prudp: Code cleanup

This commit is contained in:
Exzap 2024-04-18 19:22:28 +02:00
parent ee36992bd6
commit e2f9725719
3 changed files with 410 additions and 422 deletions

View File

@ -106,7 +106,7 @@ nexService::nexService()
nexService::nexService(prudpClient* con) : nexService()
{
if (con->isConnected() == false)
if (con->IsConnected() == false)
cemu_assert_suspicious();
this->conNexService = con;
bufferReceive = std::vector<uint8>(1024 * 4);
@ -191,7 +191,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest)
uint32 callId = _currentCallId;
_currentCallId++;
// check state of connection
if (conNexService->getConnectionState() != prudpClient::STATE_CONNECTED)
if (conNexService->GetConnectionState() != prudpClient::ConnectionState::Connected)
{
nexServiceResponse_t response = { 0 };
response.isSuccessful = false;
@ -214,7 +214,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest)
assert_dbg();
memcpy((packetBuffer + 0x0D), &queuedRequest->parameterData.front(), queuedRequest->parameterData.size());
sint32 length = 0xD + (sint32)queuedRequest->parameterData.size();
conNexService->sendDatagram(packetBuffer, length, true);
conNexService->SendDatagram(packetBuffer, length, true);
// remember request
nexActiveRequestInfo_t requestInfo = { 0 };
requestInfo.callId = callId;
@ -299,13 +299,13 @@ void nexService::registerForAsyncProcessing()
void nexService::updateTemporaryConnections()
{
// check for connection
conNexService->update();
if (conNexService->isConnected())
conNexService->Update();
if (conNexService->IsConnected())
{
if (connectionState == STATE_CONNECTING)
connectionState = STATE_CONNECTED;
}
if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED)
if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
connectionState = STATE_DISCONNECTED;
}
@ -356,18 +356,18 @@ void nexService::sendRequestResponse(nexServiceRequest_t* request, uint32 errorC
// update length field
*(uint32*)response.getDataPtr() = response.getWriteIndex()-4;
if(request->nex->conNexService)
request->nex->conNexService->sendDatagram(response.getDataPtr(), response.getWriteIndex(), true);
request->nex->conNexService->SendDatagram(response.getDataPtr(), response.getWriteIndex(), true);
}
void nexService::updateNexServiceConnection()
{
if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED)
if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
{
this->connectionState = STATE_DISCONNECTED;
return;
}
conNexService->update();
sint32 datagramLen = conNexService->receiveDatagram(bufferReceive);
conNexService->Update();
sint32 datagramLen = conNexService->ReceiveDatagram(bufferReceive);
if (datagramLen > 0)
{
if (nexIsRequest(&bufferReceive[0], datagramLen))
@ -454,12 +454,12 @@ bool _extractStationUrlParamValue(const char* urlStr, const char* paramName, cha
return false;
}
void nexServiceAuthentication_parseStationURL(char* urlStr, stationUrl_t* stationUrl)
void nexServiceAuthentication_parseStationURL(char* urlStr, prudpStationUrl* stationUrl)
{
// example:
// prudps:/address=34.210.xxx.xxx;port=60181;CID=1;PID=2;sid=1;stream=10;type=2
memset(stationUrl, 0, sizeof(stationUrl_t));
memset(stationUrl, 0, sizeof(prudpStationUrl));
char optionValue[128];
if (_extractStationUrlParamValue(urlStr, "address", optionValue, sizeof(optionValue)))
@ -499,7 +499,7 @@ typedef struct
sint32 kerberosTicketSize;
uint8 kerberosTicket2[4096];
sint32 kerberosTicket2Size;
stationUrl_t server;
prudpStationUrl server;
// progress info
bool hasError;
bool done;
@ -611,18 +611,18 @@ void nexServiceSecure_handleResponse_RegisterEx(nexService* nex, nexServiceRespo
return;
}
nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* accessKey, const char* nexToken)
nexService* nex_secureLogin(prudpAuthServerInfo* authServerInfo, const char* accessKey, const char* nexToken)
{
prudpClient* prudpSecureSock = new prudpClient(authServerInfo->server.ip, authServerInfo->server.port, accessKey, authServerInfo);
// wait until connected
while (true)
{
prudpSecureSock->update();
if (prudpSecureSock->isConnected())
prudpSecureSock->Update();
if (prudpSecureSock->IsConnected())
{
break;
}
if (prudpSecureSock->getConnectionState() == prudpClient::STATE_DISCONNECTED)
if (prudpSecureSock->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
{
// timeout or disconnected
cemuLog_log(LogType::Force, "NEX: Secure login connection time-out");
@ -638,7 +638,7 @@ nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* access
nexPacketBuffer packetBuffer(tempNexBufferArray, sizeof(tempNexBufferArray), true);
char clientStationUrl[256];
sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->getSourcePort());
sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->GetSourcePort());
// station url list
packetBuffer.writeU32(1);
packetBuffer.writeString(clientStationUrl);
@ -737,9 +737,9 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer
return nullptr;
}
// auth info
auto authServerInfo = std::make_unique<authServerInfo_t>();
auto authServerInfo = std::make_unique<prudpAuthServerInfo>();
// decrypt ticket
RC4Ctx_t rc4Ticket;
RC4Ctx rc4Ticket;
RC4_initCtx(&rc4Ticket, kerberosKey, 16);
RC4_transform(&rc4Ticket, nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, nexAuthService.kerberosTicket2);
nexPacketBuffer packetKerberosTicket(nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, false);
@ -756,7 +756,7 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer
memcpy(authServerInfo->kerberosKey, kerberosKey, 16);
memcpy(authServerInfo->secureKey, secureKey, 16);
memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(stationUrl_t));
memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(prudpStationUrl));
authServerInfo->userPid = pid;
return nex_secureLogin(authServerInfo.get(), accessKey, nexToken);

View File

