diff --git a/src/server.cpp b/src/server.cpp index 02a53e2..aa984e9 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -257,14 +257,13 @@ public: } } } + // registerClient can happen after a successful websocket handshake. + // However, the connection might not be closed gracefully, so the + // corresponding deregister operation happens in the connection + // destructor rather than the close handler laminar.registerClient(c->lc); laminar.sendStatus(c->lc); }); - - wss.set_close_handler([this](websocketpp::connection_hdl hdl){ - websocket::connection_ptr c = wss.get_con_from_hdl(hdl); - laminar.deregisterClient(c->lc); - }); } // Return a new connection object linked with the context defined below. @@ -276,6 +275,12 @@ public: return c; } + void connectionDestroyed(LaminarClient* lc) { + // This will be called for all connections, not just websockets, so + // the laminar instance should silently ignore unknown clients + laminar.deregisterClient(lc); + } + private: Resources resources; LaminarInterface& laminar; @@ -302,6 +307,7 @@ struct RpcConnection { // and the corresponding kj async methods struct Server::WebsocketConnection : public LaminarClient { WebsocketConnection(kj::Own&& stream, Server::HttpImpl& http) : + http(http), stream(kj::mv(stream)), cn(http.newConnection(this)), writePaf(kj::newPromiseAndFulfiller()) @@ -314,7 +320,13 @@ struct Server::WebsocketConnection : public LaminarClient { cn->start(); } - virtual ~WebsocketConnection() noexcept(true) override {} + virtual ~WebsocketConnection() noexcept(true) override { + // Removes the connection from the list of registered clients. Must be + // here rather than in the websocket closing handshake because connections + // might be unexpectedly/aggressively closed and any references must be + // removed. + http.connectionDestroyed(this); + } kj::Promise pend() { return stream->tryRead(ibuf, 1, sizeof(ibuf)).then([this](size_t sz){ @@ -347,6 +359,7 @@ struct Server::WebsocketConnection : public LaminarClient { cn->send(payload, websocketpp::frame::opcode::text); } + HttpImpl& http; kj::Own stream; websocket::connection_ptr cn; std::string outputBuffer; @@ -357,9 +370,11 @@ struct Server::WebsocketConnection : public LaminarClient { Server::Server(LaminarInterface& li, kj::StringPtr rpcBindAddress, kj::StringPtr httpBindAddress) : rpcInterface(kj::heap(li)), - httpInterface(new HttpImpl(li)), + laminarInterface(li), + httpInterface(kj::heap(li)), ioContext(kj::setupAsyncIo()), - tasks(*this) + tasks(*this), + httpReady(kj::newPromiseAndFulfiller()) { // RPC task if(rpcBindAddress.startsWith("unix:")) @@ -375,13 +390,12 @@ Server::Server(LaminarInterface& li, kj::StringPtr rpcBindAddress, tasks.add(ioContext.provider->getNetwork().parseAddress(httpBindAddress) .then([this](kj::Own&& addr) { acceptHttpClient(addr->listen()); + // TODO: a better way? Currently used only for testing + httpReady.fulfiller->fulfill(); })); } Server::~Server() { - // RpcImpl is deleted through Capability::Client. - // Deal with the HTTP interface the old-fashioned way - delete httpInterface; } void Server::start() { @@ -410,14 +424,11 @@ void Server::addDescriptor(int fd, std::function cb) { void Server::acceptHttpClient(kj::Own&& listener) { auto ptr = listener.get(); tasks.add(ptr->accept().then(kj::mvCapture(kj::mv(listener), - [this](kj::Own&& listener, - kj::Own&& connection) { + [this](kj::Own&& listener, kj::Own&& connection) { acceptHttpClient(kj::mv(listener)); auto conn = kj::heap(kj::mv(connection), *httpInterface); - auto promises = kj::heapArrayBuilder>(2); - promises.add(conn->pend()); - promises.add(conn->writeTask()); - return kj::joinPromises(promises.finish()).attach(std::move(conn)); + // delete the connection when either the read or write task completes + return conn->pend().exclusiveJoin(conn->writeTask()).attach(kj::mv(conn)); })) ); } @@ -425,8 +436,7 @@ void Server::acceptHttpClient(kj::Own&& listener) { void Server::acceptRpcClient(kj::Own&& listener) { auto ptr = listener.get(); tasks.add(ptr->accept().then(kj::mvCapture(kj::mv(listener), - [this](kj::Own&& listener, - kj::Own&& connection) { + [this](kj::Own&& listener, kj::Own&& connection) { acceptRpcClient(kj::mv(listener)); auto server = kj::heap(kj::mv(connection), rpcInterface, capnp::ReaderOptions()); tasks.add(server->network.onDisconnect().attach(kj::mv(server))); @@ -446,3 +456,10 @@ kj::Promise Server::handleFdRead(kj::AsyncInputStream* stream, char* buffe return kj::Promise(kj::READY_NOW); }); } + +void Server::taskFailed(kj::Exception &&exception) { + // An unexpected http connection close can cause an exception, so don't re-throw. + // TODO: consider re-throwing selected exceptions + LLOG(INFO, exception); + //kj::throwFatalException(kj::mv(exception)); +} diff --git a/src/server.h b/src/server.h index ffc70a4..5fe5442 100644 --- a/src/server.h +++ b/src/server.h @@ -48,9 +48,7 @@ private: void acceptRpcClient(kj::Own&& listener); kj::Promise handleFdRead(kj::AsyncInputStream* stream, char* buffer, std::function cb); - void taskFailed(kj::Exception&& exception) override { - kj::throwFatalException(kj::mv(exception)); - } + void taskFailed(kj::Exception&& exception) override; private: struct WebsocketConnection; @@ -58,9 +56,14 @@ private: int efd_quit; capnp::Capability::Client rpcInterface; - HttpImpl* httpInterface; + LaminarInterface& laminarInterface; + kj::Own httpInterface; kj::AsyncIoContext ioContext; kj::TaskSet tasks; + + // TODO: restructure so this isn't necessary + friend class ServerTest; + kj::PromiseFulfillerPair httpReady; }; #endif // LAMINAR_SERVER_H_ diff --git a/test/test-server.cpp b/test/test-server.cpp index 67a06fc..19ea8d8 100644 --- a/test/test-server.cpp +++ b/test/test-server.cpp @@ -43,8 +43,18 @@ public: class MockLaminar : public LaminarInterface { public: - MOCK_METHOD1(registerClient, void(LaminarClient*)); - MOCK_METHOD1(deregisterClient, void(LaminarClient*)); + LaminarClient* client = nullptr; + virtual void registerClient(LaminarClient* c) override { + ASSERT_EQ(nullptr, client); + client = c; + EXPECT_CALL(*this, sendStatus(client)).Times(testing::Exactly(1)); + } + + virtual void deregisterClient(LaminarClient* c) override { + ASSERT_EQ(client, c); + client = nullptr; + } + MOCK_METHOD2(queueJob, std::shared_ptr(std::string name, ParamMap params)); MOCK_METHOD1(registerWaiter, void(LaminarWaiter* waiter)); MOCK_METHOD1(deregisterWaiter, void(LaminarWaiter* waiter)); @@ -75,6 +85,11 @@ protected: kj::WaitScope& ws() const { return server->ioContext.waitScope; } + void waitForHttpReady() { + server->httpReady.promise.wait(server->ioContext.waitScope); + } + + kj::Network& network() { return server->ioContext.provider->getNetwork(); } TempDir tempDir; MockLaminar mockLaminar; Server* server; @@ -86,3 +101,41 @@ TEST_F(ServerTest, RpcTrigger) { EXPECT_CALL(mockLaminar, queueJob("foo", ParamMap())).Times(testing::Exactly(1)); req.send().wait(ws()); } + +// Tests that agressively closed websockets are properly removed +// and will not be attempted to be contacted again +TEST_F(ServerTest, HttpWebsocketRST) { + waitForHttpReady(); + + // TODO: generalize + constexpr const char* WS = + "GET / HTTP/1.1\r\n" + "Host: localhost:8080\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: GTFmrUCM9N6B32LdDE3Rzw==\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n"; + + static char buffer[256]; + network().parseAddress("localhost:8080").then([this](kj::Own&& addr){ + return addr->connect().attach(kj::mv(addr)).then([this](kj::Own&& stream){ + return stream->write(WS, strlen(WS)).then(kj::mvCapture(kj::mv(stream), [this](kj::Own&& stream){ + // Read the websocket header response, ensure the client has been registered + return stream->tryRead(buffer, 64, 256).then(kj::mvCapture(kj::mv(stream), [this](kj::Own&& stream, size_t sz){ + EXPECT_LE(64, sz); + EXPECT_NE(nullptr, mockLaminar.client); + // agressively abort the connection + struct linger so_linger; + so_linger.l_onoff = 1; + so_linger.l_linger = 0; + stream->setsockopt(SOL_SOCKET, SO_LINGER, &so_linger, sizeof(so_linger)); + return kj::Promise(kj::READY_NOW); + })); + })); + }); + }).wait(ws()); + ws().poll(); + // Expect that the client has been cleared. If it has not, Laminar could + // try to write to the closed file descriptor, causing an exception + EXPECT_EQ(nullptr, mockLaminar.client); +}