prudp: Code cleanup

This commit is contained in:
Exzap 2024-04-18 19:22:28 +02:00
parent ee36992bd6
commit e2f9725719
3 changed files with 410 additions and 422 deletions

View File

@ -106,7 +106,7 @@ nexService::nexService()
nexService::nexService(prudpClient* con) : nexService() nexService::nexService(prudpClient* con) : nexService()
{ {
if (con->isConnected() == false) if (con->IsConnected() == false)
cemu_assert_suspicious(); cemu_assert_suspicious();
this->conNexService = con; this->conNexService = con;
bufferReceive = std::vector<uint8>(1024 * 4); bufferReceive = std::vector<uint8>(1024 * 4);
@ -191,7 +191,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest)
uint32 callId = _currentCallId; uint32 callId = _currentCallId;
_currentCallId++; _currentCallId++;
// check state of connection // check state of connection
if (conNexService->getConnectionState() != prudpClient::STATE_CONNECTED) if (conNexService->GetConnectionState() != prudpClient::ConnectionState::Connected)
{ {
nexServiceResponse_t response = { 0 }; nexServiceResponse_t response = { 0 };
response.isSuccessful = false; response.isSuccessful = false;
@ -214,7 +214,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest)
assert_dbg(); assert_dbg();
memcpy((packetBuffer + 0x0D), &queuedRequest->parameterData.front(), queuedRequest->parameterData.size()); memcpy((packetBuffer + 0x0D), &queuedRequest->parameterData.front(), queuedRequest->parameterData.size());
sint32 length = 0xD + (sint32)queuedRequest->parameterData.size(); sint32 length = 0xD + (sint32)queuedRequest->parameterData.size();
conNexService->sendDatagram(packetBuffer, length, true); conNexService->SendDatagram(packetBuffer, length, true);
// remember request // remember request
nexActiveRequestInfo_t requestInfo = { 0 }; nexActiveRequestInfo_t requestInfo = { 0 };
requestInfo.callId = callId; requestInfo.callId = callId;
@ -299,13 +299,13 @@ void nexService::registerForAsyncProcessing()
void nexService::updateTemporaryConnections() void nexService::updateTemporaryConnections()
{ {
// check for connection // check for connection
conNexService->update(); conNexService->Update();
if (conNexService->isConnected()) if (conNexService->IsConnected())
{ {
if (connectionState == STATE_CONNECTING) if (connectionState == STATE_CONNECTING)
connectionState = STATE_CONNECTED; connectionState = STATE_CONNECTED;
} }
if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED) if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
connectionState = STATE_DISCONNECTED; connectionState = STATE_DISCONNECTED;
} }
@ -356,18 +356,18 @@ void nexService::sendRequestResponse(nexServiceRequest_t* request, uint32 errorC
// update length field // update length field
*(uint32*)response.getDataPtr() = response.getWriteIndex()-4; *(uint32*)response.getDataPtr() = response.getWriteIndex()-4;
if(request->nex->conNexService) if(request->nex->conNexService)
request->nex->conNexService->sendDatagram(response.getDataPtr(), response.getWriteIndex(), true); request->nex->conNexService->SendDatagram(response.getDataPtr(), response.getWriteIndex(), true);
} }
void nexService::updateNexServiceConnection() void nexService::updateNexServiceConnection()
{ {
if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED) if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
{ {
this->connectionState = STATE_DISCONNECTED; this->connectionState = STATE_DISCONNECTED;
return; return;
} }
conNexService->update(); conNexService->Update();
sint32 datagramLen = conNexService->receiveDatagram(bufferReceive); sint32 datagramLen = conNexService->ReceiveDatagram(bufferReceive);
if (datagramLen > 0) if (datagramLen > 0)
{ {
if (nexIsRequest(&bufferReceive[0], datagramLen)) if (nexIsRequest(&bufferReceive[0], datagramLen))
@ -454,12 +454,12 @@ bool _extractStationUrlParamValue(const char* urlStr, const char* paramName, cha
return false; return false;
} }
void nexServiceAuthentication_parseStationURL(char* urlStr, stationUrl_t* stationUrl) void nexServiceAuthentication_parseStationURL(char* urlStr, prudpStationUrl* stationUrl)
{ {
// example: // example:
// prudps:/address=34.210.xxx.xxx;port=60181;CID=1;PID=2;sid=1;stream=10;type=2 // prudps:/address=34.210.xxx.xxx;port=60181;CID=1;PID=2;sid=1;stream=10;type=2
memset(stationUrl, 0, sizeof(stationUrl_t)); memset(stationUrl, 0, sizeof(prudpStationUrl));
char optionValue[128]; char optionValue[128];
if (_extractStationUrlParamValue(urlStr, "address", optionValue, sizeof(optionValue))) if (_extractStationUrlParamValue(urlStr, "address", optionValue, sizeof(optionValue)))
@ -499,7 +499,7 @@ typedef struct
sint32 kerberosTicketSize; sint32 kerberosTicketSize;
uint8 kerberosTicket2[4096]; uint8 kerberosTicket2[4096];
sint32 kerberosTicket2Size; sint32 kerberosTicket2Size;
stationUrl_t server; prudpStationUrl server;
// progress info // progress info
bool hasError; bool hasError;
bool done; bool done;
@ -611,18 +611,18 @@ void nexServiceSecure_handleResponse_RegisterEx(nexService* nex, nexServiceRespo
return; return;
} }
nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* accessKey, const char* nexToken) nexService* nex_secureLogin(prudpAuthServerInfo* authServerInfo, const char* accessKey, const char* nexToken)
{ {
prudpClient* prudpSecureSock = new prudpClient(authServerInfo->server.ip, authServerInfo->server.port, accessKey, authServerInfo); prudpClient* prudpSecureSock = new prudpClient(authServerInfo->server.ip, authServerInfo->server.port, accessKey, authServerInfo);
// wait until connected // wait until connected
while (true) while (true)
{ {
prudpSecureSock->update(); prudpSecureSock->Update();
if (prudpSecureSock->isConnected()) if (prudpSecureSock->IsConnected())
{ {
break; break;
} }
if (prudpSecureSock->getConnectionState() == prudpClient::STATE_DISCONNECTED) if (prudpSecureSock->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
{ {
// timeout or disconnected // timeout or disconnected
cemuLog_log(LogType::Force, "NEX: Secure login connection time-out"); cemuLog_log(LogType::Force, "NEX: Secure login connection time-out");
@ -638,7 +638,7 @@ nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* access
nexPacketBuffer packetBuffer(tempNexBufferArray, sizeof(tempNexBufferArray), true); nexPacketBuffer packetBuffer(tempNexBufferArray, sizeof(tempNexBufferArray), true);
char clientStationUrl[256]; char clientStationUrl[256];
sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->getSourcePort()); sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->GetSourcePort());
// station url list // station url list
packetBuffer.writeU32(1); packetBuffer.writeU32(1);
packetBuffer.writeString(clientStationUrl); packetBuffer.writeString(clientStationUrl);
@ -737,9 +737,9 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer
return nullptr; return nullptr;
} }
// auth info // auth info
auto authServerInfo = std::make_unique<authServerInfo_t>(); auto authServerInfo = std::make_unique<prudpAuthServerInfo>();
// decrypt ticket // decrypt ticket
RC4Ctx_t rc4Ticket; RC4Ctx rc4Ticket;
RC4_initCtx(&rc4Ticket, kerberosKey, 16); RC4_initCtx(&rc4Ticket, kerberosKey, 16);
RC4_transform(&rc4Ticket, nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, nexAuthService.kerberosTicket2); RC4_transform(&rc4Ticket, nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, nexAuthService.kerberosTicket2);
nexPacketBuffer packetKerberosTicket(nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, false); nexPacketBuffer packetKerberosTicket(nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, false);
@ -756,7 +756,7 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer
memcpy(authServerInfo->kerberosKey, kerberosKey, 16); memcpy(authServerInfo->kerberosKey, kerberosKey, 16);
memcpy(authServerInfo->secureKey, secureKey, 16); memcpy(authServerInfo->secureKey, secureKey, 16);
memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(stationUrl_t)); memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(prudpStationUrl));
authServerInfo->userPid = pid; authServerInfo->userPid = pid;
return nex_secureLogin(authServerInfo.get(), accessKey, nexToken); return nex_secureLogin(authServerInfo.get(), accessKey, nexToken);

File diff suppressed because it is too large Load Diff

View File

@ -4,26 +4,26 @@
#define RC4_N 256 #define RC4_N 256
typedef struct struct RC4Ctx
{ {
unsigned char S[RC4_N]; unsigned char S[RC4_N];
int i; int i;
int j; int j;
}RC4Ctx_t; };
void RC4_initCtx(RC4Ctx_t* rc4Ctx, char *key); void RC4_initCtx(RC4Ctx* rc4Ctx, const char* key);
void RC4_initCtx(RC4Ctx_t* rc4Ctx, unsigned char* key, int keyLen); void RC4_initCtx(RC4Ctx* rc4Ctx, unsigned char* key, int keyLen);
void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned char* output); void RC4_transform(RC4Ctx* rc4Ctx, unsigned char* input, int len, unsigned char* output);
typedef struct struct prudpStreamSettings
{ {
uint8 checksumBase; // calculated from key uint8 checksumBase; // calculated from key
uint8 accessKeyDigest[16]; // MD5 hash of key uint8 accessKeyDigest[16]; // MD5 hash of key
RC4Ctx_t rc4Client; RC4Ctx rc4Client;
RC4Ctx_t rc4Server; RC4Ctx rc4Server;
}prudpStreamSettings_t; };
typedef struct struct prudpStationUrl
{ {
uint32 ip; uint32 ip;
uint16 port; uint16 port;
@ -32,19 +32,17 @@ typedef struct
sint32 sid; sint32 sid;
sint32 stream; sint32 stream;
sint32 type; sint32 type;
}stationUrl_t; };
typedef struct struct prudpAuthServerInfo
{ {
uint32 userPid; uint32 userPid;
uint8 secureKey[16]; uint8 secureKey[16];
uint8 kerberosKey[16]; uint8 kerberosKey[16];
uint8 secureTicket[1024]; uint8 secureTicket[1024];
sint32 secureTicketLength; sint32 secureTicketLength;
stationUrl_t server; prudpStationUrl server;
}authServerInfo_t; };
uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length);
class prudpPacket class prudpPacket
{ {
@ -66,7 +64,7 @@ public:
static sint32 calculateSizeFromPacketData(uint8* data, sint32 length); static sint32 calculateSizeFromPacketData(uint8* data, sint32 length);
prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature); prudpPacket(prudpStreamSettings* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature);
bool requiresAck(); bool requiresAck();
void setData(uint8* data, sint32 length); void setData(uint8* data, sint32 length);
void setFragmentIndex(uint8 fragmentIndex); void setFragmentIndex(uint8 fragmentIndex);
@ -87,7 +85,7 @@ private:
uint16 flags; uint16 flags;
uint8 sessionId; uint8 sessionId;
uint32 specifiedPacketSignature; uint32 specifiedPacketSignature;
prudpStreamSettings_t* streamSettings; prudpStreamSettings* streamSettings;
std::vector<uint8> packetData; std::vector<uint8> packetData;
bool isEncrypted; bool isEncrypted;
uint16 m_sequenceId{0}; uint16 m_sequenceId{0};
@ -97,7 +95,7 @@ private:
class prudpIncomingPacket class prudpIncomingPacket
{ {
public: public:
prudpIncomingPacket(prudpStreamSettings_t* streamSettings, uint8* data, sint32 length); prudpIncomingPacket(prudpStreamSettings* streamSettings, uint8* data, sint32 length);
bool hasError(); bool hasError();
@ -122,83 +120,91 @@ public:
private: private:
bool isInvalid = false; bool isInvalid = false;
prudpStreamSettings_t* streamSettings = nullptr; prudpStreamSettings* streamSettings = nullptr;
}; };
typedef struct
{
prudpPacket* packet;
uint32 initialSendTimestamp;
uint32 lastRetryTimestamp;
sint32 retryCount;
}prudpAckRequired_t;
class prudpClient class prudpClient
{ {
struct PacketWithAckRequired
{
PacketWithAckRequired(prudpPacket* packet, uint32 initialSendTimestamp) :
packet(packet), initialSendTimestamp(initialSendTimestamp), lastRetryTimestamp(initialSendTimestamp) { }
prudpPacket* packet;
uint32 initialSendTimestamp;
uint32 lastRetryTimestamp;
sint32 retryCount{0};
};
public: public:
static const int STATE_CONNECTING = 0; enum class ConnectionState : uint8
static const int STATE_CONNECTED = 1; {
static const int STATE_DISCONNECTED = 2; Connecting,
Connected,
Disconnected
};
public:
prudpClient(uint32 dstIp, uint16 dstPort, const char* key); prudpClient(uint32 dstIp, uint16 dstPort, const char* key);
prudpClient(uint32 dstIp, uint16 dstPort, const char* key, authServerInfo_t* authInfo); prudpClient(uint32 dstIp, uint16 dstPort, const char* key, prudpAuthServerInfo* authInfo);
~prudpClient(); ~prudpClient();
bool isConnected(); bool IsConnected() const { return m_currentConnectionState == ConnectionState::Connected; }
ConnectionState GetConnectionState() const { return m_currentConnectionState; }
uint16 GetSourcePort() const { return m_srcPort; }
uint8 getConnectionState(); bool Update(); // update connection state and check for incoming packets. Returns true if ReceiveDatagram() should be called
void acknowledgePacket(uint16 sequenceId);
void sortIncomingDataPacket(prudpIncomingPacket* incomingPacket);
void handleIncomingPacket(prudpIncomingPacket* incomingPacket);
bool update(); // check for new incoming packets, returns true if receiveDatagram() should be called
sint32 receiveDatagram(std::vector<uint8>& outputBuffer); sint32 ReceiveDatagram(std::vector<uint8>& outputBuffer);
void sendDatagram(uint8* input, sint32 length, bool reliable = true); void SendDatagram(uint8* input, sint32 length, bool reliable = true);
uint16 getSourcePort();
SOCKET getSocket();
private: private:
prudpClient(); prudpClient();
void directSendPacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort);
sint32 kerberosEncryptData(uint8* input, sint32 length, uint8* output); void HandleIncomingPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket);
void queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort); void DirectSendPacket(prudpPacket* packet);
sint32 KerberosEncryptData(uint8* input, sint32 length, uint8* output);
void QueuePacket(prudpPacket* packet);
void AcknowledgePacket(uint16 sequenceId);
void SortIncomingDataPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket);
void SendCurrentHandshakePacket();
private: private:
uint16 srcPort; uint16 m_srcPort;
uint32 dstIp; uint32 m_dstIp;
uint16 dstPort; uint16 m_dstPort;
uint8 vport_src; uint8 m_srcVPort;
uint8 vport_dst; uint8 m_dstVPort;
prudpStreamSettings_t streamSettings; prudpStreamSettings m_streamSettings;
std::vector<prudpAckRequired_t> list_packetsWithAckReq; std::vector<PacketWithAckRequired> m_dataPacketsWithAckReq;
std::vector<prudpIncomingPacket*> queue_incomingPackets; std::vector<std::unique_ptr<prudpIncomingPacket>> m_incomingPacketQueue;
// connection handshake state
bool m_hasSynAck{false};
bool m_hasConAck{false};
uint32 m_lastHandshakeTimestamp{0};
uint8 m_handshakeRetryCount{0};
// connection // connection
uint8 currentConnectionState; ConnectionState m_currentConnectionState;
uint32 serverConnectionSignature; uint32 m_serverConnectionSignature;
uint32 clientConnectionSignature; uint32 m_clientConnectionSignature;
bool hasSentCon; uint32 m_lastPingTimestamp;
uint32 lastPingTimestamp;
uint16 outgoingSequenceId; uint16 m_outgoingReliableSequenceId{2}; // 1 is reserved for CON
uint16 incomingSequenceId; uint16 m_incomingSequenceId;
uint16 m_outgoingSequenceId_ping{0}; uint16 m_outgoingSequenceId_ping{0};
uint8 m_unacknowledgedPingCount{0}; uint8 m_unacknowledgedPingCount{0};
uint8 clientSessionId; uint8 m_clientSessionId;
uint8 serverSessionId; uint8 m_serverSessionId;
// secure // secure
bool isSecureConnection; bool m_isSecureConnection{false};
authServerInfo_t authInfo; prudpAuthServerInfo m_authInfo;
// socket // socket
SOCKET socketUdp; SOCKET m_socketUdp;
}; };
uint32 prudpGetMSTimestamp(); uint32 prudpGetMSTimestamp();