Support HTTP long polling as an alternative to WebSockets (#859)

The motivation for supporting an alternative to WebSockets is that while all browsers supported by Grist offer native WebSocket support, some networking environments do not allow WebSocket traffic.

Engine.IO is used as the underlying implementation of HTTP long polling. The Grist client will first attempt a regular WebSocket connection, using the same protocol and endpoints as before, but fall back to long polling using Engine.IO if the WebSocket connection fails.

Include these changes:
- CORS websocket requests are now rejected as a stronger security measure. This shouldn’t affect anything in practice; but previously it could be possible to make unauthenticated websocket requests from another origin.
- GRIST_HOST variable no longer affects CORS responses (also should not affect anything in practice, as it wasn't serving a useful purpose)
This commit is contained in:
Jonathan Perret
2024-03-28 18:22:20 +01:00
committed by GitHub
parent b4c2562029
commit 96b652fb52
14 changed files with 852 additions and 118 deletions

View File

@@ -4,11 +4,11 @@ import {assert} from 'chai';
import * as http from 'http';
import {AddressInfo} from 'net';
import * as sinon from 'sinon';
import WebSocket from 'ws';
import * as path from 'path';
import * as tmp from 'tmp';
import {GristWSConnection, GristWSSettings} from 'app/client/components/GristWSConnection';
import {GristClientSocket, GristClientSocketOptions} from 'app/client/components/GristClientSocket';
import {Comm as ClientComm} from 'app/client/components/Comm';
import * as log from 'app/client/lib/log';
import {Comm} from 'app/server/lib/Comm';
@@ -21,10 +21,28 @@ import {Sessions} from 'app/server/lib/Sessions';
import {TcpForwarder} from 'test/server/tcpForwarder';
import * as testUtils from 'test/server/testUtils';
import * as session from '@gristlabs/express-session';
import { Hosts, RequestOrgInfo } from 'app/server/lib/extractOrg';
const SQLiteStore = require('@gristlabs/connect-sqlite3')(session);
promisifyAll(SQLiteStore.prototype);
// Just enough implementation of Hosts to be able to fake using a custom host.
class FakeHosts {
public isCustomHost = false;
public get asHosts() { return this as unknown as Hosts; }
public async addOrgInfo<T extends http.IncomingMessage>(req: T): Promise<T & RequestOrgInfo> {
return Object.assign(req, {
isCustomHost: this.isCustomHost,
org: "example",
url: req.url!
});
}
}
describe('Comm', function() {
testUtils.setTmpLogLevel(process.env.VERBOSE ? 'debug' : 'warn');
@@ -34,6 +52,7 @@ describe('Comm', function() {
let server: http.Server;
let sessions: Sessions;
let fakeHosts: FakeHosts;
let comm: Comm|null = null;
const sandbox = sinon.createSandbox();
@@ -51,14 +70,19 @@ describe('Comm', function() {
function startComm(methods: {[name: string]: ClientMethod}) {
server = http.createServer();
comm = new Comm(server, {sessions});
fakeHosts = new FakeHosts();
comm = new Comm(server, {sessions, hosts: fakeHosts.asHosts});
comm.registerMethods(methods);
return listenPromise(server.listen(0, 'localhost'));
}
async function stopComm() {
comm?.destroyAllClients();
return fromCallback(cb => server.close(cb));
await comm?.testServerShutdown();
await fromCallback(cb => {
server.close(cb);
server.closeAllConnections();
});
}
const assortedMethods: {[name: string]: ClientMethod} = {
@@ -95,34 +119,43 @@ describe('Comm', function() {
sandbox.restore();
});
function getMessages(ws: WebSocket, count: number): Promise<any[]> {
function getMessages(ws: GristClientSocket, count: number): Promise<any[]> {
return new Promise((resolve, reject) => {
const messages: object[] = [];
ws.on('error', reject);
ws.on('message', (msg: string) => {
messages.push(JSON.parse(msg));
ws.onerror = (err) => {
ws.onmessage = null;
reject(err);
};
ws.onmessage = (data: string) => {
messages.push(JSON.parse(data));
if (messages.length >= count) {
ws.onerror = null;
ws.onmessage = null;
resolve(messages);
ws.removeListener('error', reject);
ws.removeAllListeners('message');
}
});
};
});
}
/**
* Returns a promise for the connected websocket.
*/
function connect() {
const ws = new WebSocket('ws://localhost:' + (server.address() as AddressInfo).port);
return new Promise<WebSocket>((resolve, reject) => {
ws.on('open', () => resolve(ws));
ws.on('error', reject);
function connect(options?: GristClientSocketOptions): Promise<GristClientSocket> {
const ws = new GristClientSocket('ws://localhost:' + (server.address() as AddressInfo).port, options);
return new Promise<GristClientSocket>((resolve, reject) => {
ws.onopen = () => {
ws.onerror = null;
resolve(ws);
};
ws.onerror = (err) => {
ws.onopen = null;
reject(err);
};
});
}
describe("server methods", function() {
let ws: WebSocket;
let ws: GristClientSocket;
beforeEach(async function() {
await startComm(assortedMethods);
ws = await connect();
@@ -370,7 +403,8 @@ describe('Comm', function() {
// Intercept the call to _onClose to know when it occurs, since we are trying to hit a
// situation where 'close' and 'failedSend' events happen in either order.
const stubOnClose = sandbox.stub(Client.prototype as any, '_onClose')
.callsFake(function(this: Client) {
.callsFake(async function(this: Client) {
if (!options.closeHappensFirst) { await delay(10); }
eventsSeen.push('close');
return (stubOnClose as any).wrappedMethod.apply(this, arguments);
});
@@ -462,6 +496,9 @@ describe('Comm', function() {
if (options.useSmallMsgs) {
assert.deepEqual(eventsSeen, ['close']);
} else {
// Make sure to have waited long enough for the 'close' event we may have delayed
await delay(20);
// Large messages now cause a send to fail, after filling up buffer, and close the socket.
assert.deepEqual(eventsSeen, ['close', 'close']);
}
@@ -490,17 +527,54 @@ describe('Comm', function() {
assert.deepEqual(eventSpy.getCalls().map(call => call.args[0].n), [n - 1]);
}
});
describe("Allowed Origin", function() {
beforeEach(async function () {
await startComm(assortedMethods);
});
afterEach(async function() {
await stopComm();
});
async function checkOrigin(headers: { origin: string, host: string }, allowed: boolean) {
const promise = connect({ headers });
if (allowed) {
await assert.isFulfilled(promise, `${headers.host} should allow ${headers.origin}`);
} else {
await assert.isRejected(promise, /.*/, `${headers.host} should reject ${headers.origin}`);
}
}
it('origin should match base domain of host', async () => {
await checkOrigin({origin: "https://www.toto.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://badexample.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://bad.com/example.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://front.example.com", host: "worker.example.com"}, true);
await checkOrigin({origin: "https://front.example.com:3000", host: "worker.example.com"}, true);
await checkOrigin({origin: "https://example.com", host: "example.com"}, true);
});
it('with custom domains, origin should match the full hostname', async () => {
fakeHosts.isCustomHost = true;
// For a request to a custom domain, the full hostname must match.
await checkOrigin({origin: "https://front.example.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://front.example.com", host: "front.example.com"}, true);
await checkOrigin({origin: "https://front.example.com:3000", host: "front.example.com"}, true);
});
});
});
// Waits for condFunc() to return true, for up to timeoutMs milliseconds, sleeping for stepMs
// between checks. Returns true if succeeded, false if failed.
async function waitForCondition(condFunc: () => boolean, timeoutMs = 1000, stepMs = 10): Promise<boolean> {
// between checks. Returns if succeeded, throws if failed.
async function waitForCondition(condFunc: () => boolean, timeoutMs = 1000, stepMs = 10): Promise<void> {
const end = Date.now() + timeoutMs;
while (Date.now() < end) {
if (condFunc()) { return true; }
if (condFunc()) { return; }
await delay(stepMs);
}
return false;
throw new Error(`Condition not met after ${timeoutMs}ms: ${condFunc.toString()}`);
}
// Returns a range of count consecutive numbers starting with start.
@@ -513,7 +587,7 @@ function getWSSettings(docWorkerUrl: string): GristWSSettings {
let clientId: string = 'clientid-abc';
let counter: number = 0;
return {
makeWebSocket(url: string): any { return new WebSocket(url, undefined, {}); },
makeWebSocket(url: string): any { return new GristClientSocket(url); },
async getTimezone() { return 'UTC'; },
getPageUrl() { return "http://localhost"; },
async getDocWorkerUrl() { return docWorkerUrl; },

View File

@@ -4,7 +4,7 @@ import { SchemaTypes } from 'app/common/schema';
import { FlexServer } from 'app/server/lib/FlexServer';
import axios from 'axios';
import pick = require('lodash/pick');
import WebSocket from 'ws';
import {GristClientSocket} from 'app/client/components/GristClientSocket';
interface GristRequest {
reqId: number;
@@ -34,9 +34,9 @@ export class GristClient {
private _consumer: () => void;
private _ignoreTrivialActions: boolean = false;
constructor(public ws: any) {
ws.onmessage = (data: any) => {
const msg = pick(JSON.parse(data.data),
constructor(public ws: GristClientSocket) {
ws.onmessage = (data: string) => {
const msg = pick(JSON.parse(data),
['reqId', 'error', 'errorCode', 'data', 'type', 'docFD']);
if (this._ignoreTrivialActions && msg.type === 'docUserAction' &&
msg.data?.actionGroup?.internal === true &&
@@ -149,7 +149,6 @@ export class GristClient {
}
public async close() {
this.ws.terminate();
this.ws.close();
}
@@ -180,17 +179,19 @@ export async function openClient(server: FlexServer, email: string, org: string,
} else {
headers[emailHeader] = email;
}
const ws = new WebSocket('ws://localhost:' + server.getOwnPort() + `/o/${org}`, {
const ws = new GristClientSocket('ws://localhost:' + server.getOwnPort() + `/o/${org}`, {
headers
});
const client = new GristClient(ws);
await new Promise(function(resolve, reject) {
ws.on('open', function() {
ws.onopen = function() {
ws.onerror = null;
resolve(ws);
});
ws.on('error', function(err: any) {
};
ws.onerror = function(err: Error) {
ws.onopen = null;
reject(err);
});
};
});
return client;
}

View File

@@ -4879,11 +4879,17 @@ function testDocApi() {
delete chimpyConfig.headers["X-Requested-With"];
delete anonConfig.headers["X-Requested-With"];
// Target a more realistic Host than "localhost:port"
anonConfig.headers.Host = chimpyConfig.headers.Host = 'api.example.com';
const url = `${serverUrl}/api/docs/${docId}/tables/Table1/records`;
const data = {records: [{fields: {}}]};
const data = { records: [{ fields: {} }] };
const allowedOrigin = 'http://front.example.com';
const forbiddenOrigin = 'http://evil.com';
// Normal same origin requests
anonConfig.headers.Origin = serverUrl;
anonConfig.headers.Origin = allowedOrigin;
let response: AxiosResponse;
for (response of [
await axios.post(url, data, anonConfig),
@@ -4893,13 +4899,13 @@ function testDocApi() {
assert.equal(response.status, 200);
assert.equal(response.headers['access-control-allow-methods'], 'GET, PATCH, PUT, POST, DELETE, OPTIONS');
assert.equal(response.headers['access-control-allow-headers'], 'Authorization, Content-Type, X-Requested-With');
assert.equal(response.headers['access-control-allow-origin'], serverUrl);
assert.equal(response.headers['access-control-allow-origin'], allowedOrigin);
assert.equal(response.headers['access-control-allow-credentials'], 'true');
}
// Cross origin requests from untrusted origin.
for (const config of [anonConfig, chimpyConfig]) {
config.headers.Origin = "https://evil.com/";
config.headers.Origin = forbiddenOrigin;
for (response of [
await axios.post(url, data, config),
await axios.get(url, config),

View File

@@ -0,0 +1,170 @@
import { assert } from 'chai';
import * as http from 'http';
import { GristClientSocket } from 'app/client/components/GristClientSocket';
import { GristSocketServer } from 'app/server/lib/GristSocketServer';
import { fromCallback, listenPromise } from 'app/server/lib/serverUtils';
import { AddressInfo } from 'net';
import httpProxy from 'http-proxy';
describe(`GristSockets`, function () {
for (const webSocketsSupported of [true, false]) {
describe(`when the networks ${webSocketsSupported ? "supports" : "does not support"} WebSockets`, function () {
let server: http.Server | null;
let serverPort: number;
let socketServer: GristSocketServer | null;
let proxy: httpProxy | null;
let proxyServer: http.Server | null;
let proxyPort: number;
let wsAddress: string;
beforeEach(async function () {
await startSocketServer();
await startProxyServer();
});
afterEach(async function () {
await stopProxyServer();
await stopSocketServer();
});
async function startSocketServer() {
server = http.createServer((req, res) => res.writeHead(404).end());
socketServer = new GristSocketServer(server);
await listenPromise(server.listen(0, 'localhost'));
serverPort = (server.address() as AddressInfo).port;
}
async function stopSocketServer() {
await fromCallback(cb => socketServer?.close(cb));
await fromCallback(cb => { server?.close(); server?.closeAllConnections(); server?.on("close", cb); });
socketServer = server = null;
}
// Start an HTTP proxy that supports WebSockets or not
async function startProxyServer() {
proxy = httpProxy.createProxy({
target: `http://localhost:${serverPort}`,
ws: webSocketsSupported,
timeout: 1000,
});
proxy.on('error', () => { });
proxyServer = http.createServer();
if (webSocketsSupported) {
// prevent non-WebSocket requests
proxyServer.on('request', (req, res) => res.writeHead(404).end());
// proxy WebSocket requests
proxyServer.on('upgrade', (req, socket, head) => proxy!.ws(req, socket, head));
} else {
// proxy non-WebSocket requests
proxyServer.on('request', (req, res) => proxy!.web(req, res));
// don't leave WebSocket connection attempts hanging
proxyServer.on('upgrade', (req, socket, head) => socket.destroy());
}
await listenPromise(proxyServer.listen(0, 'localhost'));
proxyPort = (proxyServer.address() as AddressInfo).port;
wsAddress = `ws://localhost:${proxyPort}`;
}
async function stopProxyServer() {
if (proxy) {
proxy.close();
proxy = null;
}
if (proxyServer) {
const server = proxyServer;
await fromCallback(cb => { server.close(cb); server.closeAllConnections(); });
}
proxyServer = null;
}
function getMessages(ws: GristClientSocket, count: number): Promise<string[]> {
return new Promise((resolve, reject) => {
const messages: string[] = [];
ws.onerror = (err) => {
ws.onerror = ws.onmessage = null;
reject(err);
};
ws.onmessage = (data: string) => {
messages.push(data);
if (messages.length >= count) {
ws.onerror = ws.onmessage = null;
resolve(messages);
}
};
});
}
/**
* Returns a promise for the connected websocket.
*/
function connectClient(url: string): Promise<GristClientSocket> {
const socket = new GristClientSocket(url);
return new Promise<GristClientSocket>((resolve, reject) => {
socket.onopen = () => {
socket.onerror = null;
resolve(socket);
};
socket.onerror = (err) => {
socket.onopen = null;
reject(err);
};
});
}
it("should expose initial request", async function () {
const connectionPromise = new Promise<http.IncomingMessage>((resolve) => {
socketServer!.onconnection = (socket, req) => {
resolve(req);
};
});
const clientWs = new GristClientSocket(wsAddress + "/path?query=value", {
headers: { "cookie": "session=1234" }
});
const req = await connectionPromise;
clientWs.close();
// Engine.IO may append extra query parameters, so we check only the start of the URL
assert.match(req.url!, /^\/path\?query=value/);
assert.equal(req.headers.cookie, "session=1234");
});
it("should receive and send messages", async function () {
socketServer!.onconnection = (socket, req) => {
socket.onmessage = (data) => {
socket.send("hello, " + data);
};
};
const clientWs = await connectClient(wsAddress);
clientWs.send("world");
assert.deepEqual(await getMessages(clientWs, 1), ["hello, world"]);
clientWs.close();
});
it("should invoke send callbacks", async function () {
const connectionPromise = new Promise<void>((resolve) => {
socketServer!.onconnection = (socket, req) => {
socket.send("hello", () => resolve());
};
});
const clientWs = await connectClient(wsAddress);
await connectionPromise;
clientWs.close();
});
it("should emit close event for client", async function () {
const clientWs = await connectClient(wsAddress);
const closePromise = new Promise<void>(resolve => {
clientWs.onclose = resolve;
});
clientWs.close();
await closePromise;
});
});
}
});

View File

@@ -12,7 +12,7 @@ import {createTestDir, EnvironmentSnapshot, setTmpLogLevel} from 'test/server/te
import {assert} from 'chai';
import * as cookie from 'cookie';
import fetch from 'node-fetch';
import WebSocket from 'ws';
import {GristClientSocket} from 'app/client/components/GristClientSocket';
describe('ManyFetches', function() {
this.timeout(30000);
@@ -244,7 +244,7 @@ describe('ManyFetches', function() {
return function createConnectionFunc() {
let clientId: string = '0';
return GristWSConnection.create(null, {
makeWebSocket(url: string): any { return new WebSocket(url, undefined, { headers }); },
makeWebSocket(url: string) { return new GristClientSocket(url, { headers }); },
getTimezone() { return Promise.resolve('UTC'); },
getPageUrl() { return pageUrl; },
getDocWorkerUrl() { return Promise.resolve(docWorkerUrl); },