prudp: Improve ping and ack logic

Fixes the issue where the friend service connection would always timeout on Pretendo servers

The individual changes are:
- Outgoing ping packets now use their own incrementing sequenceId (matches official NEX behavior)
- If the server sends us a ping packet with NEEDS_ACK, we now respond
- Misc smaller refactoring and code clean up
- Added PRUDP as a separate logging option
This commit is contained in:
Exzap 2024-04-16 21:38:20 +02:00
parent 10c78eccce
commit ee36992bd6
5 changed files with 122 additions and 65 deletions

View File

@ -42,6 +42,7 @@ enum class LogType : sint32
ProcUi = 39, ProcUi = 39,
PRUDP = 40,
}; };
template <> template <>

View File

@ -221,6 +221,7 @@ NexFriends::NexFriends(uint32 authServerIp, uint16 authServerPort, const char* a
NexFriends::~NexFriends() NexFriends::~NexFriends()
{ {
if(nexCon)
nexCon->destroy(); nexCon->destroy();
} }

View File

@ -219,7 +219,7 @@ prudpPacket::prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8
this->type = type; this->type = type;
this->flags = flags; this->flags = flags;
this->sessionId = sessionId; this->sessionId = sessionId;
this->sequenceId = sequenceId; this->m_sequenceId = sequenceId;
this->specifiedPacketSignature = packetSignature; this->specifiedPacketSignature = packetSignature;
this->streamSettings = streamSettings; this->streamSettings = streamSettings;
this->fragmentIndex = 0; this->fragmentIndex = 0;
@ -257,7 +257,7 @@ sint32 prudpPacket::buildData(uint8* output, sint32 maxLength)
*(uint16*)(packetBuffer + 0x02) = typeAndFlags; *(uint16*)(packetBuffer + 0x02) = typeAndFlags;
*(uint8*)(packetBuffer + 0x04) = sessionId; *(uint8*)(packetBuffer + 0x04) = sessionId;
*(uint32*)(packetBuffer + 0x05) = packetSignature(); *(uint32*)(packetBuffer + 0x05) = packetSignature();
*(uint16*)(packetBuffer + 0x09) = sequenceId; *(uint16*)(packetBuffer + 0x09) = m_sequenceId;
writeIndex = 0xB; writeIndex = 0xB;
// variable fields // variable fields
if (this->type == TYPE_SYN) if (this->type == TYPE_SYN)
@ -286,7 +286,9 @@ sint32 prudpPacket::buildData(uint8* output, sint32 maxLength)
// no data // no data
} }
else else
assert_dbg(); {
cemu_assert_suspicious();
}
// checksum // checksum
*(uint8*)(packetBuffer + writeIndex) = calculateChecksum(packetBuffer, writeIndex); *(uint8*)(packetBuffer + writeIndex) = calculateChecksum(packetBuffer, writeIndex);
writeIndex++; writeIndex++;
@ -585,7 +587,7 @@ void prudpClient::acknowledgePacket(uint16 sequenceId)
auto it = std::begin(list_packetsWithAckReq); auto it = std::begin(list_packetsWithAckReq);
while (it != std::end(list_packetsWithAckReq)) while (it != std::end(list_packetsWithAckReq))
{ {
if (it->packet->sequenceId == sequenceId) if (it->packet->GetSequenceId() == sequenceId)
{ {
delete it->packet; delete it->packet;
list_packetsWithAckReq.erase(it); list_packetsWithAckReq.erase(it);
@ -633,17 +635,46 @@ sint32 prudpClient::kerberosEncryptData(uint8* input, sint32 length, uint8* outp
} }
void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket) void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket)
{
if(incomingPacket->type == prudpPacket::TYPE_PING)
{ {
if (incomingPacket->flags&prudpPacket::FLAG_ACK) if (incomingPacket->flags&prudpPacket::FLAG_ACK)
{ {
// ack packet // ack for our ping packet
acknowledgePacket(incomingPacket->sequenceId); if(incomingPacket->flags&prudpPacket::FLAG_NEED_ACK)
if ((incomingPacket->type == prudpPacket::TYPE_DATA || incomingPacket->type == prudpPacket::TYPE_PING) && incomingPacket->packetData.empty()) cemuLog_log(LogType::PRUDP, "[PRUDP] Received unexpected ping packet with both ACK and NEED_ACK set");
if(m_unacknowledgedPingCount > 0)
{ {
// ack packet if(incomingPacket->sequenceId == m_outgoingSequenceId_ping)
{
cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet ACK (unacknowledged count: {})", m_unacknowledgedPingCount);
m_unacknowledgedPingCount = 0;
}
else
{
cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet ACK with wrong sequenceId (expected: {}, received: {})", m_outgoingSequenceId_ping, incomingPacket->sequenceId);
}
}
else
{
cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet ACK which we dont need");
}
}
else if (incomingPacket->flags&prudpPacket::FLAG_NEED_ACK)
{
// other side is asking for ping ack
cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet with NEED_ACK set. Sending ACK back");
cemu_assert_debug(incomingPacket->packetData.empty()); // todo - echo data?
prudpPacket ackPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0);
directSendPacket(&ackPacket, dstIp, dstPort);
}
delete incomingPacket; delete incomingPacket;
return; return;
} }
// handle general packet ACK
if (incomingPacket->flags&prudpPacket::FLAG_ACK)
{
acknowledgePacket(incomingPacket->sequenceId);
} }
// special cases // special cases
if (incomingPacket->type == prudpPacket::TYPE_SYN) if (incomingPacket->type == prudpPacket::TYPE_SYN)
@ -680,7 +711,7 @@ void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket)
// set packet specific data (client connection signature) // set packet specific data (client connection signature)
conPacket->setData((uint8*)&this->clientConnectionSignature, sizeof(uint32)); conPacket->setData((uint8*)&this->clientConnectionSignature, sizeof(uint32));
} }
// sent packet // send packet
queuePacket(conPacket, dstIp, dstPort); queuePacket(conPacket, dstIp, dstPort);
// remember con packet as sent // remember con packet as sent
hasSentCon = true; hasSentCon = true;
@ -694,17 +725,28 @@ void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket)
if (currentConnectionState == STATE_CONNECTING) if (currentConnectionState == STATE_CONNECTING)
{ {
lastPingTimestamp = prudpGetMSTimestamp(); lastPingTimestamp = prudpGetMSTimestamp();
if(serverSessionId != 0) cemu_assert_debug(serverSessionId == 0);
cemuLog_logDebug(LogType::Force, "PRUDP: ServerSessionId is already set");
serverSessionId = incomingPacket->sessionId; serverSessionId = incomingPacket->sessionId;
currentConnectionState = STATE_CONNECTED; currentConnectionState = STATE_CONNECTED;
//printf("Connection established. ClientSession %02x ServerSession %02x\n", clientSessionId, serverSessionId); cemuLog_log(LogType::PRUDP, "[PRUDP] Connection established. ClientSession {:02x} ServerSession {:02x}", clientSessionId, serverSessionId);
} }
delete incomingPacket; delete incomingPacket;
return; return;
} }
else if (incomingPacket->type == prudpPacket::TYPE_DATA) else if (incomingPacket->type == prudpPacket::TYPE_DATA)
{ {
// send ack back if requested
if (incomingPacket->flags&prudpPacket::FLAG_NEED_ACK)
{
prudpPacket ackPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0);
directSendPacket(&ackPacket, dstIp, dstPort);
}
// skip data packets without payload
if (incomingPacket->packetData.empty())
{
delete incomingPacket;
return;
}
// verify some values // verify some values
uint16 seqDist = incomingPacket->sequenceId - incomingSequenceId; uint16 seqDist = incomingPacket->sequenceId - incomingSequenceId;
if (seqDist >= 0xC000) if (seqDist >= 0xC000)
@ -719,7 +761,7 @@ void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket)
if (it->sequenceId == incomingPacket->sequenceId) if (it->sequenceId == incomingPacket->sequenceId)
{ {
// already queued (should check other values too, like packet type?) // already queued (should check other values too, like packet type?)
cemuLog_logDebug(LogType::Force, "Duplicate PRUDP packet received"); cemuLog_log(LogType::PRUDP, "Duplicate PRUDP packet received");
delete incomingPacket; delete incomingPacket;
return; return;
} }
@ -738,21 +780,12 @@ void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket)
delete incomingPacket; delete incomingPacket;
return; return;
} }
if (incomingPacket->flags&prudpPacket::FLAG_NEED_ACK && incomingPacket->type == prudpPacket::TYPE_DATA)
{
// send ack back
prudpPacket* ackPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0);
queuePacket(ackPacket, dstIp, dstPort);
}
} }
bool prudpClient::update() bool prudpClient::update()
{ {
if (currentConnectionState == STATE_DISCONNECTED) if (currentConnectionState == STATE_DISCONNECTED)
{
return false; return false;
}
uint32 currentTimestamp = prudpGetMSTimestamp(); uint32 currentTimestamp = prudpGetMSTimestamp();
// check for incoming packets // check for incoming packets
uint8 receiveBuffer[4096]; uint8 receiveBuffer[4096];
@ -771,25 +804,20 @@ bool prudpClient::update()
sint32 packetLength = prudpPacket::calculateSizeFromPacketData(receiveBuffer + pIdx, r - pIdx); sint32 packetLength = prudpPacket::calculateSizeFromPacketData(receiveBuffer + pIdx, r - pIdx);
if (packetLength <= 0 || (pIdx + packetLength) > r) if (packetLength <= 0 || (pIdx + packetLength) > r)
{ {
cemuLog_logDebug(LogType::Force, "PRUDP: Invalid packet length"); cemuLog_log(LogType::Force, "[PRUDP] Invalid packet length");
break; break;
} }
prudpIncomingPacket* incomingPacket = new prudpIncomingPacket(&streamSettings, receiveBuffer + pIdx, packetLength); prudpIncomingPacket* incomingPacket = new prudpIncomingPacket(&streamSettings, receiveBuffer + pIdx, packetLength);
pIdx += packetLength; pIdx += packetLength;
if (incomingPacket->hasError()) if (incomingPacket->hasError())
{ {
cemuLog_logDebug(LogType::Force, "PRUDP: Packet error"); cemuLog_log(LogType::Force, "[PRUDP] Packet error");
delete incomingPacket; delete incomingPacket;
break; break;
} }
// sessionId validation is complicated and depends on specific flags and type combinations. It does not seem to cover all packet types if (incomingPacket->type != prudpPacket::TYPE_CON && incomingPacket->sessionId != serverSessionId)
bool validateSessionId = serverSessionId != 0;
if((incomingPacket->type == prudpPacket::TYPE_PING && (incomingPacket->flags&prudpPacket::FLAG_ACK) != 0))
validateSessionId = false; // PING + ack -> disable session id validation. Pretendo's friend server sends PING ack packets without setting the sessionId (it is 0)
if (validateSessionId && incomingPacket->sessionId != serverSessionId)
{ {
cemuLog_logDebug(LogType::Force, "PRUDP: Invalid session id"); cemuLog_log(LogType::PRUDP, "[PRUDP] Invalid session id");
delete incomingPacket; delete incomingPacket;
continue; // different session continue; // different session
} }
@ -816,14 +844,45 @@ bool prudpClient::update()
} }
} }
// check if we need to send another ping // check if we need to send another ping
if (currentConnectionState == STATE_CONNECTED && (currentTimestamp - lastPingTimestamp) >= 20000) if (currentConnectionState == STATE_CONNECTED)
{ {
// send ping if(m_unacknowledgedPingCount != 0) // counts how many times we sent a ping packet (for the current sequenceId) without receiving an ack
prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK | prudpPacket::FLAG_RELIABLE, this->clientSessionId, this->outgoingSequenceId, serverConnectionSignature); {
this->outgoingSequenceId++; // increase since prudpPacket::FLAG_RELIABLE is set (note: official Wii U friends client sends ping packets without FLAG_RELIABLE) // we are waiting for the ack of the previous ping, but it hasn't arrived yet so send another ping packet
queuePacket(pingPacket, dstIp, dstPort); if((currentTimestamp - lastPingTimestamp) >= 1500)
{
cemuLog_log(LogType::PRUDP, "[PRUDP] Resending ping packet (no ack received)");
if(m_unacknowledgedPingCount >= 10)
{
// too many unacknowledged pings, assume the connection is dead
currentConnectionState = STATE_DISCONNECTED;
cemuLog_log(LogType::PRUDP, "PRUDP: Connection did not receive a ping response in a while. Assuming disconnect");
return false;
}
// resend the ping packet
prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->clientSessionId, this->m_outgoingSequenceId_ping, serverConnectionSignature);
directSendPacket(pingPacket, dstIp, dstPort);
m_unacknowledgedPingCount++;
delete pingPacket;
lastPingTimestamp = currentTimestamp; lastPingTimestamp = currentTimestamp;
} }
}
else
{
if((currentTimestamp - lastPingTimestamp) >= 20000)
{
cemuLog_log(LogType::PRUDP, "[PRUDP] Sending new ping packet with sequenceId {}", this->m_outgoingSequenceId_ping+1);
// start a new ping packet with a new sequenceId. Note that ping packets have their own sequenceId and acknowledgement happens by manually comparing the incoming ping ACK against the last sent sequenceId
// only one unacknowledged ping packet can be in flight at a time. We will resend the same ping packet until we receive an ack
this->m_outgoingSequenceId_ping++; // increment before sending. The first ping has a sequenceId of 1
prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->clientSessionId, this->m_outgoingSequenceId_ping, serverConnectionSignature);
directSendPacket(pingPacket, dstIp, dstPort);
m_unacknowledgedPingCount++;
delete pingPacket;
lastPingTimestamp = currentTimestamp;
}
}
}
return false; return false;
} }
@ -844,6 +903,7 @@ void prudpClient::queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort)
{ {
if (packet->requiresAck()) if (packet->requiresAck())
{ {
cemu_assert_debug(packet->GetType() != prudpPacket::TYPE_PING); // ping packets use their own logic for acks, dont queue them
// remember this packet until we receive the ack // remember this packet until we receive the ack
prudpAckRequired_t ackRequired = { 0 }; prudpAckRequired_t ackRequired = { 0 };
ackRequired.packet = packet; ackRequired.packet = packet;
@ -861,15 +921,17 @@ void prudpClient::queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort)
void prudpClient::sendDatagram(uint8* input, sint32 length, bool reliable) void prudpClient::sendDatagram(uint8* input, sint32 length, bool reliable)
{ {
if (reliable == false) cemu_assert_debug(reliable); // non-reliable packets require testing
{
assert_dbg(); // todo
}
if(length >= 0x300) if(length >= 0x300)
assert_dbg(); // too long, need to split into multiple fragments {
cemuLog_logOnce(LogType::Force, "PRUDP: Datagram too long");
}
// single fragment data packet // single fragment data packet
prudpPacket* packet = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, prudpPacket::FLAG_NEED_ACK | prudpPacket::FLAG_RELIABLE, clientSessionId, outgoingSequenceId, 0); uint16 flags = prudpPacket::FLAG_NEED_ACK;
if(reliable)
flags |= prudpPacket::FLAG_RELIABLE;
prudpPacket* packet = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, flags, clientSessionId, outgoingSequenceId, 0);
if(reliable)
outgoingSequenceId++; outgoingSequenceId++;
packet->setFragmentIndex(0); packet->setFragmentIndex(0);
packet->setData(input, length); packet->setData(input, length);
@ -919,13 +981,7 @@ sint32 prudpClient::receiveDatagram(std::vector<uint8>& outputBuffer)
} }
delete incomingPacket; delete incomingPacket;
// remove packet from queue // remove packet from queue
sint32 size = (sint32)queue_incomingPackets.size(); queue_incomingPackets.erase(queue_incomingPackets.begin());
size--;
for (sint32 i = 0; i < size; i++)
{
queue_incomingPackets[i] = queue_incomingPackets[i + 1];
}
queue_incomingPackets.resize(size);
// advance expected sequence id // advance expected sequence id
this->incomingSequenceId++; this->incomingSequenceId++;
return datagramLen; return datagramLen;
@ -975,13 +1031,7 @@ sint32 prudpClient::receiveDatagram(std::vector<uint8>& outputBuffer)
delete incomingPacket; delete incomingPacket;
} }
// remove packets from queue // remove packets from queue
sint32 size = (sint32)queue_incomingPackets.size(); queue_incomingPackets.erase(queue_incomingPackets.begin(), queue_incomingPackets.begin() + chainLength);
size -= chainLength;
for (sint32 i = 0; i < size; i++)
{
queue_incomingPackets[i] = queue_incomingPackets[i + chainLength];
}
queue_incomingPackets.resize(size);
this->incomingSequenceId += chainLength; this->incomingSequenceId += chainLength;
return writeIndex; return writeIndex;
} }

