(core) updates from grist-core

This commit is contained in:
Paul Fitzpatrick
2024-04-01 09:22:25 -04:00
24 changed files with 1006 additions and 138 deletions

View File

@@ -4,11 +4,11 @@ import {assert} from 'chai';
import * as http from 'http';
import {AddressInfo} from 'net';
import * as sinon from 'sinon';
import WebSocket from 'ws';
import * as path from 'path';
import * as tmp from 'tmp';
import {GristWSConnection, GristWSSettings} from 'app/client/components/GristWSConnection';
import {GristClientSocket, GristClientSocketOptions} from 'app/client/components/GristClientSocket';
import {Comm as ClientComm} from 'app/client/components/Comm';
import * as log from 'app/client/lib/log';
import {Comm} from 'app/server/lib/Comm';
@@ -21,10 +21,28 @@ import {Sessions} from 'app/server/lib/Sessions';
import {TcpForwarder} from 'test/server/tcpForwarder';
import * as testUtils from 'test/server/testUtils';
import * as session from '@gristlabs/express-session';
import { Hosts, RequestOrgInfo } from 'app/server/lib/extractOrg';
const SQLiteStore = require('@gristlabs/connect-sqlite3')(session);
promisifyAll(SQLiteStore.prototype);
// Just enough implementation of Hosts to be able to fake using a custom host.
class FakeHosts {
public isCustomHost = false;
public get asHosts() { return this as unknown as Hosts; }
public async addOrgInfo<T extends http.IncomingMessage>(req: T): Promise<T & RequestOrgInfo> {
return Object.assign(req, {
isCustomHost: this.isCustomHost,
org: "example",
url: req.url!
});
}
}
describe('Comm', function() {
testUtils.setTmpLogLevel(process.env.VERBOSE ? 'debug' : 'warn');
@@ -34,6 +52,7 @@ describe('Comm', function() {
let server: http.Server;
let sessions: Sessions;
let fakeHosts: FakeHosts;
let comm: Comm|null = null;
const sandbox = sinon.createSandbox();
@@ -51,14 +70,19 @@ describe('Comm', function() {
function startComm(methods: {[name: string]: ClientMethod}) {
server = http.createServer();
comm = new Comm(server, {sessions});
fakeHosts = new FakeHosts();
comm = new Comm(server, {sessions, hosts: fakeHosts.asHosts});
comm.registerMethods(methods);
return listenPromise(server.listen(0, 'localhost'));
}
async function stopComm() {
comm?.destroyAllClients();
return fromCallback(cb => server.close(cb));
await comm?.testServerShutdown();
await fromCallback(cb => {
server.close(cb);
server.closeAllConnections();
});
}
const assortedMethods: {[name: string]: ClientMethod} = {
@@ -95,34 +119,43 @@ describe('Comm', function() {
sandbox.restore();
});
function getMessages(ws: WebSocket, count: number): Promise<any[]> {
function getMessages(ws: GristClientSocket, count: number): Promise<any[]> {
return new Promise((resolve, reject) => {
const messages: object[] = [];
ws.on('error', reject);
ws.on('message', (msg: string) => {
messages.push(JSON.parse(msg));
ws.onerror = (err) => {
ws.onmessage = null;
reject(err);
};
ws.onmessage = (data: string) => {
messages.push(JSON.parse(data));
if (messages.length >= count) {
ws.onerror = null;
ws.onmessage = null;
resolve(messages);
ws.removeListener('error', reject);
ws.removeAllListeners('message');
}
});
};
});
}
/**
* Returns a promise for the connected websocket.
*/
function connect() {
const ws = new WebSocket('ws://localhost:' + (server.address() as AddressInfo).port);
return new Promise<WebSocket>((resolve, reject) => {
ws.on('open', () => resolve(ws));
ws.on('error', reject);
function connect(options?: GristClientSocketOptions): Promise<GristClientSocket> {
const ws = new GristClientSocket('ws://localhost:' + (server.address() as AddressInfo).port, options);
return new Promise<GristClientSocket>((resolve, reject) => {
ws.onopen = () => {
ws.onerror = null;
resolve(ws);
};
ws.onerror = (err) => {
ws.onopen = null;
reject(err);
};
});
}
describe("server methods", function() {
let ws: WebSocket;
let ws: GristClientSocket;
beforeEach(async function() {
await startComm(assortedMethods);
ws = await connect();
@@ -370,7 +403,8 @@ describe('Comm', function() {
// Intercept the call to _onClose to know when it occurs, since we are trying to hit a
// situation where 'close' and 'failedSend' events happen in either order.
const stubOnClose = sandbox.stub(Client.prototype as any, '_onClose')
.callsFake(function(this: Client) {
.callsFake(async function(this: Client) {
if (!options.closeHappensFirst) { await delay(10); }
eventsSeen.push('close');
return (stubOnClose as any).wrappedMethod.apply(this, arguments);
});
@@ -462,6 +496,9 @@ describe('Comm', function() {
if (options.useSmallMsgs) {
assert.deepEqual(eventsSeen, ['close']);
} else {
// Make sure to have waited long enough for the 'close' event we may have delayed
await delay(20);
// Large messages now cause a send to fail, after filling up buffer, and close the socket.
assert.deepEqual(eventsSeen, ['close', 'close']);
}
@@ -490,17 +527,54 @@ describe('Comm', function() {
assert.deepEqual(eventSpy.getCalls().map(call => call.args[0].n), [n - 1]);
}
});
describe("Allowed Origin", function() {
beforeEach(async function () {
await startComm(assortedMethods);
});
afterEach(async function() {
await stopComm();
});
async function checkOrigin(headers: { origin: string, host: string }, allowed: boolean) {
const promise = connect({ headers });
if (allowed) {
await assert.isFulfilled(promise, `${headers.host} should allow ${headers.origin}`);
} else {
await assert.isRejected(promise, /.*/, `${headers.host} should reject ${headers.origin}`);
}
}
it('origin should match base domain of host', async () => {
await checkOrigin({origin: "https://www.toto.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://badexample.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://bad.com/example.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://front.example.com", host: "worker.example.com"}, true);
await checkOrigin({origin: "https://front.example.com:3000", host: "worker.example.com"}, true);
await checkOrigin({origin: "https://example.com", host: "example.com"}, true);
});
it('with custom domains, origin should match the full hostname', async () => {
fakeHosts.isCustomHost = true;
// For a request to a custom domain, the full hostname must match.
await checkOrigin({origin: "https://front.example.com", host: "worker.example.com"}, false);
await checkOrigin({origin: "https://front.example.com", host: "front.example.com"}, true);
await checkOrigin({origin: "https://front.example.com:3000", host: "front.example.com"}, true);
});
});
});
// Waits for condFunc() to return true, for up to timeoutMs milliseconds, sleeping for stepMs
// between checks. Returns true if succeeded, false if failed.
async function waitForCondition(condFunc: () => boolean, timeoutMs = 1000, stepMs = 10): Promise<boolean> {
// between checks. Returns if succeeded, throws if failed.
async function waitForCondition(condFunc: () => boolean, timeoutMs = 1000, stepMs = 10): Promise<void> {
const end = Date.now() + timeoutMs;
while (Date.now() < end) {
if (condFunc()) { return true; }
if (condFunc()) { return; }
await delay(stepMs);
}
return false;
throw new Error(`Condition not met after ${timeoutMs}ms: ${condFunc.toString()}`);
}
// Returns a range of count consecutive numbers starting with start.
@@ -513,7 +587,7 @@ function getWSSettings(docWorkerUrl: string): GristWSSettings {
let clientId: string = 'clientid-abc';
let counter: number = 0;
return {
makeWebSocket(url: string): any { return new WebSocket(url, undefined, {}); },
makeWebSocket(url: string): any { return new GristClientSocket(url); },
async getTimezone() { return 'UTC'; },
getPageUrl() { return "http://localhost"; },
async getDocWorkerUrl() { return docWorkerUrl; },

View File

@@ -4,7 +4,7 @@ import { SchemaTypes } from 'app/common/schema';
import { FlexServer } from 'app/server/lib/FlexServer';
import axios from 'axios';
import pick = require('lodash/pick');
import WebSocket from 'ws';
import {GristClientSocket} from 'app/client/components/GristClientSocket';
interface GristRequest {
reqId: number;
@@ -34,9 +34,9 @@ export class GristClient {
private _consumer: () => void;
private _ignoreTrivialActions: boolean = false;
constructor(public ws: any) {
ws.onmessage = (data: any) => {
const msg = pick(JSON.parse(data.data),
constructor(public ws: GristClientSocket) {
ws.onmessage = (data: string) => {
const msg = pick(JSON.parse(data),
['reqId', 'error', 'errorCode', 'data', 'type', 'docFD']);
if (this._ignoreTrivialActions && msg.type === 'docUserAction' &&
msg.data?.actionGroup?.internal === true &&
@@ -149,7 +149,6 @@ export class GristClient {
}
public async close() {
this.ws.terminate();
this.ws.close();
}
@@ -180,17 +179,19 @@ export async function openClient(server: FlexServer, email: string, org: string,
} else {
headers[emailHeader] = email;
}
const ws = new WebSocket('ws://localhost:' + server.getOwnPort() + `/o/${org}`, {
const ws = new GristClientSocket('ws://localhost:' + server.getOwnPort() + `/o/${org}`, {
headers
});
const client = new GristClient(ws);
await new Promise(function(resolve, reject) {
ws.on('open', function() {
ws.onopen = function() {
ws.onerror = null;
resolve(ws);
});
ws.on('error', function(err: any) {
};
ws.onerror = function(err: Error) {
ws.onopen = null;
reject(err);
});
};
});
return client;
}

View File

@@ -4879,11 +4879,17 @@ function testDocApi() {
delete chimpyConfig.headers["X-Requested-With"];
delete anonConfig.headers["X-Requested-With"];
// Target a more realistic Host than "localhost:port"
anonConfig.headers.Host = chimpyConfig.headers.Host = 'api.example.com';
const url = `${serverUrl}/api/docs/${docId}/tables/Table1/records`;
const data = {records: [{fields: {}}]};
const data = { records: [{ fields: {} }] };
const allowedOrigin = 'http://front.example.com';
const forbiddenOrigin = 'http://evil.com';
// Normal same origin requests
anonConfig.headers.Origin = serverUrl;
anonConfig.headers.Origin = allowedOrigin;
let response: AxiosResponse;
for (response of [
await axios.post(url, data, anonConfig),
@@ -4893,13 +4899,13 @@ function testDocApi() {
assert.equal(response.status, 200);
assert.equal(response.headers['access-control-allow-methods'], 'GET, PATCH, PUT, POST, DELETE, OPTIONS');
assert.equal(response.headers['access-control-allow-headers'], 'Authorization, Content-Type, X-Requested-With');
assert.equal(response.headers['access-control-allow-origin'], serverUrl);
assert.equal(response.headers['access-control-allow-origin'], allowedOrigin);
assert.equal(response.headers['access-control-allow-credentials'], 'true');
}
// Cross origin requests from untrusted origin.
for (const config of [anonConfig, chimpyConfig]) {
config.headers.Origin = "https://evil.com/";
config.headers.Origin = forbiddenOrigin;
for (response of [
await axios.post(url, data, config),
await axios.get(url, config),

View File

@@ -0,0 +1,170 @@
import { assert } from 'chai';
import * as http from 'http';
import { GristClientSocket } from 'app/client/components/GristClientSocket';
import { GristSocketServer } from 'app/server/lib/GristSocketServer';
import { fromCallback, listenPromise } from 'app/server/lib/serverUtils';
import { AddressInfo } from 'net';
import httpProxy from 'http-proxy';
describe(`GristSockets`, function () {
for (const webSocketsSupported of [true, false]) {
describe(`when the networks ${webSocketsSupported ? "supports" : "does not support"} WebSockets`, function () {
let server: http.Server | null;
let serverPort: number;
let socketServer: GristSocketServer | null;
let proxy: httpProxy | null;
let proxyServer: http.Server | null;
let proxyPort: number;
let wsAddress: string;
beforeEach(async function () {
await startSocketServer();
await startProxyServer();
});
afterEach(async function () {
await stopProxyServer();
await stopSocketServer();
});
async function startSocketServer() {
server = http.createServer((req, res) => res.writeHead(404).end());
socketServer = new GristSocketServer(server);
await listenPromise(server.listen(0, 'localhost'));
serverPort = (server.address() as AddressInfo).port;
}
async function stopSocketServer() {
await fromCallback(cb => socketServer?.close(cb));
await fromCallback(cb => { server?.close(); server?.closeAllConnections(); server?.on("close", cb); });
socketServer = server = null;
}
// Start an HTTP proxy that supports WebSockets or not
async function startProxyServer() {
proxy = httpProxy.createProxy({
target: `http://localhost:${serverPort}`,
ws: webSocketsSupported,
timeout: 1000,
});
proxy.on('error', () => { });
proxyServer = http.createServer();
if (webSocketsSupported) {
// prevent non-WebSocket requests
proxyServer.on('request', (req, res) => res.writeHead(404).end());
// proxy WebSocket requests
proxyServer.on('upgrade', (req, socket, head) => proxy!.ws(req, socket, head));
} else {
// proxy non-WebSocket requests
proxyServer.on('request', (req, res) => proxy!.web(req, res));
// don't leave WebSocket connection attempts hanging
proxyServer.on('upgrade', (req, socket, head) => socket.destroy());
}
await listenPromise(proxyServer.listen(0, 'localhost'));
proxyPort = (proxyServer.address() as AddressInfo).port;
wsAddress = `ws://localhost:${proxyPort}`;
}
async function stopProxyServer() {
if (proxy) {
proxy.close();
proxy = null;
}
if (proxyServer) {
const server = proxyServer;
await fromCallback(cb => { server.close(cb); server.closeAllConnections(); });
}
proxyServer = null;
}
function getMessages(ws: GristClientSocket, count: number): Promise<string[]> {
return new Promise((resolve, reject) => {
const messages: string[] = [];
ws.onerror = (err) => {
ws.onerror = ws.onmessage = null;
reject(err);
};
ws.onmessage = (data: string) => {
messages.push(data);
if (messages.length >= count) {
ws.onerror = ws.onmessage = null;
resolve(messages);
}
};
});
}
/**
* Returns a promise for the connected websocket.
*/
function connectClient(url: string): Promise<GristClientSocket> {
const socket = new GristClientSocket(url);
return new Promise<GristClientSocket>((resolve, reject) => {
socket.onopen = () => {
socket.onerror = null;
resolve(socket);
};
socket.onerror = (err) => {
socket.onopen = null;
reject(err);
};
});
}
it("should expose initial request", async function () {
const connectionPromise = new Promise<http.IncomingMessage>((resolve) => {
socketServer!.onconnection = (socket, req) => {
resolve(req);
};
});
const clientWs = new GristClientSocket(wsAddress + "/path?query=value", {
headers: { "cookie": "session=1234" }
});
const req = await connectionPromise;
clientWs.close();
// Engine.IO may append extra query parameters, so we check only the start of the URL
assert.match(req.url!, /^\/path\?query=value/);
assert.equal(req.headers.cookie, "session=1234");
});
it("should receive and send messages", async function () {
socketServer!.onconnection = (socket, req) => {
socket.onmessage = (data) => {
socket.send("hello, " + data);
};
};
const clientWs = await connectClient(wsAddress);
clientWs.send("world");
assert.deepEqual(await getMessages(clientWs, 1), ["hello, world"]);
clientWs.close();
});
it("should invoke send callbacks", async function () {
const connectionPromise = new Promise<void>((resolve) => {
socketServer!.onconnection = (socket, req) => {
socket.send("hello", () => resolve());
};
});
const clientWs = await connectClient(wsAddress);
await connectionPromise;
clientWs.close();
});
it("should emit close event for client", async function () {
const clientWs = await connectClient(wsAddress);
const closePromise = new Promise<void>(resolve => {
clientWs.onclose = resolve;
});
clientWs.close();
await closePromise;
});
});
}
});

View File

@@ -12,7 +12,7 @@ import {createTestDir, EnvironmentSnapshot, setTmpLogLevel} from 'test/server/te
import {assert} from 'chai';
import * as cookie from 'cookie';
import fetch from 'node-fetch';
import WebSocket from 'ws';
import {GristClientSocket} from 'app/client/components/GristClientSocket';
describe('ManyFetches', function() {
this.timeout(30000);
@@ -244,7 +244,7 @@ describe('ManyFetches', function() {
return function createConnectionFunc() {
let clientId: string = '0';
return GristWSConnection.create(null, {
makeWebSocket(url: string): any { return new WebSocket(url, undefined, { headers }); },
makeWebSocket(url: string) { return new GristClientSocket(url, { headers }); },
getTimezone() { return Promise.resolve('UTC'); },
getPageUrl() { return pageUrl; },
getDocWorkerUrl() { return Promise.resolve(docWorkerUrl); },

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# This runs browser tests with the server started using docker, to
# catch any configuration problems.