diff --git a/include/kNet/MessageConnection.h b/include/kNet/MessageConnection.h index 24d63e9..4566209 100644 --- a/include/kNet/MessageConnection.h +++ b/include/kNet/MessageConnection.h @@ -269,6 +269,13 @@ class MessageConnection : public RefCountable /// Registers a new listener object for the events of this connection. void RegisterInboundMessageHandler(IMessageHandler *handler); // [main thread] + void SetUserContext(void* ctx); + + void* GetUserContext() const; + + template + T *GetUserContext() const; + /// Fetches all newly received messages waiting in the inbound queue, and passes each of these /// to the message handler registered using RegisterInboundMessageHandler. /// Call this function periodically to receive new data from the network if you are using the Observer pattern. @@ -486,6 +493,8 @@ class MessageConnection : public RefCountable /// The object that receives notifications of all received data. IMessageHandler *inboundMessageHandler; // [main thread] + void *userContext; + /// The underlying socket on top of which this connection operates. Socket *socket; // [set by main thread before the worker thread is running. Read-only when worker thread is running. Read by main and worker thread] @@ -603,4 +612,10 @@ void MessageConnection::Send(const SerializableMessage &data, unsigned long cont SendStruct(data, SerializableMessage::messageID, data.inOrder, data.reliable, data.priority, contentID); } +template +T *MessageConnection::GetUserContext() const +{ + return reinterpret_cast(userContext); +} + } // ~kNet diff --git a/samples/HelloClient/HelloClient.cpp b/samples/HelloClient/HelloClient.cpp index 46fd064..76e1801 100644 --- a/samples/HelloClient/HelloClient.cpp +++ b/samples/HelloClient/HelloClient.cpp @@ -22,48 +22,68 @@ using namespace kNet; // Define a MessageID for our a custom message. const message_id_t cHelloMessageID = 10; +const message_id_t cRegisterMessageID = 11; +const message_id_t cGeneralMessageID = 12; // This object gets called whenever new data is received. class MessageListener : public IMessageHandler { public: - void HandleMessage(MessageConnection *source, packet_id_t /*packetId*/, message_id_t messageId, const char *data, size_t numBytes) - { - if (messageId == cHelloMessageID) - { - // Read what we received. - DataDeserializer dd(data, numBytes); - std::cout << "Server says: " << dd.ReadString() << std::endl; - - source->Close(0); - } - } + void HandleMessage(MessageConnection* source, packet_id_t /*packetId*/, message_id_t messageId, const char* data, size_t numBytes) + { + const int maxBytesCount = 256; + if (messageId == cHelloMessageID) + { + // Read what we received. + DataDeserializer dd(data, numBytes); + std::cout << "Server says: " << dd.ReadString() << std::endl; + + NetworkMessage* msg = source->StartNewMessage(cRegisterMessageID, maxBytesCount); + msg->reliable = true; + DataSerializer ds(msg->data, maxBytesCount); + ds.AddString("Captain Jack"); + source->EndAndQueueMessage(msg, ds.BytesFilled()); + } + else if (messageId == cGeneralMessageID) + { + DataDeserializer dd(data, numBytes); + std::cout << "Server says: " << dd.ReadString() << std::endl; + + NetworkMessage* msg = source->StartNewMessage(cGeneralMessageID, maxBytesCount); + msg->reliable = true; + msg->priority = 10; + DataSerializer ds(msg->data, maxBytesCount); + ds.AddString("Nothing, bye!"); + source->EndAndQueueMessage(msg, ds.BytesFilled()); + source->Close(); + } + } }; BottomMemoryAllocator bma; -int main(int argc, char **argv) +int main(int argc, char** argv) { - if (argc < 2) - { - std::cout << "Usage: " << argv[0] << " server-ip" << std::endl; - return 0; - } - - kNet::SetLogChannels(LogUser | LogInfo | LogError); - - EnableMemoryLeakLoggingAtExit(); - - Network network; - MessageListener listener; - const unsigned short cServerPort = 1234; - Ptr(MessageConnection) connection = network.Connect(argv[1], cServerPort, SocketOverUDP, &listener); - - if (connection) - { - // Run the main client loop. - connection->RunModalClient(); - } - - return 0; + if (argc < 2) + { + std::cout << "Usage: " << argv[0] << " server-ip" << std::endl; + return 0; + } + + kNet::SetLogChannels(LogUser | LogInfo | LogError); + + EnableMemoryLeakLoggingAtExit(); + + Network network; + MessageListener listener; + const unsigned short cServerPort = 1234; + Ptr(MessageConnection) connection = network.Connect(argv[1], cServerPort, SocketOverUDP, &listener); + + if (connection) + { + // Run the main client loop. + connection->RunModalClient(); + } + + return 0; } diff --git a/samples/HelloServer/HelloServer.cpp b/samples/HelloServer/HelloServer.cpp index 7339ef0..1c3ecad 100644 --- a/samples/HelloServer/HelloServer.cpp +++ b/samples/HelloServer/HelloServer.cpp @@ -22,13 +22,68 @@ using namespace kNet; // Define a MessageID for our a custom message. const message_id_t cHelloMessageID = 10; +const message_id_t cRegisterMessageID = 11; +const message_id_t cGeneralMessageID = 12; + +class UserContext +{ +public: + explicit UserContext(MessageConnection* conn): messageConnection(conn) + { + } + + ~UserContext() + { + KNET_LOG(LogUser, "Client '%s' is being destroyed" , userName.c_str()); + } + + void OnPeerData(packet_id_t pid, message_id_t mid, const char* data, size_t len) + { + MessageConnection* conn = messageConnection; + if (mid == cRegisterMessageID) + { + DataDeserializer dd(data, len); + + userName = dd.ReadString(); + KNET_LOG(LogUser, "Client '%s' connected from %s.", userName.c_str(), conn->ToString().c_str()); + + const int maxBytesCount = 256; + NetworkMessage* msg = conn->StartNewMessage(cGeneralMessageID, maxBytesCount); + msg->reliable = true; + DataSerializer ds(msg->data, maxBytesCount); + ds.AddString("Hello, " + userName + "! What can I do for you?"); + conn->EndAndQueueMessage(msg, ds.BytesFilled()); + } + else + { + DataDeserializer dd(data, len); + KNET_LOG(LogUser, "Get message from %s: %s.", userName.c_str(), dd.ReadString().c_str()); + } + } + +private: + MessageConnection* messageConnection; + std::string userName; +}; + // This object gets called for notifications on new network connection events. -class ServerListener : public INetworkServerListener +class ServerListener : public INetworkServerListener, public IMessageHandler { public: + void HandleMessage(MessageConnection* source, packet_id_t packetId, message_id_t messageId, const char* data, + size_t numBytes) + { + UserContext* uc = source->GetUserContext(); + uc->OnPeerData(packetId, messageId, data, numBytes); + } + void NewConnectionEstablished(MessageConnection *connection) { + UserContext* uc = new UserContext(connection); + connection->SetUserContext(uc); + connection->RegisterInboundMessageHandler(this); + const int maxMsgBytes = 256; // Start building a new message. NetworkMessage *msg = connection->StartNewMessage(cHelloMessageID, maxMsgBytes); @@ -46,6 +101,8 @@ class ServerListener : public INetworkServerListener void ClientDisconnected(MessageConnection *connection) { connection->Disconnect(); + UserContext* uc = connection->GetUserContext(); + delete uc; } }; diff --git a/src/MessageConnection.cpp b/src/MessageConnection.cpp index b4c4af9..fa3ca13 100644 --- a/src/MessageConnection.cpp +++ b/src/MessageConnection.cpp @@ -94,7 +94,7 @@ outboundAcceptQueue(16*1024), inboundMessageQueue(16*1024), #ifdef KNET_NO_MAXHEAP outboundQueue(16 * 1024), #endif -inboundMessageHandler(0), socket(socket_), +inboundMessageHandler(0), userContext(0), socket(socket_), bOutboundSendsPaused(false), rtt(0.f), lastHeardTime(Clock::Tick()), @@ -1117,6 +1117,18 @@ void MessageConnection::RegisterInboundMessageHandler(IMessageHandler *handler) inboundMessageHandler = handler; } +void MessageConnection::SetUserContext(void *ctx) +{ + AssertInMainThreadContext(); + + userContext = ctx; +} + +void* MessageConnection::GetUserContext() const +{ + return userContext; +} + void MessageConnection::SendPingRequestMessage(bool internalQueue) { #ifdef KNET_THREAD_CHECKING_ENABLED