View File

@ -71,15 +71,14 @@ public:
void setData(uint8* data, sint32 length); void setData(uint8* data, sint32 length);
void setFragmentIndex(uint8 fragmentIndex); void setFragmentIndex(uint8 fragmentIndex);
sint32 buildData(uint8* output, sint32 maxLength); sint32 buildData(uint8* output, sint32 maxLength);
uint8 GetType() const { return type; }
uint16 GetSequenceId() const { return m_sequenceId; }
private: private:
uint32 packetSignature(); uint32 packetSignature();
uint8 calculateChecksum(uint8* data, sint32 length); uint8 calculateChecksum(uint8* data, sint32 length);
public:
uint16 sequenceId;
private: private:
uint8 src; uint8 src;
uint8 dst; uint8 dst;
@ -91,6 +90,8 @@ private:
prudpStreamSettings_t* streamSettings; prudpStreamSettings_t* streamSettings;
std::vector<uint8> packetData; std::vector<uint8> packetData;
bool isEncrypted; bool isEncrypted;
uint16 m_sequenceId{0};
}; };
class prudpIncomingPacket class prudpIncomingPacket
@ -186,6 +187,9 @@ private:
uint16 outgoingSequenceId; uint16 outgoingSequenceId;
uint16 incomingSequenceId; uint16 incomingSequenceId;
uint16 m_outgoingSequenceId_ping{0};
uint8 m_unacknowledgedPingCount{0};
uint8 clientSessionId; uint8 clientSessionId;
uint8 serverSessionId; uint8 serverSessionId;

