#include "stdafx.h" #include "HandshakeManager.h" #include "AuthModule.h" #include "HttpClient.h" #include "StringHelpers.h" #include "Common/vendor/nlohmann/json.hpp" static constexpr auto PROTOCOL_VERSION = L"1.0"; static wstring getField(const vector> &fields, const wchar_t *key) { for (const auto &[k, v] : fields) if (k == key) return v; return {}; } HandshakeManager::HandshakeManager(bool isServer) : isServer(isServer), state(HandshakeState::IDLE), activeModule(nullptr) { } void HandshakeManager::registerModule(unique_ptr module) { wstring name = module->schemeName(); modules[std::move(name)] = std::move(module); } void HandshakeManager::setCredentials(const wstring &token, const wstring &uid, const wstring &username, const wstring &variation) { accessToken = token; clientUid = uid; clientUsername = username; preferredVariation = variation; } vector> HandshakeManager::drainPendingPackets() { auto out = std::move(pendingPackets); pendingPackets.clear(); return out; } shared_ptr HandshakeManager::handlePacket(const shared_ptr &packet) { return isServer ? handleServer(packet) : handleClient(packet); } shared_ptr HandshakeManager::createInitialPacket() { state = HandshakeState::VERSION_SENT; wstring schemes; for (const auto &[name, mod] : modules) { if (!schemes.empty()) schemes += L","; schemes += name; } return makePacket(AuthStage::ANNOUNCE_VERSION, { {L"version", PROTOCOL_VERSION}, {L"schemes", schemes} }); } shared_ptr HandshakeManager::handleServer(const shared_ptr &packet) { switch (packet->stage) { case AuthStage::ANNOUNCE_VERSION: { protocolVersion = getField(packet->fields, L"version"); if (protocolVersion != PROTOCOL_VERSION) return fail(); if (modules.empty()) return fail(); auto splitCsv = [](const wstring &s) { vector out; for (size_t p = 0; p < s.size(); ) { size_t c = s.find(L',', p); if (c == wstring::npos) c = s.size(); out.push_back(s.substr(p, c - p)); p = c + 1; } return out; }; auto supported = splitCsv(getField(packet->fields, L"schemes")); activeModule = nullptr; for (auto &[name, mod] : modules) { if (supported.empty() || std::find(supported.begin(), supported.end(), name) != supported.end()) { activeModule = mod.get(); break; } } if (!activeModule) return fail(); state = HandshakeState::SCHEME_DECLARED; return makePacket(AuthStage::DECLARE_SCHEME, { {L"version", PROTOCOL_VERSION}, {L"scheme", activeModule->schemeName()} }); } case AuthStage::ACCEPT_SCHEME: { activeVariation = getField(packet->fields, L"variation"); state = HandshakeState::SETTINGS_SENT; auto settings = activeModule->getSettings(activeVariation); return makePacket(AuthStage::SCHEME_SETTINGS, std::move(settings)); } case AuthStage::BEGIN_AUTH: { state = HandshakeState::AUTH_IN_PROGRESS; return nullptr; } case AuthStage::AUTH_DATA: { wstring uid, username; if (!activeModule->onAuthData(packet->fields, uid, username)) return fail(); finalUid = uid; finalUsername = username; state = HandshakeState::AUTH_DATA_EXCHANGED; return nullptr; } case AuthStage::AUTH_DONE: { if (getField(packet->fields, L"uid") != finalUid || getField(packet->fields, L"username") != finalUsername) return fail(); state = HandshakeState::IDENTITY_ASSIGNED; return makePacket(AuthStage::ASSIGN_IDENTITY, { {L"uid", finalUid}, {L"username", finalUsername} }); } case AuthStage::CONFIRM_IDENTITY: { if (getField(packet->fields, L"uid") != finalUid || getField(packet->fields, L"username") != finalUsername) return fail(); state = HandshakeState::COMPLETE; return makePacket(AuthStage::AUTH_SUCCESS); } default: return fail(); } } shared_ptr HandshakeManager::handleClient(const shared_ptr &packet) { switch (packet->stage) { case AuthStage::DECLARE_SCHEME: { protocolVersion = getField(packet->fields, L"version"); wstring scheme = getField(packet->fields, L"scheme"); app.DebugPrintf("AUTH CLIENT: DECLARE_SCHEME scheme=%ls\n", scheme.c_str()); if (protocolVersion != PROTOCOL_VERSION) return fail(); auto it = modules.find(scheme); if (it == modules.end()) return fail(); activeModule = it->second.get(); auto variations = activeModule->supportedVariations(); if (!preferredVariation.empty() && std::find(variations.begin(), variations.end(), preferredVariation) != variations.end()) activeVariation = preferredVariation; else activeVariation = variations.empty() ? L"" : variations[0]; app.DebugPrintf("AUTH CLIENT: accepting variation=%ls\n", activeVariation.c_str()); state = HandshakeState::SCHEME_ACCEPTED; return makePacket(AuthStage::ACCEPT_SCHEME, {{L"variation", activeVariation}}); } case AuthStage::SCHEME_SETTINGS: { if (!activeModule) return fail(); wstring serverId = getField(packet->fields, L"serverId"); wstring joinUrlW = getField(packet->fields, L"joinUrl"); wstring scheme(activeModule->schemeName()); app.DebugPrintf("AUTH CLIENT: SCHEME_SETTINGS joinUrl=%ls serverId=%ls\n", joinUrlW.c_str(), serverId.c_str()); if (scheme == L"mcconsoles:session" && !accessToken.empty() && !joinUrlW.empty()) { nlohmann::json body = { {"accessToken", narrowStr(accessToken)}, {"selectedProfile", narrowStr(clientUid)}, {"serverId", narrowStr(serverId)} }; string joinUrl = narrowStr(joinUrlW); app.DebugPrintf("AUTH CLIENT: POSTing join to %s\n", joinUrl.c_str()); HttpResponse resp; try { resp = HttpClient::post(joinUrl, body.dump()); } catch (...) { app.DebugPrintf("AUTH CLIENT: join POST threw exception\n"); return fail(); } app.DebugPrintf("AUTH CLIENT: join POST status=%d\n", resp.statusCode); if (resp.statusCode < 200 || resp.statusCode >= 300) return fail(); } state = HandshakeState::AUTH_IN_PROGRESS; pendingPackets.push_back(makePacket(AuthStage::BEGIN_AUTH)); pendingPackets.push_back(makePacket(AuthStage::AUTH_DATA, { {L"uid", clientUid}, {L"username", clientUsername} })); pendingPackets.push_back(makePacket(AuthStage::AUTH_DONE, { {L"uid", clientUid}, {L"username", clientUsername} })); return nullptr; } case AuthStage::ASSIGN_IDENTITY: { finalUid = getField(packet->fields, L"uid"); finalUsername = getField(packet->fields, L"username"); state = HandshakeState::IDENTITY_CONFIRMED; return makePacket(AuthStage::CONFIRM_IDENTITY, { {L"uid", finalUid}, {L"username", finalUsername} }); } case AuthStage::AUTH_SUCCESS: { state = HandshakeState::COMPLETE; return nullptr; } case AuthStage::AUTH_FAILURE: { state = HandshakeState::FAILED; return nullptr; } default: return fail(); } } shared_ptr HandshakeManager::makePacket(AuthStage stage, vector> fields) { return std::make_shared(stage, std::move(fields)); } shared_ptr HandshakeManager::fail() { state = HandshakeState::FAILED; return makePacket(AuthStage::AUTH_FAILURE); }