gristlabs_grist-core/test/server/Comm.ts

591 lines
24 KiB
TypeScript
Raw Permalink Normal View History

import {Events as BackboneEvents} from 'backbone';
import {promisifyAll} from 'bluebird';
import {assert} from 'chai';
import * as http from 'http';
import {AddressInfo} from 'net';
import * as sinon from 'sinon';
import * as path from 'path';
import * as tmp from 'tmp';
import {GristWSConnection, GristWSSettings} from 'app/client/components/GristWSConnection';
import {GristClientSocket, GristClientSocketOptions} from 'app/client/components/GristClientSocket';
import {Comm as ClientComm} from 'app/client/components/Comm';
import * as log from 'app/client/lib/log';
import {Comm} from 'app/server/lib/Comm';
import {Client, ClientMethod} from 'app/server/lib/Client';
import {CommClientConnect} from 'app/common/CommTypes';
import {delay} from 'app/common/delay';
import {isLongerThan} from 'app/common/gutil';
import {fromCallback, listenPromise} from 'app/server/lib/serverUtils';
import {Sessions} from 'app/server/lib/Sessions';
import {TcpForwarder} from 'test/server/tcpForwarder';
import * as testUtils from 'test/server/testUtils';
import * as session from '@gristlabs/express-session';
import { Hosts, RequestOrgInfo } from 'app/server/lib/extractOrg';
const SQLiteStore = require('@gristlabs/connect-sqlite3')(session);
promisifyAll(SQLiteStore.prototype);
// Just enough implementation of Hosts to be able to fake using a custom host.
class FakeHosts {
public isCustomHost = false;
public get asHosts() { return this as unknown as Hosts; }
public async addOrgInfo<T extends http.IncomingMessage>(req: T): Promise<T & RequestOrgInfo> {
return Object.assign(req, {
isCustomHost: this.isCustomHost,
org: "example",
url: req.url!
});
}
}
describe('Comm', function() {
testUtils.setTmpLogLevel(process.env.VERBOSE ? 'debug' : 'warn');
// Allow test cases to register afterEach callbacks here for easier cleanup.
const cleanup: Array<() => void> = [];
let server: http.Server;
let sessions: Sessions;
let fakeHosts: FakeHosts;
let comm: Comm|null = null;
const sandbox = sinon.createSandbox();
before(async function() {
const sessionDB = tmp.fileSync();
const sessionStore = new SQLiteStore({
dir: path.dirname(sessionDB.name),
db: path.basename(sessionDB.name),
table: 'sessions'
});
// Random string to use for the test session secret.
const sessionSecret = 'xkwriagasaqystubgkkbwhqtyyncwqjemyncnmetjpkiwtfzvllejpfneldmoyri';
sessions = new Sessions(sessionSecret, sessionStore);
});
function startComm(methods: {[name: string]: ClientMethod}) {
server = http.createServer();
fakeHosts = new FakeHosts();
comm = new Comm(server, {sessions, hosts: fakeHosts.asHosts});
comm.registerMethods(methods);
return listenPromise(server.listen(0, 'localhost'));
}
async function stopComm() {
comm?.destroyAllClients();
await comm?.testServerShutdown();
await fromCallback(cb => {
server.close(cb);
server.closeAllConnections();
});
}
const assortedMethods: {[name: string]: ClientMethod} = {
methodSync: async function(client, x, y) {
return {x: x, y: y, name: "methodSync"};
},
methodError: async function(client, x, y) {
throw new Error("fake error");
},
methodAsync: async function(client, x, y) {
await delay(20);
return {x: x, y: y, name: "methodAsync"};
},
methodSend: async function(client, docFD) {
void(client.sendMessage({docFD, type: "fooType" as any, data: "foo"}));
void(client.sendMessage({docFD, type: "barType" as any, data: "bar"}));
}
};
afterEach(async function() {
// Run the cleanup callbacks registered in cleanup().
await Promise.all(cleanup.splice(0).map(callback => callback()));
sandbox.restore();
});
function getMessages(ws: GristClientSocket, count: number): Promise<any[]> {
return new Promise((resolve, reject) => {
const messages: object[] = [];
ws.onerror = (err) => {
ws.onmessage = null;
reject(err);
};
ws.onmessage = (data: string) => {
messages.push(JSON.parse(data));
if (messages.length >= count) {
ws.onerror = null;
ws.onmessage = null;
resolve(messages);
}
};
});
}
/**
* Returns a promise for the connected websocket.
*/
function connect(options?: GristClientSocketOptions): Promise<GristClientSocket> {
const ws = new GristClientSocket('ws://localhost:' + (server.address() as AddressInfo).port, options);
return new Promise<GristClientSocket>((resolve, reject) => {
ws.onopen = () => {
ws.onerror = null;
resolve(ws);
};
ws.onerror = (err) => {
ws.onopen = null;
reject(err);
};
});
}
describe("server methods", function() {
let ws: GristClientSocket;
beforeEach(async function() {
await startComm(assortedMethods);
ws = await connect();
await getMessages(ws, 1); // consume a clientConnect message
});
afterEach(async function() {
await stopComm();
});
it("should return data for valid calls", async function() {
ws.send(JSON.stringify({reqId: 10, method: "methodSync", args: ["hello", "world"]}));
const messages = await getMessages(ws, 1);
const resp = messages[0];
assert.equal(resp.reqId, 10, `Messages received instead: ${JSON.stringify(messages)}`);
assert.deepEqual(resp.data, {x: "hello", y: "world", name: "methodSync"});
});
it("should work for async calls", async function() {
ws.send(JSON.stringify({reqId: 20, method: "methodAsync", args: ["hello", "world"]}));
const messages = await getMessages(ws, 1);
const resp = messages[0];
assert.equal(resp.reqId, 20);
assert.deepEqual(resp.data, {x: "hello", y: "world", name: "methodAsync"});
});
it("should work for out-of-order calls", async function() {
ws.send(JSON.stringify({reqId: 30, method: "methodAsync", args: [1, 2]}));
ws.send(JSON.stringify({reqId: 31, method: "methodSync", args: [3, 4]}));
const messages = await getMessages(ws, 2);
assert.equal(messages[0].reqId, 31);
assert.deepEqual(messages[0].data, {x: 3, y: 4, name: "methodSync"});
assert.equal(messages[1].reqId, 30);
assert.deepEqual(messages[1].data, {x: 1, y: 2, name: "methodAsync"});
});
it("should return error when a call fails", async function() {
const logMessages = await testUtils.captureLog('warn', async () => {
ws.send(JSON.stringify({reqId: 40, method: "methodError", args: ["hello"]}));
const messages = await getMessages(ws, 1);
const resp = messages[0];
assert.equal(resp.reqId, 40);
assert.equal(resp.data, undefined);
assert(resp.error.indexOf('fake error') >= 0);
});
testUtils.assertMatchArray(logMessages, [
/^warn: Client.* Error: fake error[^]+at methodError/,
/^warn: Client.* responding to .* ERROR fake error/,
]);
});
it("should return error for unknown methods", async function() {
const logMessages = await testUtils.captureLog('warn', async () => {
ws.send(JSON.stringify({reqId: 50, method: "someUnknownMethod", args: []}));
const messages = await getMessages(ws, 1);
const resp = messages[0];
assert.equal(resp.reqId, 50);
assert.equal(resp.data, undefined);
assert(resp.error.indexOf('Unknown method') >= 0);
});
testUtils.assertMatchArray(logMessages, [
/^warn: Client.* Unknown method.*someUnknownMethod/
]);
});
it('should only log warning for malformed JSON data', async function () {
const logMessages = await testUtils.captureLog('warn', async () => {
ws.send('foobar');
}, {waitForFirstLog: true});
testUtils.assertMatchArray(logMessages, [
/^warn: Client.* Unexpected token.*/
]);
});
it('should log warning when null value is passed', async function () {
const logMessages = await testUtils.captureLog('warn', async () => {
ws.send('null');
}, {waitForFirstLog: true});
testUtils.assertMatchArray(logMessages, [
/^warn: Client.*Cannot read properties of null*/
]);
});
it("should support app-level events correctly", async function() {
comm!.broadcastMessage('fooType' as any, 'hello');
comm!.broadcastMessage('barType' as any, 'world');
const messages = await getMessages(ws, 2);
assert.equal(messages[0].type, 'fooType');
assert.equal(messages[0].data, 'hello');
assert.equal(messages[1].type, 'barType');
assert.equal(messages[1].data, 'world');
});
it("should support doc-level events", async function() {
ws.send(JSON.stringify({reqId: 60, method: "methodSend", args: [13]}));
const messages = await getMessages(ws, 3);
assert.equal(messages[0].type, 'fooType');
assert.equal(messages[0].data, 'foo');
assert.equal(messages[0].docFD, 13);
assert.equal(messages[1].type, 'barType');
assert.equal(messages[1].data, 'bar');
assert.equal(messages[1].docFD, 13);
assert.equal(messages[2].reqId, 60);
assert.equal(messages[2].data, undefined);
assert.equal(messages[2].error, undefined);
});
});
describe("reconnects", function() {
const docId = "docId_abc";
this.timeout(10000);
// Helper to set up a Comm server, a Comm client, and a forwarder between them that allows
// simulating disconnects.
async function startManagedConnection(methods: {[name: string]: ClientMethod}) {
// Start the server Comm, providing a few methods.
await startComm(methods);
cleanup.push(() => stopComm());
// Create a forwarder, which we use to test disconnects.
const serverPort = (server.address() as AddressInfo).port;
const forwarder = new TcpForwarder(serverPort);
const forwarderPort = await forwarder.pickForwarderPort();
await forwarder.connect();
cleanup.push(() => forwarder.disconnect());
// To create a client-side Comm object, we need to trick GristWSConnection's check for
// whether there is a worker to connect to.
(global as any).window = undefined;
sandbox.stub(global as any, 'window').value({gristConfig: {getWorker: 'STUB', assignmentId: docId}});
// We also need to get GristWSConnection to use a custom GristWSSettings object, and to
// connect to the forwarder's port.
const docWorkerUrl = `http://localhost:${forwarderPort}`;
const settings = getWSSettings(docWorkerUrl);
const stubGristWsCreate = sandbox.stub(GristWSConnection, 'create').callsFake(function(this: any, owner) {
return (stubGristWsCreate as any).wrappedMethod.call(this, owner, settings);
});
// Cast with BackboneEvents to allow using cliComm.on().
const cliComm = ClientComm.create() as ClientComm & BackboneEvents;
cliComm.useDocConnection(docId);
cleanup.push(() => cliComm.dispose()); // Dispose after this test ends.
return {cliComm, forwarder};
}
it('should forward calls on a normal connection', async function() {
const {cliComm} = await startManagedConnection(assortedMethods);
// A couple of regular requests.
const resp1 = await cliComm._makeRequest(null, null, "methodSync", "foo", 1);
assert.deepEqual(resp1, {name: 'methodSync', x: "foo", y: 1});
const resp2 = await cliComm._makeRequest(null, null, "methodAsync", "foo", 2);
assert.deepEqual(resp2, {name: 'methodAsync', x: "foo", y: 2});
// Try calls that return out of order.
const [resp3, resp4] = await Promise.all([
cliComm._makeRequest(null, null, "methodAsync", "foo", 3),
cliComm._makeRequest(null, null, "methodSync", "foo", 4),
]);
assert.deepEqual(resp3, {name: 'methodAsync', x: "foo", y: 3});
assert.deepEqual(resp4, {name: 'methodSync', x: "foo", y: 4});
});
it('should forward missed responses when a server send fails', async function() {
await testMissedResponses(true);
});
it('should forward missed responses when a server send is queued', async function() {
await testMissedResponses(false);
});
async function testMissedResponses(sendShouldFail: boolean) {
let failedSendCount = 0;
const {cliComm, forwarder} = await startManagedConnection({...assortedMethods,
// An extra method that simulates a lost connection on server side prior to response.
testDisconnect: async function(client, x, y) {
setTimeout(() => forwarder.disconnectServerSide(), 0);
if (!sendShouldFail) {
// Add a delay to let the 'close' event get noticed first.
await delay(20);
}
return {x: x, y: y, name: "testDisconnect"};
},
});
const resp1 = await cliComm._makeRequest(null, null, "methodSync", "foo", 1);
assert.deepEqual(resp1, {name: 'methodSync', x: "foo", y: 1});
if (sendShouldFail) {
// In Node 18, the socket is closed during the call to 'testDisconnect'.
// In prior versions of Node, the socket was still disconnecting.
// This test is sensitive to timing and only passes in the latter, unless we
// stub the method below to produce similar behavior in the former.
sandbox.stub(Client.prototype as any, '_sendToWebsocket')
.onFirstCall()
.callsFake(() => {
failedSendCount += 1;
throw new Error('WebSocket is not open');
})
.callThrough();
}
// Make more calls, with a disconnect before they return. The server should queue up responses.
const resp2Promise = cliComm._makeRequest(null, null, "testDisconnect", "foo", 2);
const resp3Promise = cliComm._makeRequest(null, null, "methodAsync", "foo", 3);
assert.equal(await isLongerThan(resp2Promise, 250), true);
// Once we reconnect, the response should arrive.
await forwarder.connect();
assert.deepEqual(await resp2Promise, {name: 'testDisconnect', x: "foo", y: 2});
assert.deepEqual(await resp3Promise, {name: 'methodAsync', x: "foo", y: 3});
// Check that we saw the situation we were hoping to test.
assert.equal(failedSendCount, sendShouldFail ? 1 : 0, 'Expected to see a failed send');
}
it("should receive all server messages (small) in order when send doesn't fail", async function() {
await testSendOrdering({noFailedSend: true, useSmallMsgs: true});
});
it("should receive all server messages (large) in order when send doesn't fail", async function() {
await testSendOrdering({noFailedSend: true});
});
it("should order server messages correctly with failedSend before close", async function() {
await testSendOrdering({closeHappensFirst: false});
});
it("should order server messages correctly with close before failedSend", async function() {
await testSendOrdering({closeHappensFirst: true});
});
async function testSendOrdering(
options: {noFailedSend?: boolean, closeHappensFirst?: boolean, useSmallMsgs?: boolean}
) {
const eventsSeen: Array<'failedSend'|'close'> = [];
// Server-side Client object.
let ssClient!: Client;
const {cliComm, forwarder} = await startManagedConnection(assortedMethods);
// Intercept the call to _onClose to know when it occurs, since we are trying to hit a
// situation where 'close' and 'failedSend' events happen in either order.
const stubOnClose: any = sandbox.stub(Client.prototype as any, '_onClose')
.callsFake(function(this: Client) {
eventsSeen.push('close');
return stubOnClose.wrappedMethod.apply(this, arguments);
});
// Intercept calls to client.sendMessage(), to know when it fails, and possibly to delay the
// failures to hit a particular order in which 'close' and 'failedSend' events are seen by
// Client.ts. This is the only reliable way I found to reproduce this order of events.
const stubSendToWebsocket: any = sandbox.stub(Client.prototype as any, '_sendToWebsocket')
.callsFake(async function(this: Client) {
try {
return await stubSendToWebsocket.wrappedMethod.apply(this, arguments);
} catch (err) {
if (options.closeHappensFirst) { await delay(100); }
eventsSeen.push('failedSend');
throw err;
}
});
// Watch the events received all the way on the client side.
const eventSpy = sinon.spy();
const clientConnectSpy = sinon.spy();
cliComm.on('docUserAction', eventSpy);
cliComm.on('clientConnect', clientConnectSpy);
// We need to simulate an important property of the browser client: when needReload is set
// in the clientConnect message, we are expected to reload the app. In the test, we replace
// the GristWSConnection.
cliComm.on('clientConnect', async (msg: CommClientConnect) => {
ssClient = comm!.getClient(msg.clientId);
if (msg.needReload) {
await delay(0);
cliComm.releaseDocConnection(docId);
cliComm.useDocConnection(docId);
}
});
// Wait for a connect call, which we rely on to get access to the Client object (ssClient).
await waitForCondition(() => (clientConnectSpy.callCount > 0), 1000);
// Send large buffers, to fill up the socket's buffers to get it to block.
const data = "x".repeat(options.useSmallMsgs ? 100_000 : 10_000_000);
const makeMessage = (n: number) => ({type: 'docUserAction', n, data});
let n = 0;
const sendPromises: Array<Promise<void>> = [];
const sendNextMessage= () => sendPromises.push(ssClient.sendMessage(makeMessage(n++) as any));
await testUtils.captureLog('warn', async () => {
// Make a few sends. These are big enough not to return immediately. Keep the first two
// successful (by awaiting them). And keep a few more that will fail. This is to test the
// ordering of successful and failed messages that may be missed.
sendNextMessage();
sendNextMessage();
sendNextMessage();
await sendPromises[0];
await sendPromises[1];
sendNextMessage();
sendNextMessage();
// Forcibly close the forwarder, so that the server sees a 'close' event. But first let
// some messages get to the client. In case we want all sends to succeed, let them all get
// forwarded before disconnect; otherwise, disconnect after 2 are fowarded.
const countToWaitFor = options.noFailedSend ? 5 : 2;
await waitForCondition(() => eventSpy.callCount >= countToWaitFor);
void(forwarder.disconnectServerSide());
// Wait less than the delay that we add for delayFailedSend, and send another message. There
// used to be a bug that such a message would get recorded into missedMessages out of order.
await delay(50);
sendNextMessage();
// Now reconnect, and collect the messages that the client sees.
clientConnectSpy.resetHistory();
await forwarder.connect();
// Wait until we get a clientConnect message that does not require a reload. (Except with
// noFailedSend, the first one would have needReload set; and after the reconnect, we should
// get one without.)
await waitForCondition(() =>
(clientConnectSpy.callCount > 0 && clientConnectSpy.lastCall.args[0].needReload === false),
3000);
});
// This test helper is used for 3 different situations. Check that we observed that
// situations we were trying to hit.
if (options.noFailedSend) {
if (options.useSmallMsgs) {
assert.deepEqual(eventsSeen, ['close']);
} else {
// Make sure to have waited long enough for the 'close' event we may have delayed
await delay(20);
// Large messages now cause a send to fail, after filling up buffer, and close the socket.
assert.deepEqual(eventsSeen, ['close', 'close']);
}
} else if (options.closeHappensFirst) {
assert.equal(eventsSeen[0], 'close');
assert.include(eventsSeen, 'failedSend');
} else {
assert.equal(eventsSeen[0], 'failedSend');
assert.include(eventsSeen, 'close');
}
// After a successful reconnect, subsequent calls should work normally.
assert.deepEqual(await cliComm._makeRequest(null, null, "methodSync", 1, 2),
{name: 'methodSync', x: 1, y: 2});
// Check that all the received messages are in order.
const messageNums = eventSpy.getCalls().map(call => call.args[0].n);
assert.isAtLeast(messageNums.length, 2);
assert.deepEqual(messageNums, nrange(0, messageNums.length),
`Unexpected message sequence ${JSON.stringify(messageNums)}`);
// Subsequent messages should work normally too.
eventSpy.resetHistory();
sendNextMessage();
await waitForCondition(() => eventSpy.callCount > 0);
assert.deepEqual(eventSpy.getCalls().map(call => call.args[0].n), [n - 1]);
}
});
describe("Allowed Origin", function() {
beforeEach(async function () {
await startComm(assortedMethods);
});
afterEach(async function() {
await stopComm();
});
async function checkOrigin(headers: { origin: string, host: string }, allowed: boolean) {
const promise = connect({ headers });
if (allowed) {
await assert.isFulfilled(promise, `${headers.host} should allow ${headers.origin}`);
} else {
await assert.isRejected(promise, /.*/, `${headers.host} should reject ${headers.origin}`);
}
}
it('origin should match base domain of host', async () => {
await checkOrigin({origin: "https://www.toto.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://badexample.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://bad.com/example.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://front.example.com", host: "worker.example.com"}, true);
await checkOrigin({origin: "https://front.example.com:3000", host: "worker.example.com"}, true);
await checkOrigin({origin: "https://example.com", host: "example.com"}, true);
});
it('with custom domains, origin should match the full hostname', async () => {
fakeHosts.isCustomHost = true;
// For a request to a custom domain, the full hostname must match.
await checkOrigin({origin: "https://front.example.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://front.example.com", host: "front.example.com"}, true);
await checkOrigin({origin: "https://front.example.com:3000", host: "front.example.com"}, true);
});
});
});
// Waits for condFunc() to return true, for up to timeoutMs milliseconds, sleeping for stepMs
// between checks. Returns if succeeded, throws if failed.
async function waitForCondition(condFunc: () => boolean, timeoutMs = 1000, stepMs = 10): Promise<void> {
const end = Date.now() + timeoutMs;
while (Date.now() < end) {
if (condFunc()) { return; }
await delay(stepMs);
}
throw new Error(`Condition not met after ${timeoutMs}ms: ${condFunc.toString()}`);
}
// Returns a range of count consecutive numbers starting with start.
function nrange(start: number, count: number): number[] {
return Array.from(Array(count), (_, i) => start + i);
}
// Returns a GristWSSettings object, for use with GristWSConnection.
function getWSSettings(docWorkerUrl: string): GristWSSettings {
let clientId: string = 'clientid-abc';
let counter: number = 0;
return {
makeWebSocket(url: string): any { return new GristClientSocket(url); },
async getTimezone() { return 'UTC'; },
getPageUrl() { return "http://localhost"; },
async getDocWorkerUrl() { return docWorkerUrl; },
getClientId(did: any) { return clientId; },
getUserSelector() { return ''; },
updateClientId(did: string, cid: string) { clientId = cid; },
advanceCounter(): string { return String(counter++); },
log() { (log as any).debug(...arguments); },
warn() { (log as any).warn(...arguments); },
};
}