View File

@ -2232,6 +2232,7 @@ void MainWindow::RecreateMenu()
debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::CoreinitThread), _("&Coreinit Thread API"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::CoreinitThread)); debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::CoreinitThread), _("&Coreinit Thread API"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::CoreinitThread));
debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::NN_NFP), _("&NN NFP"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::NN_NFP)); debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::NN_NFP), _("&NN NFP"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::NN_NFP));
debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::NN_FP), _("&NN FP"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::NN_FP)); debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::NN_FP), _("&NN FP"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::NN_FP));
debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::PRUDP), _("&PRUDP (for NN FP)"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::PRUDP));
debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::NN_BOSS), _("&NN BOSS"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::NN_BOSS)); debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::NN_BOSS), _("&NN BOSS"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::NN_BOSS));
debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::GX2), _("&GX2 API"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::GX2)); debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::GX2), _("&GX2 API"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::GX2));
debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::SoundAPI), _("&Audio API"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::SoundAPI)); debugLoggingMenu->AppendCheckItem(MAINFRAME_MENU_ID_DEBUG_LOGGING0 + stdx::to_underlying(LogType::SoundAPI), _("&Audio API"), wxEmptyString)->Check(cemuLog_isLoggingEnabled(LogType::SoundAPI));