@ -6,67 +6,52 @@
#include <boost/random/uniform_int.hpp>
void swap(unsigned char *a, unsigned char *b)
static void KSA(unsigned char* key, int keyLen, unsigned char* S)
{
int tmp = *a;
*a = *b;
*b = tmp;
}
void KSA(unsigned char *key, int keyLen, unsigned char *S)
{
int j = 0;
for (int i = 0; i < RC4_N; i++)
S[i] = i;
int j = 0;
for (int i = 0; i < RC4_N; i++)
{
j = (j + S[i] + key[i % keyLen]) % RC4_N;
swap(&S[i], &S[j]);
std::swap(S[i], S[j]);
}
}
void PRGA(unsigned char *S, unsigned char* input, int len, unsigned char* output)
static void PRGA(unsigned char* S, unsigned char* input, int len, unsigned char* output)
{
int i = 0;
int j = 0;
for (size_t n = 0; n < len; n++)
{
i = (i + 1) % RC4_N;
j = (j + S[i]) % RC4_N;
swap(&S[i], &S[j]);
int i = (i + 1) % RC4_N;
int j = (j + S[i]) % RC4_N;
std::swap(S[i], S[j]);
int rnd = S[(S[i] + S[j]) % RC4_N];
output[n] = rnd ^ input[n];
}
}
void RC4(char* key, unsigned char* input, int len, unsigned char* output)
static void RC4(char* key, unsigned char* input, int len, unsigned char* output)
{
unsigned char S[RC4_N];
KSA((unsigned char*)key, (int)strlen(key), S);
PRGA(S, input, len, output);
}
void RC4_initCtx(RC4Ctx_t* rc4Ctx, const char* key)
void RC4_initCtx(RC4Ctx* rc4Ctx, const char* key)
{
rc4Ctx->i = 0;
rc4Ctx->j = 0;
KSA((unsigned char*)key, (int)strlen(key), rc4Ctx->S);
}
void RC4_initCtx(RC4Ctx_t* rc4Ctx, unsigned char* key, int keyLen)
void RC4_initCtx(RC4Ctx* rc4Ctx, unsigned char* key, int keyLen)
{
rc4Ctx->i = 0;
rc4Ctx->j = 0;
KSA(key, keyLen, rc4Ctx->S);
}
void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned char* output)
void RC4_transform(RC4Ctx* rc4Ctx, unsigned char* input, int len, unsigned char* output)
{
int i = rc4Ctx->i;
int j = rc4Ctx->j;
@ -75,13 +60,10 @@ void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned cha
{
i = (i + 1) % RC4_N;
j = (j + rc4Ctx->S[i]) % RC4_N;
swap(&rc4Ctx->S[i], &rc4Ctx->S[j]);
std::swap(rc4Ctx->S[i], rc4Ctx->S[j]);
int rnd = rc4Ctx->S[(rc4Ctx->S[i] + rc4Ctx->S[j]) % RC4_N];
output[n] = rnd ^ input[n];
}
rc4Ctx->i = i;
rc4Ctx->j = j;
}
@ -91,34 +73,14 @@ uint32 prudpGetMSTimestamp()
return GetTickCount();
}
std::bitset<10000> _portUsageMask;
uint16 getRandomSrcPRUDPPort()
{
while (true)
{
sint32 p = rand() % 10000;
if (_portUsageMask.test(p))
continue;
_portUsageMask.set(p);
return 40000 + p;
}
return 0;
}
void releasePRUDPPort(uint16 port)
{
uint32 bitIndex = port - 40000;
_portUsageMask.reset(bitIndex);
}
std::mt19937_64 prudpRG(GetTickCount());
// workaround for static asserts when using uniform_int_distribution
boost::random::uniform_int_distribution<int> prudpDis8(0, 0xFF);
// workaround for static asserts when using uniform_int_distribution (see https://github.com/cemu-project/Cemu/issues/48)
boost::random::uniform_int_distribution<int> prudpRandomDistribution8(0, 0xFF);
boost::random::uniform_int_distribution<int> prudpRandomDistributionPortGen(0, 10000);
uint8 prudp_generateRandomU8()
{
return prudpDis8(prudpRG);
return prudpRandomDistribution8(prudpRG);
}
uint32 prudp_generateRandomU32()
@ -133,7 +95,29 @@ uint32 prudp_generateRandomU32()
return v;
}
uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length)
std::bitset<10000> _portUsageMask;
static uint16 AllocateRandomSrcPRUDPPort()
{
while (true)
{
sint32 p = prudpRandomDistributionPortGen(prudpRG);
if (_portUsageMask.test(p))
continue;
_portUsageMask.set(p);
return 40000 + p;
}
}
static void ReleasePRUDPSrcPort(uint16 port)
{
cemu_assert_debug(port >= 40000);
uint32 bitIndex = port - 40000;
cemu_assert_debug(_portUsageMask.test(bitIndex));
_portUsageMask.reset(bitIndex);
}
static uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length)
{
uint32 checksum32 = 0;
for (sint32 i = 0; i < length / 4; i++)
@ -212,7 +196,7 @@ sint32 prudpPacket::calculateSizeFromPacketData(uint8* data, sint32 length)
return length;
}
prudpPacket::prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature)
prudpPacket::prudpPacket(prudpStreamSettings* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature)
{
this->src = src;
this->dst = dst;
@ -352,7 +336,8 @@ prudpIncomingPacket::prudpIncomingPacket()
streamSettings = nullptr;
}
prudpIncomingPacket::prudpIncomingPacket(prudpStreamSettings_t* streamSettings, uint8* data, sint32 length) : prudpIncomingPacket()
prudpIncomingPacket::prudpIncomingPacket(prudpStreamSettings* streamSettings, uint8* data, sint32 length)
: prudpIncomingPacket()
{
if (length < 0xB + 1)
{
@ -479,53 +464,41 @@ void prudpIncomingPacket::decrypt()
prudpClient::prudpClient()
{
currentConnectionState = STATE_CONNECTING;
serverConnectionSignature = 0;
clientConnectionSignature = 0;
hasSentCon = false;
outgoingSequenceId = 0;
incomingSequenceId = 0;
m_currentConnectionState = ConnectionState::Connecting;
m_serverConnectionSignature = 0;
m_clientConnectionSignature = 0;
m_incomingSequenceId = 0;
clientSessionId = 0;
serverSessionId = 0;
isSecureConnection = false;
m_clientSessionId = 0;
m_serverSessionId = 0;
}
prudpClient::~prudpClient()
prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key)
: prudpClient()
{
if (srcPort != 0)
{
releasePRUDPPort(srcPort);
closesocket(socketUdp);
}
}
prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key) : prudpClient()
{
this->dstIp = dstIp;
this->dstPort = dstPort;
m_dstIp = dstIp;
m_dstPort = dstPort;
// get unused random source port
for (sint32 tries = 0; tries < 5; tries++)
{
srcPort = getRandomSrcPRUDPPort();
m_srcPort = AllocateRandomSrcPRUDPPort();
// create and bind udp socket
socketUdp = socket(AF_INET, SOCK_DGRAM, 0);
m_socketUdp = socket(AF_INET, SOCK_DGRAM, 0);
struct sockaddr_in udpServer;
udpServer.sin_family = AF_INET;
udpServer.sin_addr.s_addr = INADDR_ANY;
udpServer.sin_port = htons(srcPort);
if (bind(socketUdp, (struct sockaddr *)&udpServer, sizeof(udpServer)) == SOCKET_ERROR)
udpServer.sin_port = htons(m_srcPort);
if (bind(m_socketUdp, (struct sockaddr*)&udpServer, sizeof(udpServer)) == SOCKET_ERROR)
{
ReleasePRUDPSrcPort(m_srcPort);
m_srcPort = 0;
if (tries == 4)
{
cemuLog_log(LogType::Force, "PRUDP: Failed to bind UDP socket");
currentConnectionState = STATE_DISCONNECTED;
srcPort = 0;
m_currentConnectionState = ConnectionState::Disconnected;
return;
}
releasePRUDPPort(srcPort);
closesocket(socketUdp);
closesocket(m_socketUdp);
continue;
}
else
@ -534,77 +507,75 @@ prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key) : prudpC
// set socket to non-blocking mode
#if BOOST_OS_WINDOWS
u_long nonBlockingMode = 1; // 1 to enable non-blocking socket
ioctlsocket(socketUdp, FIONBIO, &nonBlockingMode);
ioctlsocket(m_socketUdp, FIONBIO, &nonBlockingMode);
#else
int flags = fcntl(socketUdp, F_GETFL);
fcntl(socketUdp, F_SETFL, flags | O_NONBLOCK);
#endif
// generate frequently used parameters
this->vport_src = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0xF);
this->vport_dst = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0x1);
this->m_srcVPort = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0xF);
this->m_dstVPort = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0x1);
// set stream settings
uint8 checksumBase = 0;
for (sint32 i = 0; key[i] != '\0'; i++)
{
checksumBase += key[i];
}
streamSettings.checksumBase = checksumBase;
m_streamSettings.checksumBase = checksumBase;
MD5_CTX md5Ctx;
MD5_Init(&md5Ctx);
MD5_Update(&md5Ctx, key, (int)strlen(key));
MD5_Final(streamSettings.accessKeyDigest, &md5Ctx);
MD5_Final(m_streamSettings.accessKeyDigest, &md5Ctx);
// init stream ciphers
RC4_initCtx(&streamSettings.rc4Server, "CD&ML");
RC4_initCtx(&streamSettings.rc4Client, "CD&ML");
RC4_initCtx(&m_streamSettings.rc4Server, "CD&ML");
RC4_initCtx(&m_streamSettings.rc4Client, "CD&ML");
// send syn packet
prudpPacket* synPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_SYN, prudpPacket::FLAG_NEED_ACK, 0, 0, 0);
queuePacket(synPacket, dstIp, dstPort);
outgoingSequenceId++;
SendCurrentHandshakePacket();
// set incoming sequence id to 1
incomingSequenceId = 1;
m_incomingSequenceId = 1;
}
prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key, authServerInfo_t* authInfo) : prudpClient(dstIp, dstPort, key)
prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key, prudpAuthServerInfo* authInfo)
: prudpClient(dstIp, dstPort, key)
{
RC4_initCtx(&streamSettings.rc4Server, authInfo->secureKey, 16);
RC4_initCtx(&streamSettings.rc4Client, authInfo->secureKey, 16);
this->isSecureConnection = true;
memcpy(&this->authInfo, authInfo, sizeof(authServerInfo_t));
RC4_initCtx(&m_streamSettings.rc4Server, authInfo->secureKey, 16);
RC4_initCtx(&m_streamSettings.rc4Client, authInfo->secureKey, 16);
this->m_isSecureConnection = true;
memcpy(&this->m_authInfo, authInfo, sizeof(prudpAuthServerInfo));
}
bool prudpClient::isConnected()
prudpClient::~prudpClient()
{
return currentConnectionState == STATE_CONNECTED;
if (m_srcPort != 0)
{
ReleasePRUDPSrcPort(m_srcPort);
closesocket(m_socketUdp);
}
}
uint8 prudpClient::getConnectionState()
void prudpClient::AcknowledgePacket(uint16 sequenceId)
{
return currentConnectionState;
}
void prudpClient::acknowledgePacket(uint16 sequenceId)
{
auto it = std::begin(list_packetsWithAckReq);
while (it != std::end(list_packetsWithAckReq))
auto it = std::begin(m_dataPacketsWithAckReq);
while (it != std::end(m_dataPacketsWithAckReq))
{
if (it->packet->GetSequenceId() == sequenceId)
{
delete it->packet;
list_packetsWithAckReq.erase(it);
m_dataPacketsWithAckReq.erase(it);
return;
}
it++;
}
}
void prudpClient::sortIncomingDataPacket(prudpIncomingPacket* incomingPacket)
void prudpClient::SortIncomingDataPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket)
{
uint16 sequenceIdIncomingPacket = incomingPacket->sequenceId;
// find insert index
sint32 insertIndex = 0;
while (insertIndex < queue_incomingPackets.size() )
while (insertIndex < m_incomingPacketQueue.size())
{
uint16 seqDif = sequenceIdIncomingPacket - queue_incomingPackets[insertIndex]->sequenceId;
uint16 seqDif = sequenceIdIncomingPacket - m_incomingPacketQueue[insertIndex]->sequenceId;
if (seqDif & 0x8000)
break; // negative seqDif -> insert before current element
#ifdef CEMU_DEBUG_ASSERT
@ -613,28 +584,72 @@ void prudpClient::sortIncomingDataPacket(prudpIncomingPacket* incomingPacket)
#endif
insertIndex++;
}
// insert
sint32 currentSize = (sint32)queue_incomingPackets.size();
queue_incomingPackets.resize(currentSize+1);
for(sint32 i=currentSize; i>insertIndex; i--)
m_incomingPacketQueue.insert(m_incomingPacketQueue.begin() + insertIndex, std::move(incomingPacket));
// debug check if packets are really ordered by sequence id
#ifdef CEMU_DEBUG_ASSERT
for (sint32 i = 1; i < m_incomingPacketQueue.size(); i++)
{
queue_incomingPackets[i] = queue_incomingPackets[i - 1];
uint16 seqDif = m_incomingPacketQueue[i]->sequenceId - m_incomingPacketQueue[i - 1]->sequenceId;
if (seqDif & 0x8000)
seqDif = -seqDif;
if (seqDif >= 0x8000)
assert_dbg();
}
queue_incomingPackets[insertIndex] = incomingPacket;
#endif
}
sint32 prudpClient::kerberosEncryptData(uint8* input, sint32 length, uint8* output)
sint32 prudpClient::KerberosEncryptData(uint8* input, sint32 length, uint8* output)
{
RC4Ctx_t rc4Kerberos;
RC4_initCtx(&rc4Kerberos, this->authInfo.secureKey, 16);
RC4Ctx rc4Kerberos;
RC4_initCtx(&rc4Kerberos, this->m_authInfo.secureKey, 16);
memcpy(output, input, length);
RC4_transform(&rc4Kerberos, output, length, output);
// calculate and append hmac
hmacMD5(this->authInfo.secureKey, 16, output, length, output+length);
hmacMD5(this->m_authInfo.secureKey, 16, output, length, output + length);
return length + 16;
}
void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket)
// (re)sends either CON or SYN based on what stage of the login we are at
// the sequenceId for both is hardcoded for both because we'll never send anything in between
void prudpClient::SendCurrentHandshakePacket()
{
if (!m_hasSynAck)
{
// send syn (with a fixed sequenceId of 0)
prudpPacket synPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_SYN, prudpPacket::FLAG_NEED_ACK, 0, 0, 0);
DirectSendPacket(&synPacket);
}
else
{
// send con (with a fixed sequenceId of 1)
prudpPacket conPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_CON, prudpPacket::FLAG_NEED_ACK | prudpPacket::FLAG_RELIABLE, this->m_clientSessionId, 1, m_serverConnectionSignature);
if (this->m_isSecureConnection)
{
uint8 tempBuffer[512];
nexPacketBuffer conData(tempBuffer, sizeof(tempBuffer), true);
conData.writeU32(this->m_clientConnectionSignature);
conData.writeBuffer(m_authInfo.secureTicket, m_authInfo.secureTicketLength);
// encrypted request data
uint8 requestData[4 * 3];
uint8 requestDataEncrypted[4 * 3 + 0x10];
*(uint32*)(requestData + 0x0) = m_authInfo.userPid;
*(uint32*)(requestData + 0x4) = m_authInfo.server.cid;
*(uint32*)(requestData + 0x8) = prudp_generateRandomU32(); // todo - check value
sint32 encryptedSize = KerberosEncryptData(requestData, sizeof(requestData), requestDataEncrypted);
conData.writeBuffer(requestDataEncrypted, encryptedSize);
conPacket.setData(conData.getDataPtr(), conData.getWriteIndex());
}
else
{
conPacket.setData((uint8*)&this->m_clientConnectionSignature, sizeof(uint32));
}
DirectSendPacket(&conPacket);
}
m_lastHandshakeTimestamp = prudpGetMSTimestamp();
m_handshakeRetryCount++;
}
void prudpClient::HandleIncomingPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket)
{
if (incomingPacket->type == prudpPacket::TYPE_PING)
{
@ -664,127 +679,114 @@ void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket)
{
// other side is asking for ping ack
cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet with NEED_ACK set. Sending ACK back");
cemu_assert_debug(incomingPacket->packetData.empty()); // todo - echo data?
prudpPacket ackPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0);
directSendPacket(&ackPacket, dstIp, dstPort);
prudpPacket ackPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_ACK, this->m_clientSessionId, incomingPacket->sequenceId, 0);
if(!incomingPacket->packetData.empty())
ackPacket.setData(incomingPacket->packetData.data(), incomingPacket->packetData.size());
DirectSendPacket(&ackPacket);
}
delete incomingPacket;
return;
}
// handle general packet ACK
if (incomingPacket->flags&prudpPacket::FLAG_ACK)
else if (incomingPacket->type == prudpPacket::TYPE_SYN)
{
acknowledgePacket(incomingPacket->sequenceId);
// syn packet from server is expected to have ACK set
if (!(incomingPacket->flags & prudpPacket::FLAG_ACK))
{
cemuLog_log(LogType::Force, "[PRUDP] Received SYN packet without ACK flag set"); // always log this
return;
}
// special cases
if (incomingPacket->type == prudpPacket::TYPE_SYN)
if (m_hasSynAck || !incomingPacket->hasData || incomingPacket->packetData.size() != 4)
{
if (hasSentCon == false && incomingPacket->hasData && incomingPacket->packetData.size() == 4)
{
this->serverConnectionSignature = *(uint32*)&incomingPacket->packetData.front();
this->clientSessionId = prudp_generateRandomU8();
// generate client session id
this->clientConnectionSignature = prudp_generateRandomU32();
// syn already acked or not a valid syn packet
cemuLog_log(LogType::PRUDP, "[PRUDP] Received unexpected SYN packet");
return;
}
m_hasSynAck = true;
this->m_serverConnectionSignature = *(uint32*)&incomingPacket->packetData.front();
// generate client session id and connection signature
this->m_clientSessionId = prudp_generateRandomU8();
this->m_clientConnectionSignature = prudp_generateRandomU32();
// send con packet
prudpPacket* conPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_CON, prudpPacket::FLAG_NEED_ACK|prudpPacket::FLAG_RELIABLE, this->clientSessionId, outgoingSequenceId, serverConnectionSignature);
outgoingSequenceId++;
if (this->isSecureConnection)
{
// set packet specific data (client connection signature)
uint8 tempBuffer[512];
nexPacketBuffer conData(tempBuffer, sizeof(tempBuffer), true);
conData.writeU32(this->clientConnectionSignature);
conData.writeBuffer(authInfo.secureTicket, authInfo.secureTicketLength);
// encrypted request data
uint8 requestData[4 * 3];
uint8 requestDataEncrypted[4 * 3 + 0x10];
*(uint32*)(requestData + 0x0) = authInfo.userPid;
*(uint32*)(requestData + 0x4) = authInfo.server.cid;
*(uint32*)(requestData + 0x8) = prudp_generateRandomU32(); // todo - check value
sint32 encryptedSize = kerberosEncryptData(requestData, sizeof(requestData), requestDataEncrypted);
conData.writeBuffer(requestDataEncrypted, encryptedSize);
conPacket->setData(conData.getDataPtr(), conData.getWriteIndex());
}
else
{
// set packet specific data (client connection signature)
conPacket->setData((uint8*)&this->clientConnectionSignature, sizeof(uint32));
}
// send packet
queuePacket(conPacket, dstIp, dstPort);
// remember con packet as sent
hasSentCon = true;
}
delete incomingPacket;
m_handshakeRetryCount = 0;
SendCurrentHandshakePacket();
return;
}
else if (incomingPacket->type == prudpPacket::TYPE_CON)
{
// connected!
if (currentConnectionState == STATE_CONNECTING)
if (!m_hasSynAck || m_hasConAck)
{
lastPingTimestamp = prudpGetMSTimestamp();
cemu_assert_debug(serverSessionId == 0);
serverSessionId = incomingPacket->sessionId;
currentConnectionState = STATE_CONNECTED;
cemuLog_log(LogType::PRUDP, "[PRUDP] Connection established. ClientSession {:02x} ServerSession {:02x}", clientSessionId, serverSessionId);
cemuLog_log(LogType::PRUDP, "[PRUDP] Received unexpected CON packet");
return;
}
delete incomingPacket;
// make sure the packet has the ACK flag set
if (!(incomingPacket->flags & prudpPacket::FLAG_ACK))
{
cemuLog_log(LogType::Force, "[PRUDP] Received CON packet without ACK flag set");
return;
}
m_hasConAck = true;
m_handshakeRetryCount = 0;
cemu_assert_debug(m_currentConnectionState == ConnectionState::Connecting);
// connected!
m_lastPingTimestamp = prudpGetMSTimestamp();
cemu_assert_debug(m_serverSessionId == 0);
m_serverSessionId = incomingPacket->sessionId;
m_currentConnectionState = ConnectionState::Connected;
cemuLog_log(LogType::PRUDP, "[PRUDP] Connection established. ClientSession {:02x} ServerSession {:02x}", m_clientSessionId, m_serverSessionId);
return;
}
else if (incomingPacket->type == prudpPacket::TYPE_DATA)
{
// handle ACK
if (incomingPacket->flags & prudpPacket::FLAG_ACK)
{
AcknowledgePacket(incomingPacket->sequenceId);
if(!incomingPacket->packetData.empty())
cemuLog_log(LogType::PRUDP, "[PRUDP] Received ACK data packet with payload");
return;
}
// send ack back if requested
if (incomingPacket->flags & prudpPacket::FLAG_NEED_ACK)
{
prudpPacket ackPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0);
directSendPacket(&ackPacket, dstIp, dstPort);
prudpPacket ackPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_DATA, prudpPacket::FLAG_ACK, this->m_clientSessionId, incomingPacket->sequenceId, 0);
DirectSendPacket(&ackPacket);
}
// skip data packets without payload
if (incomingPacket->packetData.empty())
{
delete incomingPacket;
return;
}
// verify some values
uint16 seqDist = incomingPacket->sequenceId - incomingSequenceId;
// verify sequence id
uint16 seqDist = incomingPacket->sequenceId - m_incomingSequenceId;
if (seqDist >= 0xC000)
{
// outdated
delete incomingPacket;
return;
}
// check if packet is already queued
for (auto& it : queue_incomingPackets)
for (auto& it : m_incomingPacketQueue)
{
if (it->sequenceId == incomingPacket->sequenceId)
{
// already queued (should check other values too, like packet type?)
cemuLog_log(LogType::PRUDP, "Duplicate PRUDP packet received");
delete incomingPacket;
return;
}
}
// put into ordered receive queue
sortIncomingDataPacket(incomingPacket);
SortIncomingDataPacket(std::move(incomingPacket));
}
else if (incomingPacket->type == prudpPacket::TYPE_DISCONNECT)
{
currentConnectionState = STATE_DISCONNECTED;
m_currentConnectionState = ConnectionState::Disconnected;
return;
}
else
{
// ignore unknown packet
delete incomingPacket;
return;
cemuLog_log(LogType::PRUDP, "[PRUDP] Received unknown packet type");
}
}
bool prudpClient::update()
bool prudpClient::Update()
{
if (currentConnectionState == STATE_DISCONNECTED)
if (m_currentConnectionState == ConnectionState::Disconnected)
return false;
uint32 currentTimestamp = prudpGetMSTimestamp();
// check for incoming packets
@ -793,7 +795,7 @@ bool prudpClient::update()
{
sockaddr receiveFrom = {0};
socklen_t receiveFromLen = sizeof(receiveFrom);
sint32 r = recvfrom(socketUdp, (char*)receiveBuffer, sizeof(receiveBuffer), 0, &receiveFrom, &receiveFromLen);
sint32 r = recvfrom(m_socketUdp, (char*)receiveBuffer, sizeof(receiveBuffer), 0, &receiveFrom, &receiveFromLen);
if (r >= 0)
{
// todo: Verify sender (receiveFrom)
@ -807,198 +809,190 @@ bool prudpClient::update()
cemuLog_log(LogType::Force, "[PRUDP] Invalid packet length");
break;
}
prudpIncomingPacket* incomingPacket = new prudpIncomingPacket(&streamSettings, receiveBuffer + pIdx, packetLength);
auto incomingPacket = std::make_unique<prudpIncomingPacket>(&m_streamSettings, receiveBuffer + pIdx, packetLength);
pIdx += packetLength;
if (incomingPacket->hasError())
{
cemuLog_log(LogType::Force, "[PRUDP] Packet error");
delete incomingPacket;
break;
}
if (incomingPacket->type != prudpPacket::TYPE_CON && incomingPacket->sessionId != serverSessionId)
if (incomingPacket->type != prudpPacket::TYPE_CON && incomingPacket->sessionId != m_serverSessionId)
{
cemuLog_log(LogType::PRUDP, "[PRUDP] Invalid session id");
delete incomingPacket;
continue; // different session
}
handleIncomingPacket(incomingPacket);
HandleIncomingPacket(std::move(incomingPacket));
}
}
else
break;
}
// check for ack timeouts
for (auto &it : list_packetsWithAckReq)
for (auto& it : m_dataPacketsWithAckReq)
{
if ((currentTimestamp - it.lastRetryTimestamp) >= 2300)
{
if (it.retryCount >= 7)
{
// after too many retries consider the connection dead
currentConnectionState = STATE_DISCONNECTED;
m_currentConnectionState = ConnectionState::Disconnected;
}
// resend
directSendPacket(it.packet, dstIp, dstPort);
DirectSendPacket(it.packet);
it.lastRetryTimestamp = currentTimestamp;
it.retryCount++;
}
}
// check if we need to send another ping
if (currentConnectionState == STATE_CONNECTED)
if (m_currentConnectionState == ConnectionState::Connecting)
{
// check if we need to resend SYN or CON
uint32 timeSinceLastHandshake = currentTimestamp - m_lastHandshakeTimestamp;
if (timeSinceLastHandshake >= 1200)
{
if (m_handshakeRetryCount >= 5)
{
// too many retries, assume the other side doesn't listen
m_currentConnectionState = ConnectionState::Disconnected;
cemuLog_log(LogType::PRUDP, "PRUDP: Failed to connect");
return false;
}
SendCurrentHandshakePacket();
}
}
else if (m_currentConnectionState == ConnectionState::Connected)
{
// handle pings
if (m_unacknowledgedPingCount != 0) // counts how many times we sent a ping packet (for the current sequenceId) without receiving an ack
{
// we are waiting for the ack of the previous ping, but it hasn't arrived yet so send another ping packet
if((currentTimestamp - lastPingTimestamp) >= 1500)
if ((currentTimestamp - m_lastPingTimestamp) >= 1500)
{
cemuLog_log(LogType::PRUDP, "[PRUDP] Resending ping packet (no ack received)");
if (m_unacknowledgedPingCount >= 10)
{
// too many unacknowledged pings, assume the connection is dead
currentConnectionState = STATE_DISCONNECTED;
m_currentConnectionState = ConnectionState::Disconnected;
cemuLog_log(LogType::PRUDP, "PRUDP: Connection did not receive a ping response in a while. Assuming disconnect");
return false;
}
// resend the ping packet
prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->clientSessionId, this->m_outgoingSequenceId_ping, serverConnectionSignature);
directSendPacket(pingPacket, dstIp, dstPort);
prudpPacket pingPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->m_clientSessionId, this->m_outgoingSequenceId_ping, m_serverConnectionSignature);
DirectSendPacket(&pingPacket);
m_unacknowledgedPingCount++;
delete pingPacket;
lastPingTimestamp = currentTimestamp;
m_lastPingTimestamp = currentTimestamp;
}
}
else
{
if((currentTimestamp - lastPingTimestamp) >= 20000)
if ((currentTimestamp - m_lastPingTimestamp) >= 20000)
{
cemuLog_log(LogType::PRUDP, "[PRUDP] Sending new ping packet with sequenceId {}", this->m_outgoingSequenceId_ping + 1);
// start a new ping packet with a new sequenceId. Note that ping packets have their own sequenceId and acknowledgement happens by manually comparing the incoming ping ACK against the last sent sequenceId
// only one unacknowledged ping packet can be in flight at a time. We will resend the same ping packet until we receive an ack
this->m_outgoingSequenceId_ping++; // increment before sending. The first ping has a sequenceId of 1
prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->clientSessionId, this->m_outgoingSequenceId_ping, serverConnectionSignature);
directSendPacket(pingPacket, dstIp, dstPort);
prudpPacket pingPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->m_clientSessionId, this->m_outgoingSequenceId_ping, m_serverConnectionSignature);
DirectSendPacket(&pingPacket);
m_unacknowledgedPingCount++;
delete pingPacket;
lastPingTimestamp = currentTimestamp;
m_lastPingTimestamp = currentTimestamp;
}
}
}
return false;
}
void prudpClient::directSendPacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort)
void prudpClient::DirectSendPacket(prudpPacket* packet)
{
uint8 packetBuffer[prudpPacket::PACKET_RAW_SIZE_MAX];
sint32 len = packet->buildData(packetBuffer, prudpPacket::PACKET_RAW_SIZE_MAX);
sockaddr_in destAddr;
destAddr.sin_family = AF_INET;
destAddr.sin_port = htons(dstPort);
destAddr.sin_addr.s_addr = dstIp;
sendto(socketUdp, (const char*)packetBuffer, len, 0, (const sockaddr*)&destAddr, sizeof(destAddr));
destAddr.sin_port = htons(m_dstPort);
destAddr.sin_addr.s_addr = m_dstIp;
sendto(m_socketUdp, (const char*)packetBuffer, len, 0, (const sockaddr*)&destAddr, sizeof(destAddr));
}
void prudpClient::queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort)
void prudpClient::QueuePacket(prudpPacket* packet)
{
cemu_assert_debug(packet->GetType() == prudpPacket::TYPE_DATA); // only data packets should be queued
if (packet->requiresAck())
{
cemu_assert_debug(packet->GetType() != prudpPacket::TYPE_PING); // ping packets use their own logic for acks, dont queue them
// remember this packet until we receive the ack
prudpAckRequired_t ackRequired = { 0 };
ackRequired.packet = packet;
ackRequired.initialSendTimestamp = prudpGetMSTimestamp();
ackRequired.lastRetryTimestamp = ackRequired.initialSendTimestamp;
list_packetsWithAckReq.push_back(ackRequired);
directSendPacket(packet, dstIp, dstPort);
m_dataPacketsWithAckReq.emplace_back(packet, prudpGetMSTimestamp());
DirectSendPacket(packet);
}
else
{
directSendPacket(packet, dstIp, dstPort);
DirectSendPacket(packet);
delete packet;
}
}
void prudpClient::sendDatagram(uint8* input, sint32 length, bool reliable)
void prudpClient::SendDatagram(uint8* input, sint32 length, bool reliable)
{
cemu_assert_debug(reliable); // non-reliable packets require testing
cemu_assert_debug(reliable); // non-reliable packets require correct sequenceId handling and testing
cemu_assert_debug(m_hasSynAck && m_hasConAck); // cant send data packets before we are connected
if (length >= 0x300)
{
cemuLog_logOnce(LogType::Force, "PRUDP: Datagram too long");
cemuLog_logOnce(LogType::Force, "PRUDP: Datagram too long. Fragmentation not implemented yet");
}
// single fragment data packet
uint16 flags = prudpPacket::FLAG_NEED_ACK;
if (reliable)
flags |= prudpPacket::FLAG_RELIABLE;
prudpPacket* packet = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, flags, clientSessionId, outgoingSequenceId, 0);
prudpPacket* packet = new prudpPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_DATA, flags, m_clientSessionId, m_outgoingReliableSequenceId, 0);
if (reliable)
outgoingSequenceId++;
m_outgoingReliableSequenceId++;
packet->setFragmentIndex(0);
packet->setData(input, length);
queuePacket(packet, dstIp, dstPort);
QueuePacket(packet);
}
uint16 prudpClient::getSourcePort()
sint32 prudpClient::ReceiveDatagram(std::vector<uint8>& outputBuffer)
{
return this->srcPort;
}
SOCKET prudpClient::getSocket()
{
if (currentConnectionState == STATE_DISCONNECTED)
{
return INVALID_SOCKET;
}
return this->socketUdp;
}
sint32 prudpClient::receiveDatagram(std::vector<uint8>& outputBuffer)
{
if (queue_incomingPackets.empty())
outputBuffer.clear();
if (m_incomingPacketQueue.empty())
return -1;
prudpIncomingPacket* incomingPacket = queue_incomingPackets[0];
if (incomingPacket->sequenceId != this->incomingSequenceId)
prudpIncomingPacket* frontPacket = m_incomingPacketQueue[0].get();
if (frontPacket->sequenceId != this->m_incomingSequenceId)
return -1;
if (incomingPacket->fragmentIndex == 0)
if (frontPacket->fragmentIndex == 0)
{
// single-fragment packet
// decrypt
incomingPacket->decrypt();
frontPacket->decrypt();
// read data
sint32 datagramLen = (sint32)incomingPacket->packetData.size();
if (datagramLen > 0)
if (!frontPacket->packetData.empty())
{
// to conserve memory we will also shrink the buffer if it was previously extended beyond 32KB
constexpr size_t BUFFER_TARGET_SIZE = 1024 * 32;
if (frontPacket->packetData.size() < BUFFER_TARGET_SIZE && outputBuffer.capacity() > BUFFER_TARGET_SIZE)
{
// resize buffer if necessary
if (datagramLen > outputBuffer.size())
outputBuffer.resize(datagramLen);
// to conserve memory we will also shrink the buffer if it was previously extended beyond 64KB
constexpr size_t BUFFER_TARGET_SIZE = 1024 * 64;
if (datagramLen < BUFFER_TARGET_SIZE && outputBuffer.size() > BUFFER_TARGET_SIZE)
outputBuffer.resize(BUFFER_TARGET_SIZE);
// copy datagram to buffer
memcpy(outputBuffer.data(), &incomingPacket->packetData.front(), datagramLen);
outputBuffer.shrink_to_fit();
outputBuffer.clear();
}
delete incomingPacket;
// remove packet from queue
queue_incomingPackets.erase(queue_incomingPackets.begin());
// write packet data to output buffer
cemu_assert_debug(outputBuffer.empty());
outputBuffer.insert(outputBuffer.end(), frontPacket->packetData.begin(), frontPacket->packetData.end());
}
m_incomingPacketQueue.erase(m_incomingPacketQueue.begin());
// advance expected sequence id
this->incomingSequenceId++;
return datagramLen;
this->m_incomingSequenceId++;
return (sint32)outputBuffer.size();
}
else
{
// multi-fragment packet
if (incomingPacket->fragmentIndex != 1)
if (frontPacket->fragmentIndex != 1)
return -1; // first packet of the chain not received yet
// verify chain
sint32 packetIndex = 1;
sint32 chainLength = -1; // if full chain found, set to count of packets
for(sint32 i=1; i<queue_incomingPackets.size(); i++)
for (sint32 i = 1; i < m_incomingPacketQueue.size(); i++)
{
uint8 itFragmentIndex = queue_incomingPackets[packetIndex]->fragmentIndex;
uint8 itFragmentIndex = m_incomingPacketQueue[packetIndex]->fragmentIndex;
// sequence id must increase by 1 for every packet
if (queue_incomingPackets[packetIndex]->sequenceId != (this->incomingSequenceId+i) )
if (m_incomingPacketQueue[packetIndex]->sequenceId != (m_incomingSequenceId + i))
return -1; // missing packets
// last fragment in chain is marked by fragment index 0
if (itFragmentIndex == 0)
@ -1011,29 +1005,17 @@ sint32 prudpClient::receiveDatagram(std::vector<uint8>& outputBuffer)
if (chainLength < 1)
return -1; // chain not complete
// extract data from packet chain
sint32 writeIndex = 0;
cemu_assert_debug(outputBuffer.empty());
for (sint32 i = 0; i < chainLength; i++)
{
incomingPacket = queue_incomingPackets[i];
// decrypt
prudpIncomingPacket* incomingPacket = m_incomingPacketQueue[i].get();
incomingPacket->decrypt();
// extract data
sint32 datagramLen = (sint32)incomingPacket->packetData.size();
if (datagramLen > 0)
{
// make sure output buffer can fit the data
if ((writeIndex + datagramLen) > outputBuffer.size())
outputBuffer.resize(writeIndex + datagramLen + 4 * 1024);
memcpy(outputBuffer.data()+writeIndex, &incomingPacket->packetData.front(), datagramLen);
writeIndex += datagramLen;
}
// free packet memory
delete incomingPacket;
outputBuffer.insert(outputBuffer.end(), incomingPacket->packetData.begin(), incomingPacket->packetData.end());
}
// remove packets from queue
queue_incomingPackets.erase(queue_incomingPackets.begin(), queue_incomingPackets.begin() + chainLength);
this->incomingSequenceId += chainLength;
return writeIndex;
m_incomingPacketQueue.erase(m_incomingPacketQueue.begin(), m_incomingPacketQueue.begin() + chainLength);
m_incomingSequenceId += chainLength;
return (sint32)outputBuffer.size();
}
return -1;
}

View File

@ -4,26 +4,26 @@
#define RC4_N 256
typedef struct
struct RC4Ctx
{
unsigned char S[RC4_N];
int i;
int j;
}RC4Ctx_t;
};
void RC4_initCtx(RC4Ctx_t* rc4Ctx, char *key);
void RC4_initCtx(RC4Ctx_t* rc4Ctx, unsigned char* key, int keyLen);
void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned char* output);
void RC4_initCtx(RC4Ctx* rc4Ctx, const char* key);
void RC4_initCtx(RC4Ctx* rc4Ctx, unsigned char* key, int keyLen);
void RC4_transform(RC4Ctx* rc4Ctx, unsigned char* input, int len, unsigned char* output);
typedef struct
struct prudpStreamSettings
{
uint8 checksumBase; // calculated from key
uint8 accessKeyDigest[16]; // MD5 hash of key
RC4Ctx_t rc4Client;
RC4Ctx_t rc4Server;
}prudpStreamSettings_t;
RC4Ctx rc4Client;
RC4Ctx rc4Server;
};
typedef struct
struct prudpStationUrl
{
uint32 ip;
uint16 port;
@ -32,19 +32,17 @@ typedef struct
sint32 sid;
sint32 stream;
sint32 type;
}stationUrl_t;
};
typedef struct
struct prudpAuthServerInfo
{
uint32 userPid;
uint8 secureKey[16];
uint8 kerberosKey[16];
uint8 secureTicket[1024];
sint32 secureTicketLength;
stationUrl_t server;
}authServerInfo_t;
uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length);
prudpStationUrl server;
};
class prudpPacket
{
@ -66,7 +64,7 @@ public:
static sint32 calculateSizeFromPacketData(uint8* data, sint32 length);
prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature);
prudpPacket(prudpStreamSettings* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature);
bool requiresAck();
void setData(uint8* data, sint32 length);
void setFragmentIndex(uint8 fragmentIndex);
@ -87,7 +85,7 @@ private:
uint16 flags;
uint8 sessionId;
uint32 specifiedPacketSignature;
prudpStreamSettings_t* streamSettings;
prudpStreamSettings* streamSettings;
std::vector<uint8> packetData;
bool isEncrypted;
uint16 m_sequenceId{0};
@ -97,7 +95,7 @@ private:
class prudpIncomingPacket
{
public:
prudpIncomingPacket(prudpStreamSettings_t* streamSettings, uint8* data, sint32 length);
prudpIncomingPacket(prudpStreamSettings* streamSettings, uint8* data, sint32 length);
bool hasError();
@ -122,83 +120,91 @@ public:
private:
bool isInvalid = false;
prudpStreamSettings_t* streamSettings = nullptr;
prudpStreamSettings* streamSettings = nullptr;
};
typedef struct
{
prudpPacket* packet;
uint32 initialSendTimestamp;
uint32 lastRetryTimestamp;
sint32 retryCount;
}prudpAckRequired_t;
class prudpClient
{
struct PacketWithAckRequired
{
PacketWithAckRequired(prudpPacket* packet, uint32 initialSendTimestamp) :
packet(packet), initialSendTimestamp(initialSendTimestamp), lastRetryTimestamp(initialSendTimestamp) { }
prudpPacket* packet;
uint32 initialSendTimestamp;
uint32 lastRetryTimestamp;
sint32 retryCount{0};
};
public:
static const int STATE_CONNECTING = 0;
static const int STATE_CONNECTED = 1;
static const int STATE_DISCONNECTED = 2;
enum class ConnectionState : uint8
{
Connecting,
Connected,
Disconnected
};
public:
prudpClient(uint32 dstIp, uint16 dstPort, const char* key);
prudpClient(uint32 dstIp, uint16 dstPort, const char* key, authServerInfo_t* authInfo);
prudpClient(uint32 dstIp, uint16 dstPort, const char* key, prudpAuthServerInfo* authInfo);
~prudpClient();
bool isConnected();
bool IsConnected() const { return m_currentConnectionState == ConnectionState::Connected; }
ConnectionState GetConnectionState() const { return m_currentConnectionState; }
uint16 GetSourcePort() const { return m_srcPort; }
uint8 getConnectionState();
void acknowledgePacket(uint16 sequenceId);
void sortIncomingDataPacket(prudpIncomingPacket* incomingPacket);
void handleIncomingPacket(prudpIncomingPacket* incomingPacket);
bool update(); // check for new incoming packets, returns true if receiveDatagram() should be called
bool Update(); // update connection state and check for incoming packets. Returns true if ReceiveDatagram() should be called
sint32 receiveDatagram(std::vector<uint8>& outputBuffer);
void sendDatagram(uint8* input, sint32 length, bool reliable = true);
uint16 getSourcePort();
SOCKET getSocket();
sint32 ReceiveDatagram(std::vector<uint8>& outputBuffer);
void SendDatagram(uint8* input, sint32 length, bool reliable = true);
private:
prudpClient();
void directSendPacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort);
sint32 kerberosEncryptData(uint8* input, sint32 length, uint8* output);
void queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort);
void HandleIncomingPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket);
void DirectSendPacket(prudpPacket* packet);
sint32 KerberosEncryptData(uint8* input, sint32 length, uint8* output);
void QueuePacket(prudpPacket* packet);
void AcknowledgePacket(uint16 sequenceId);
void SortIncomingDataPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket);
void SendCurrentHandshakePacket();
private:
uint16 srcPort;
uint32 dstIp;
uint16 dstPort;
uint8 vport_src;
uint8 vport_dst;
prudpStreamSettings_t streamSettings;
std::vector<prudpAckRequired_t> list_packetsWithAckReq;
std::vector<prudpIncomingPacket*> queue_incomingPackets;
uint16 m_srcPort;
uint32 m_dstIp;
uint16 m_dstPort;
uint8 m_srcVPort;
uint8 m_dstVPort;
prudpStreamSettings m_streamSettings;
std::vector<PacketWithAckRequired> m_dataPacketsWithAckReq;
std::vector<std::unique_ptr<prudpIncomingPacket>> m_incomingPacketQueue;
// connection handshake state
bool m_hasSynAck{false};
bool m_hasConAck{false};
uint32 m_lastHandshakeTimestamp{0};
uint8 m_handshakeRetryCount{0};
// connection
uint8 currentConnectionState;
uint32 serverConnectionSignature;
uint32 clientConnectionSignature;
bool hasSentCon;
uint32 lastPingTimestamp;
ConnectionState m_currentConnectionState;
uint32 m_serverConnectionSignature;
uint32 m_clientConnectionSignature;
uint32 m_lastPingTimestamp;
uint16 outgoingSequenceId;
uint16 incomingSequenceId;
uint16 m_outgoingReliableSequenceId{2}; // 1 is reserved for CON
uint16 m_incomingSequenceId;
uint16 m_outgoingSequenceId_ping{0};
uint8 m_unacknowledgedPingCount{0};
uint8 clientSessionId;
uint8 serverSessionId;
uint8 m_clientSessionId;
uint8 m_serverSessionId;
// secure
bool isSecureConnection;
authServerInfo_t authInfo;
bool m_isSecureConnection{false};
prudpAuthServerInfo m_authInfo;
// socket
SOCKET socketUdp;
SOCKET m_socketUdp;
};
uint32 prudpGetMSTimestamp();