import {Events as BackboneEvents} from 'backbone'; import {promisifyAll} from 'bluebird'; import {assert} from 'chai'; import * as http from 'http'; import {AddressInfo} from 'net'; import * as sinon from 'sinon'; import * as path from 'path'; import * as tmp from 'tmp'; import {GristWSConnection, GristWSSettings} from 'app/client/components/GristWSConnection'; import {GristClientSocket, GristClientSocketOptions} from 'app/client/components/GristClientSocket'; import {Comm as ClientComm} from 'app/client/components/Comm'; import * as log from 'app/client/lib/log'; import {Comm} from 'app/server/lib/Comm'; import {Client, ClientMethod} from 'app/server/lib/Client'; import {CommClientConnect} from 'app/common/CommTypes'; import {delay} from 'app/common/delay'; import {isLongerThan} from 'app/common/gutil'; import {fromCallback, listenPromise} from 'app/server/lib/serverUtils'; import {Sessions} from 'app/server/lib/Sessions'; import {TcpForwarder} from 'test/server/tcpForwarder'; import * as testUtils from 'test/server/testUtils'; import * as session from '@gristlabs/express-session'; import { Hosts, RequestOrgInfo } from 'app/server/lib/extractOrg'; const SQLiteStore = require('@gristlabs/connect-sqlite3')(session); promisifyAll(SQLiteStore.prototype); // Just enough implementation of Hosts to be able to fake using a custom host. class FakeHosts { public isCustomHost = false; public get asHosts() { return this as unknown as Hosts; } public async addOrgInfo<T extends http.IncomingMessage>(req: T): Promise<T & RequestOrgInfo> { return Object.assign(req, { isCustomHost: this.isCustomHost, org: "example", url: req.url! }); } } describe('Comm', function() { testUtils.setTmpLogLevel(process.env.VERBOSE ? 'debug' : 'warn'); // Allow test cases to register afterEach callbacks here for easier cleanup. const cleanup: Array<() => void> = []; let server: http.Server; let sessions: Sessions; let fakeHosts: FakeHosts; let comm: Comm|null = null; const sandbox = sinon.createSandbox(); before(async function() { const sessionDB = tmp.fileSync(); const sessionStore = new SQLiteStore({ dir: path.dirname(sessionDB.name), db: path.basename(sessionDB.name), table: 'sessions' }); // Random string to use for the test session secret. const sessionSecret = 'xkwriagasaqystubgkkbwhqtyyncwqjemyncnmetjpkiwtfzvllejpfneldmoyri'; sessions = new Sessions(sessionSecret, sessionStore); }); function startComm(methods: {[name: string]: ClientMethod}) { server = http.createServer(); fakeHosts = new FakeHosts(); comm = new Comm(server, {sessions, hosts: fakeHosts.asHosts}); comm.registerMethods(methods); return listenPromise(server.listen(0, 'localhost')); } async function stopComm() { comm?.destroyAllClients(); await comm?.testServerShutdown(); await fromCallback(cb => { server.close(cb); server.closeAllConnections(); }); } const assortedMethods: {[name: string]: ClientMethod} = { methodSync: async function(client, x, y) { return {x: x, y: y, name: "methodSync"}; }, methodError: async function(client, x, y) { throw new Error("fake error"); }, methodAsync: async function(client, x, y) { await delay(20); return {x: x, y: y, name: "methodAsync"}; }, methodSend: async function(client, docFD) { void(client.sendMessage({docFD, type: "fooType" as any, data: "foo"})); void(client.sendMessage({docFD, type: "barType" as any, data: "bar"})); } }; afterEach(async function() { // Run the cleanup callbacks registered in cleanup(). await Promise.all(cleanup.splice(0).map(callback => callback())); sandbox.restore(); }); function getMessages(ws: GristClientSocket, count: number): Promise<any[]> { return new Promise((resolve, reject) => { const messages: object[] = []; ws.onerror = (err) => { ws.onmessage = null; reject(err); }; ws.onmessage = (data: string) => { messages.push(JSON.parse(data)); if (messages.length >= count) { ws.onerror = null; ws.onmessage = null; resolve(messages); } }; }); } /** * Returns a promise for the connected websocket. */ function connect(options?: GristClientSocketOptions): Promise<GristClientSocket> { const ws = new GristClientSocket('ws://localhost:' + (server.address() as AddressInfo).port, options); return new Promise<GristClientSocket>((resolve, reject) => { ws.onopen = () => { ws.onerror = null; resolve(ws); }; ws.onerror = (err) => { ws.onopen = null; reject(err); }; }); } describe("server methods", function() { let ws: GristClientSocket; beforeEach(async function() { await startComm(assortedMethods); ws = await connect(); await getMessages(ws, 1); // consume a clientConnect message }); afterEach(async function() { await stopComm(); }); it("should return data for valid calls", async function() { ws.send(JSON.stringify({reqId: 10, method: "methodSync", args: ["hello", "world"]})); const messages = await getMessages(ws, 1); const resp = messages[0]; assert.equal(resp.reqId, 10, `Messages received instead: ${JSON.stringify(messages)}`); assert.deepEqual(resp.data, {x: "hello", y: "world", name: "methodSync"}); }); it("should work for async calls", async function() { ws.send(JSON.stringify({reqId: 20, method: "methodAsync", args: ["hello", "world"]})); const messages = await getMessages(ws, 1); const resp = messages[0]; assert.equal(resp.reqId, 20); assert.deepEqual(resp.data, {x: "hello", y: "world", name: "methodAsync"}); }); it("should work for out-of-order calls", async function() { ws.send(JSON.stringify({reqId: 30, method: "methodAsync", args: [1, 2]})); ws.send(JSON.stringify({reqId: 31, method: "methodSync", args: [3, 4]})); const messages = await getMessages(ws, 2); assert.equal(messages[0].reqId, 31); assert.deepEqual(messages[0].data, {x: 3, y: 4, name: "methodSync"}); assert.equal(messages[1].reqId, 30); assert.deepEqual(messages[1].data, {x: 1, y: 2, name: "methodAsync"}); }); it("should return error when a call fails", async function() { const logMessages = await testUtils.captureLog('warn', async () => { ws.send(JSON.stringify({reqId: 40, method: "methodError", args: ["hello"]})); const messages = await getMessages(ws, 1); const resp = messages[0]; assert.equal(resp.reqId, 40); assert.equal(resp.data, undefined); assert(resp.error.indexOf('fake error') >= 0); }); testUtils.assertMatchArray(logMessages, [ /^warn: Client.* Error: fake error[^]+at methodError/, /^warn: Client.* responding to .* ERROR fake error/, ]); }); it("should return error for unknown methods", async function() { const logMessages = await testUtils.captureLog('warn', async () => { ws.send(JSON.stringify({reqId: 50, method: "someUnknownMethod", args: []})); const messages = await getMessages(ws, 1); const resp = messages[0]; assert.equal(resp.reqId, 50); assert.equal(resp.data, undefined); assert(resp.error.indexOf('Unknown method') >= 0); }); testUtils.assertMatchArray(logMessages, [ /^warn: Client.* Unknown method.*someUnknownMethod/ ]); }); it('should only log warning for malformed JSON data', async function () { const logMessages = await testUtils.captureLog('warn', async () => { ws.send('foobar'); }, {waitForFirstLog: true}); testUtils.assertMatchArray(logMessages, [ /^warn: Client.* Unexpected token.*/ ]); }); it('should log warning when null value is passed', async function () { const logMessages = await testUtils.captureLog('warn', async () => { ws.send('null'); }, {waitForFirstLog: true}); testUtils.assertMatchArray(logMessages, [ /^warn: Client.*Cannot read properties of null*/ ]); }); it("should support app-level events correctly", async function() { comm!.broadcastMessage('fooType' as any, 'hello'); comm!.broadcastMessage('barType' as any, 'world'); const messages = await getMessages(ws, 2); assert.equal(messages[0].type, 'fooType'); assert.equal(messages[0].data, 'hello'); assert.equal(messages[1].type, 'barType'); assert.equal(messages[1].data, 'world'); }); it("should support doc-level events", async function() { ws.send(JSON.stringify({reqId: 60, method: "methodSend", args: [13]})); const messages = await getMessages(ws, 3); assert.equal(messages[0].type, 'fooType'); assert.equal(messages[0].data, 'foo'); assert.equal(messages[0].docFD, 13); assert.equal(messages[1].type, 'barType'); assert.equal(messages[1].data, 'bar'); assert.equal(messages[1].docFD, 13); assert.equal(messages[2].reqId, 60); assert.equal(messages[2].data, undefined); assert.equal(messages[2].error, undefined); }); }); describe("reconnects", function() { const docId = "docId_abc"; this.timeout(10000); // Helper to set up a Comm server, a Comm client, and a forwarder between them that allows // simulating disconnects. async function startManagedConnection(methods: {[name: string]: ClientMethod}) { // Start the server Comm, providing a few methods. await startComm(methods); cleanup.push(() => stopComm()); // Create a forwarder, which we use to test disconnects. const serverPort = (server.address() as AddressInfo).port; const forwarder = new TcpForwarder(serverPort); const forwarderPort = await forwarder.pickForwarderPort(); await forwarder.connect(); cleanup.push(() => forwarder.disconnect()); // To create a client-side Comm object, we need to trick GristWSConnection's check for // whether there is a worker to connect to. (global as any).window = undefined; sandbox.stub(global as any, 'window').value({gristConfig: {getWorker: 'STUB', assignmentId: docId}}); // We also need to get GristWSConnection to use a custom GristWSSettings object, and to // connect to the forwarder's port. const docWorkerUrl = `http://localhost:${forwarderPort}`; const settings = getWSSettings(docWorkerUrl); const stubGristWsCreate = sandbox.stub(GristWSConnection, 'create').callsFake(function(this: any, owner) { return (stubGristWsCreate as any).wrappedMethod.call(this, owner, settings); }); // Cast with BackboneEvents to allow using cliComm.on(). const cliComm = ClientComm.create() as ClientComm & BackboneEvents; cliComm.useDocConnection(docId); cleanup.push(() => cliComm.dispose()); // Dispose after this test ends. return {cliComm, forwarder}; } it('should forward calls on a normal connection', async function() { const {cliComm} = await startManagedConnection(assortedMethods); // A couple of regular requests. const resp1 = await cliComm._makeRequest(null, null, "methodSync", "foo", 1); assert.deepEqual(resp1, {name: 'methodSync', x: "foo", y: 1}); const resp2 = await cliComm._makeRequest(null, null, "methodAsync", "foo", 2); assert.deepEqual(resp2, {name: 'methodAsync', x: "foo", y: 2}); // Try calls that return out of order. const [resp3, resp4] = await Promise.all([ cliComm._makeRequest(null, null, "methodAsync", "foo", 3), cliComm._makeRequest(null, null, "methodSync", "foo", 4), ]); assert.deepEqual(resp3, {name: 'methodAsync', x: "foo", y: 3}); assert.deepEqual(resp4, {name: 'methodSync', x: "foo", y: 4}); }); it('should forward missed responses when a server send fails', async function() { await testMissedResponses(true); }); it('should forward missed responses when a server send is queued', async function() { await testMissedResponses(false); }); async function testMissedResponses(sendShouldFail: boolean) { let failedSendCount = 0; const {cliComm, forwarder} = await startManagedConnection({...assortedMethods, // An extra method that simulates a lost connection on server side prior to response. testDisconnect: async function(client, x, y) { setTimeout(() => forwarder.disconnectServerSide(), 0); if (!sendShouldFail) { // Add a delay to let the 'close' event get noticed first. await delay(20); } return {x: x, y: y, name: "testDisconnect"}; }, }); const resp1 = await cliComm._makeRequest(null, null, "methodSync", "foo", 1); assert.deepEqual(resp1, {name: 'methodSync', x: "foo", y: 1}); if (sendShouldFail) { // In Node 18, the socket is closed during the call to 'testDisconnect'. // In prior versions of Node, the socket was still disconnecting. // This test is sensitive to timing and only passes in the latter, unless we // stub the method below to produce similar behavior in the former. sandbox.stub(Client.prototype as any, '_sendToWebsocket') .onFirstCall() .callsFake(() => { failedSendCount += 1; throw new Error('WebSocket is not open'); }) .callThrough(); } // Make more calls, with a disconnect before they return. The server should queue up responses. const resp2Promise = cliComm._makeRequest(null, null, "testDisconnect", "foo", 2); const resp3Promise = cliComm._makeRequest(null, null, "methodAsync", "foo", 3); assert.equal(await isLongerThan(resp2Promise, 250), true); // Once we reconnect, the response should arrive. await forwarder.connect(); assert.deepEqual(await resp2Promise, {name: 'testDisconnect', x: "foo", y: 2}); assert.deepEqual(await resp3Promise, {name: 'methodAsync', x: "foo", y: 3}); // Check that we saw the situation we were hoping to test. assert.equal(failedSendCount, sendShouldFail ? 1 : 0, 'Expected to see a failed send'); } it("should receive all server messages (small) in order when send doesn't fail", async function() { await testSendOrdering({noFailedSend: true, useSmallMsgs: true}); }); it("should receive all server messages (large) in order when send doesn't fail", async function() { await testSendOrdering({noFailedSend: true}); }); it("should order server messages correctly with failedSend before close", async function() { await testSendOrdering({closeHappensFirst: false}); }); it("should order server messages correctly with close before failedSend", async function() { await testSendOrdering({closeHappensFirst: true}); }); async function testSendOrdering( options: {noFailedSend?: boolean, closeHappensFirst?: boolean, useSmallMsgs?: boolean} ) { const eventsSeen: Array<'failedSend'|'close'> = []; // Server-side Client object. let ssClient!: Client; const {cliComm, forwarder} = await startManagedConnection(assortedMethods); // Intercept the call to _onClose to know when it occurs, since we are trying to hit a // situation where 'close' and 'failedSend' events happen in either order. const stubOnClose: any = sandbox.stub(Client.prototype as any, '_onClose') .callsFake(function(this: Client) { eventsSeen.push('close'); return stubOnClose.wrappedMethod.apply(this, arguments); }); // Intercept calls to client.sendMessage(), to know when it fails, and possibly to delay the // failures to hit a particular order in which 'close' and 'failedSend' events are seen by // Client.ts. This is the only reliable way I found to reproduce this order of events. const stubSendToWebsocket: any = sandbox.stub(Client.prototype as any, '_sendToWebsocket') .callsFake(async function(this: Client) { try { return await stubSendToWebsocket.wrappedMethod.apply(this, arguments); } catch (err) { if (options.closeHappensFirst) { await delay(100); } eventsSeen.push('failedSend'); throw err; } }); // Watch the events received all the way on the client side. const eventSpy = sinon.spy(); const clientConnectSpy = sinon.spy(); cliComm.on('docUserAction', eventSpy); cliComm.on('clientConnect', clientConnectSpy); // We need to simulate an important property of the browser client: when needReload is set // in the clientConnect message, we are expected to reload the app. In the test, we replace // the GristWSConnection. cliComm.on('clientConnect', async (msg: CommClientConnect) => { ssClient = comm!.getClient(msg.clientId); if (msg.needReload) { await delay(0); cliComm.releaseDocConnection(docId); cliComm.useDocConnection(docId); } }); // Wait for a connect call, which we rely on to get access to the Client object (ssClient). await waitForCondition(() => (clientConnectSpy.callCount > 0), 1000); // Send large buffers, to fill up the socket's buffers to get it to block. const data = "x".repeat(options.useSmallMsgs ? 100_000 : 10_000_000); const makeMessage = (n: number) => ({type: 'docUserAction', n, data}); let n = 0; const sendPromises: Array<Promise<void>> = []; const sendNextMessage= () => sendPromises.push(ssClient.sendMessage(makeMessage(n++) as any)); await testUtils.captureLog('warn', async () => { // Make a few sends. These are big enough not to return immediately. Keep the first two // successful (by awaiting them). And keep a few more that will fail. This is to test the // ordering of successful and failed messages that may be missed. sendNextMessage(); sendNextMessage(); sendNextMessage(); await sendPromises[0]; await sendPromises[1]; sendNextMessage(); sendNextMessage(); // Forcibly close the forwarder, so that the server sees a 'close' event. But first let // some messages get to the client. In case we want all sends to succeed, let them all get // forwarded before disconnect; otherwise, disconnect after 2 are fowarded. const countToWaitFor = options.noFailedSend ? 5 : 2; await waitForCondition(() => eventSpy.callCount >= countToWaitFor); void(forwarder.disconnectServerSide()); // Wait less than the delay that we add for delayFailedSend, and send another message. There // used to be a bug that such a message would get recorded into missedMessages out of order. await delay(50); sendNextMessage(); // Now reconnect, and collect the messages that the client sees. clientConnectSpy.resetHistory(); await forwarder.connect(); // Wait until we get a clientConnect message that does not require a reload. (Except with // noFailedSend, the first one would have needReload set; and after the reconnect, we should // get one without.) await waitForCondition(() => (clientConnectSpy.callCount > 0 && clientConnectSpy.lastCall.args[0].needReload === false), 3000); }); // This test helper is used for 3 different situations. Check that we observed that // situations we were trying to hit. if (options.noFailedSend) { if (options.useSmallMsgs) { assert.deepEqual(eventsSeen, ['close']); } else { // Make sure to have waited long enough for the 'close' event we may have delayed await delay(20); // Large messages now cause a send to fail, after filling up buffer, and close the socket. assert.deepEqual(eventsSeen, ['close', 'close']); } } else if (options.closeHappensFirst) { assert.equal(eventsSeen[0], 'close'); assert.include(eventsSeen, 'failedSend'); } else { assert.equal(eventsSeen[0], 'failedSend'); assert.include(eventsSeen, 'close'); } // After a successful reconnect, subsequent calls should work normally. assert.deepEqual(await cliComm._makeRequest(null, null, "methodSync", 1, 2), {name: 'methodSync', x: 1, y: 2}); // Check that all the received messages are in order. const messageNums = eventSpy.getCalls().map(call => call.args[0].n); assert.isAtLeast(messageNums.length, 2); assert.deepEqual(messageNums, nrange(0, messageNums.length), `Unexpected message sequence ${JSON.stringify(messageNums)}`); // Subsequent messages should work normally too. eventSpy.resetHistory(); sendNextMessage(); await waitForCondition(() => eventSpy.callCount > 0); assert.deepEqual(eventSpy.getCalls().map(call => call.args[0].n), [n - 1]); } }); describe("Allowed Origin", function() { beforeEach(async function () { await startComm(assortedMethods); }); afterEach(async function() { await stopComm(); }); async function checkOrigin(headers: { origin: string, host: string }, allowed: boolean) { const promise = connect({ headers }); if (allowed) { await assert.isFulfilled(promise, `${headers.host} should allow ${headers.origin}`); } else { await assert.isRejected(promise, /.*/, `${headers.host} should reject ${headers.origin}`); } } it('origin should match base domain of host', async () => { await checkOrigin({origin: "https://www.toto.com", host: "worker.example.com"}, false); await checkOrigin({origin: "https://badexample.com", host: "worker.example.com"}, false); await checkOrigin({origin: "https://bad.com/example.com", host: "worker.example.com"}, false); await checkOrigin({origin: "https://front.example.com", host: "worker.example.com"}, true); await checkOrigin({origin: "https://front.example.com:3000", host: "worker.example.com"}, true); await checkOrigin({origin: "https://example.com", host: "example.com"}, true); }); it('with custom domains, origin should match the full hostname', async () => { fakeHosts.isCustomHost = true; // For a request to a custom domain, the full hostname must match. await checkOrigin({origin: "https://front.example.com", host: "worker.example.com"}, false); await checkOrigin({origin: "https://front.example.com", host: "front.example.com"}, true); await checkOrigin({origin: "https://front.example.com:3000", host: "front.example.com"}, true); }); }); }); // Waits for condFunc() to return true, for up to timeoutMs milliseconds, sleeping for stepMs // between checks. Returns if succeeded, throws if failed. async function waitForCondition(condFunc: () => boolean, timeoutMs = 1000, stepMs = 10): Promise<void> { const end = Date.now() + timeoutMs; while (Date.now() < end) { if (condFunc()) { return; } await delay(stepMs); } throw new Error(`Condition not met after ${timeoutMs}ms: ${condFunc.toString()}`); } // Returns a range of count consecutive numbers starting with start. function nrange(start: number, count: number): number[] { return Array.from(Array(count), (_, i) => start + i); } // Returns a GristWSSettings object, for use with GristWSConnection. function getWSSettings(docWorkerUrl: string): GristWSSettings { let clientId: string = 'clientid-abc'; let counter: number = 0; return { makeWebSocket(url: string): any { return new GristClientSocket(url); }, async getTimezone() { return 'UTC'; }, getPageUrl() { return "http://localhost"; }, async getDocWorkerUrl() { return docWorkerUrl; }, getClientId(did: any) { return clientId; }, getUserSelector() { return ''; }, updateClientId(did: string, cid: string) { clientId = cid; }, advanceCounter(): string { return String(counter++); }, log() { (log as any).debug(...arguments); }, warn() { (log as any).warn(...arguments); }, }; }