(core) updates from grist-core

This commit is contained in:
Paul Fitzpatrick
2024-04-01 09:22:25 -04:00
24 changed files with 1006 additions and 138 deletions

View File

@@ -22,7 +22,7 @@ import {fromCallback} from 'app/server/lib/serverUtils';
import {i18n} from 'i18next';
import * as crypto from 'crypto';
import moment from 'moment';
import * as WebSocket from 'ws';
import {GristServerSocket} from 'app/server/lib/GristServerSocket';
// How many messages and bytes to accumulate for a disconnected client before booting it.
// The benefit is that a client who temporarily disconnects and reconnects without missing much,
@@ -97,8 +97,7 @@ export class Client {
private _missedMessagesTotalLength: number = 0;
private _destroyTimer: NodeJS.Timer|null = null;
private _destroyed: boolean = false;
private _websocket: WebSocket|null;
private _websocketEventHandlers: Array<{event: string, handler: (...args: any[]) => void}> = [];
private _websocket: GristServerSocket|null;
private _org: string|null = null;
private _profile: UserProfile|null = null;
private _user: FullUser|undefined = undefined;
@@ -131,18 +130,14 @@ export class Client {
return this._locale;
}
public setConnection(websocket: WebSocket, counter: string|null, browserSettings: BrowserSettings) {
public setConnection(websocket: GristServerSocket, counter: string|null, browserSettings: BrowserSettings) {
this._websocket = websocket;
this._counter = counter;
this.browserSettings = browserSettings;
const addHandler = (event: string, handler: (...args: any[]) => void) => {
websocket.on(event, handler);
this._websocketEventHandlers.push({event, handler});
};
addHandler('error', (err: unknown) => this._onError(err));
addHandler('close', () => this._onClose());
addHandler('message', (msg: string) => this._onMessage(msg));
websocket.onerror = (err: Error) => this._onError(err);
websocket.onclose = () => this._onClose();
websocket.onmessage = (msg: string) => this._onMessage(msg);
}
/**
@@ -189,7 +184,7 @@ export class Client {
public interruptConnection() {
if (this._websocket) {
this._removeWebsocketListeners();
this._websocket.removeAllListeners();
this._websocket.terminate(); // close() is inadequate when ws routed via loadbalancer
this._websocket = null;
}
@@ -359,7 +354,7 @@ export class Client {
// See also my report at https://stackoverflow.com/a/48411315/328565
await delay(250);
if (!this._destroyed && this._websocket?.readyState === WebSocket.OPEN) {
if (!this._destroyed && this._websocket?.isOpen) {
await this._sendToWebsocket(JSON.stringify({...clientConnectMsg, dup: true}));
}
} catch (err) {
@@ -604,7 +599,7 @@ export class Client {
/**
* Processes an error on the websocket.
*/
private _onError(err: unknown) {
private _onError(err: Error) {
this._log.warn(null, "onError", err);
// TODO Make sure that this is followed by onClose when the connection is lost.
}
@@ -613,7 +608,7 @@ export class Client {
* Processes the closing of a websocket.
*/
private _onClose() {
this._removeWebsocketListeners();
this._websocket?.removeAllListeners();
// Remove all references to the websocket.
this._websocket = null;
@@ -629,15 +624,4 @@ export class Client {
this._destroyTimer = setTimeout(() => this.destroy(), Deps.clientRemovalTimeoutMs);
}
}
private _removeWebsocketListeners() {
if (this._websocket) {
// Avoiding websocket.removeAllListeners() because WebSocket.Server registers listeners
// internally for websockets it keeps track of, and we should not accidentally remove those.
for (const {event, handler} of this._websocketEventHandlers) {
this._websocket.off(event, handler);
}
this._websocketEventHandlers = [];
}
}
}

View File

@@ -35,7 +35,8 @@
import {EventEmitter} from 'events';
import * as http from 'http';
import * as https from 'https';
import * as WebSocket from 'ws';
import {GristSocketServer} from 'app/server/lib/GristSocketServer';
import {GristServerSocket} from 'app/server/lib/GristServerSocket';
import {parseFirstUrlPart} from 'app/common/gristUrls';
import {firstDefined, safeJsonParse} from 'app/common/gutil';
@@ -50,6 +51,7 @@ import {localeFromRequest} from 'app/server/lib/ServerLocale';
import {fromCallback} from 'app/server/lib/serverUtils';
import {Sessions} from 'app/server/lib/Sessions';
import {i18n} from 'i18next';
import { trustOrigin } from './requestUtils';
export interface CommOptions {
sessions: Sessions; // A collection of all sessions for this instance of Grist
@@ -74,7 +76,7 @@ export interface CommOptions {
export class Comm extends EventEmitter {
// Collection of all sessions; maps sessionIds to ScopedSession objects.
public readonly sessions: Sessions = this._options.sessions;
private _wss: WebSocket.Server[]|null = null;
private _wss: GristSocketServer[]|null = null;
private _clients = new Map<string, Client>(); // Maps clientIds to Client objects.
@@ -146,11 +148,6 @@ export class Comm extends EventEmitter {
public async testServerShutdown() {
if (this._wss) {
for (const wssi of this._wss) {
// Terminate all clients. WebSocket.Server used to do it automatically in close() but no
// longer does (see https://github.com/websockets/ws/pull/1904#discussion_r668844565).
for (const ws of wssi.clients) {
ws.terminate();
}
await fromCallback((cb) => wssi.close(cb));
}
this._wss = null;
@@ -204,14 +201,7 @@ export class Comm extends EventEmitter {
/**
* Processes a new websocket connection, and associates the websocket and a Client object.
*/
private async _onWebSocketConnection(websocket: WebSocket, req: http.IncomingMessage) {
if (this._options.hosts) {
// DocWorker ID (/dw/) and version tag (/v/) may be present in this request but are not
// needed. addOrgInfo assumes req.url starts with /o/ if present.
req.url = parseFirstUrlPart('dw', req.url!).path;
req.url = parseFirstUrlPart('v', req.url).path;
await this._options.hosts.addOrgInfo(req);
}
private async _onWebSocketConnection(websocket: GristServerSocket, req: http.IncomingMessage) {
// Parse the cookie in the request to get the sessionId.
const sessionId = this.sessions.getSessionIdFromRequest(req);
@@ -255,15 +245,28 @@ export class Comm extends EventEmitter {
if (this._options.httpsServer) { servers.push(this._options.httpsServer); }
const wss = [];
for (const server of servers) {
const wssi = new WebSocket.Server({server});
wssi.on('connection', async (websocket: WebSocket, req) => {
const wssi = new GristSocketServer(server, {
verifyClient: async (req: http.IncomingMessage) => {
if (this._options.hosts) {
// DocWorker ID (/dw/) and version tag (/v/) may be present in this request but are not
// needed. addOrgInfo assumes req.url starts with /o/ if present.
req.url = parseFirstUrlPart('dw', req.url!).path;
req.url = parseFirstUrlPart('v', req.url).path;
await this._options.hosts.addOrgInfo(req);
}
return trustOrigin(req);
}
});
wssi.onconnection = async (websocket: GristServerSocket, req) => {
try {
await this._onWebSocketConnection(websocket, req);
} catch (e) {
log.error("Comm connection for %s threw exception: %s", req.url, e.message);
websocket.terminate(); // close() is inadequate when ws routed via loadbalancer
}
});
};
wss.push(wssi);
}
return wss;

View File

@@ -0,0 +1,138 @@
import * as WS from 'ws';
import * as EIO from 'engine.io';
export abstract class GristServerSocket {
public abstract set onerror(handler: (err: Error) => void);
public abstract set onclose(handler: () => void);
public abstract set onmessage(handler: (data: string) => void);
public abstract removeAllListeners(): void;
public abstract close(): void;
public abstract terminate(): void;
public abstract get isOpen(): boolean;
public abstract send(data: string, cb?: (err?: Error) => void): void;
}
export class GristServerSocketEIO extends GristServerSocket {
private _eventHandlers: Array<{ event: string, handler: (...args: any[]) => void }> = [];
private _messageCounter = 0;
// Engine.IO only invokes send() callbacks on success. We keep a map of
// send callbacks for messages in flight so that we can invoke them for
// any messages still unsent upon receiving a "close" event.
private _messageCallbacks: Map<number, (err: Error) => void> = new Map();
constructor(private _socket: EIO.Socket) { super(); }
public set onerror(handler: (err: Error) => void) {
// Note that as far as I can tell, Engine.IO sockets never emit "error"
// but instead include error information in the "close" event.
this._socket.on('error', handler);
this._eventHandlers.push({ event: 'error', handler });
}
public set onclose(handler: () => void) {
const wrappedHandler = (reason: string, description: any) => {
// In practice, when available, description has more error details,
// possibly in the form of an Error object.
const maybeErr = description ?? reason;
const err = maybeErr instanceof Error ? maybeErr : new Error(maybeErr);
for (const cb of this._messageCallbacks.values()) {
cb(err);
}
this._messageCallbacks.clear();
handler();
};
this._socket.on('close', wrappedHandler);
this._eventHandlers.push({ event: 'close', handler: wrappedHandler });
}
public set onmessage(handler: (data: string) => void) {
const wrappedHandler = (msg: Buffer) => {
handler(msg.toString());
};
this._socket.on('message', wrappedHandler);
this._eventHandlers.push({ event: 'message', handler: wrappedHandler });
}
public removeAllListeners() {
for (const { event, handler } of this._eventHandlers) {
this._socket.off(event, handler);
}
this._eventHandlers = [];
}
public close() {
this._socket.close();
}
// Terminates the connection without waiting for the client to close its own side.
public terminate() {
// Trigger a normal close. For the polling transport, this is sufficient and instantaneous.
this._socket.close(/* discard */ true);
}
public get isOpen() {
return this._socket.readyState === 'open';
}
public send(data: string, cb?: (err?: Error) => void) {
const msgNum = this._messageCounter++;
if (cb) {
this._messageCallbacks.set(msgNum, cb);
}
this._socket.send(data, {}, () => {
if (cb && this._messageCallbacks.delete(msgNum)) {
// send was successful: pass no Error to callback
cb();
}
});
}
}
export class GristServerSocketWS extends GristServerSocket {
private _eventHandlers: Array<{ event: string, handler: (...args: any[]) => void }> = [];
constructor(private _ws: WS.WebSocket) { super(); }
public set onerror(handler: (err: Error) => void) {
this._ws.on('error', handler);
this._eventHandlers.push({ event: 'error', handler });
}
public set onclose(handler: () => void) {
this._ws.on('close', handler);
this._eventHandlers.push({ event: 'close', handler });
}
public set onmessage(handler: (data: string) => void) {
const wrappedHandler = (msg: Buffer) => handler(msg.toString());
this._ws.on('message', wrappedHandler);
this._eventHandlers.push({ event: 'message', handler: wrappedHandler });
}
public removeAllListeners() {
// Avoiding websocket.removeAllListeners() because WS.Server registers listeners
// internally for websockets it keeps track of, and we should not accidentally remove those.
for (const { event, handler } of this._eventHandlers) {
this._ws.off(event, handler);
}
this._eventHandlers = [];
}
public close() {
this._ws.close();
}
public terminate() {
this._ws.terminate();
}
public get isOpen() {
return this._ws.readyState === WS.OPEN;
}
public send(data: string, cb?: (err?: Error) => void) {
this._ws.send(data, cb);
}
}

View File

@@ -0,0 +1,111 @@
import * as http from 'http';
import * as WS from 'ws';
import * as EIO from 'engine.io';
import {GristServerSocket, GristServerSocketEIO, GristServerSocketWS} from './GristServerSocket';
import * as net from 'net';
const MAX_PAYLOAD = 100e6;
export interface GristSocketServerOptions {
verifyClient?: (request: http.IncomingMessage) => Promise<boolean>;
}
export class GristSocketServer {
private _wsServer: WS.Server;
private _eioServer: EIO.Server;
private _connectionHandler: (socket: GristServerSocket, req: http.IncomingMessage) => void;
constructor(server: http.Server, private _options?: GristSocketServerOptions) {
this._wsServer = new WS.Server({ noServer: true, maxPayload: MAX_PAYLOAD });
this._eioServer = new EIO.Server({
// We only use Engine.IO for its polling transport,
// so we disable the built-in Engine.IO upgrade mechanism.
allowUpgrades: false,
transports: ['polling'],
maxHttpBufferSize: MAX_PAYLOAD,
cors: {
// This will cause Engine.IO to reflect any client-provided Origin into
// the Access-Control-Allow-Origin header, essentially disabling the
// protection offered by the Same-Origin Policy. This sounds insecure
// but is actually the security model of native WebSockets (they are
// not covered by SOP; any webpage can open a WebSocket connecting to
// any other domain, including the target domain's cookies; it is up to
// the receiving server to check the request's Origin header). Since
// the connection attempt is validated in `verifyClient` later,
// it is safe to let any client attempt a connection here.
origin: true,
// We need to allow the client to send its cookies. See above for the
// reasoning on why it is safe to do so.
credentials: true,
methods: ["GET", "POST"],
},
});
this._eioServer.on('connection', this._onEIOConnection.bind(this));
this._attach(server);
}
public set onconnection(handler: (socket: GristServerSocket, req: http.IncomingMessage) => void) {
this._connectionHandler = handler;
}
public close(cb: (...args: any[]) => void) {
this._eioServer.close();
// Terminate all clients. WS.Server used to do it automatically in close() but no
// longer does (see https://github.com/websockets/ws/pull/1904#discussion_r668844565).
for (const ws of this._wsServer.clients) {
ws.terminate();
}
this._wsServer.close(cb);
}
private _attach(server: http.Server) {
// Forward all WebSocket upgrade requests to WS
server.on('upgrade', async (request, socket, head) => {
if (this._options?.verifyClient && !await this._options.verifyClient(request)) {
// Because we are handling an "upgrade" event, we don't have access to
// a "response" object, just the raw socket. We can still construct
// a well-formed HTTP error response.
socket.write('HTTP/1.1 403 Forbidden\r\n\r\n');
socket.destroy();
return;
}
this._wsServer.handleUpgrade(request, socket as net.Socket, head, (client) => {
this._connectionHandler?.(new GristServerSocketWS(client), request);
});
});
// At this point an Express app is installed as the handler for the server's
// "request" event. We need to install our own listener instead, to intercept
// requests that are meant for the Engine.IO polling implementation.
const listeners = [...server.listeners("request")];
server.removeAllListeners("request");
server.on("request", async (req, res) => {
// Intercept requests that have transport=polling in their querystring
if (/[&?]transport=polling(&|$)/.test(req.url ?? '')) {
if (this._options?.verifyClient && !await this._options.verifyClient(req)) {
res.writeHead(403).end();
return;
}
this._eioServer.handleRequest(req, res);
} else {
// Otherwise fallback to the pre-existing listener(s)
for (const listener of listeners) {
listener.call(server, req, res);
}
}
});
server.on("close", this.close.bind(this));
}
private _onEIOConnection(socket: EIO.Socket) {
const req = socket.request;
(socket as any).request = null; // Free initial request as recommended in the Engine.IO documentation
this._connectionHandler?.(new GristServerSocketEIO(socket), req);
}
}

View File

@@ -8,7 +8,9 @@ import {RequestWithGrist} from 'app/server/lib/GristServer';
import log from 'app/server/lib/log';
import {Permit} from 'app/server/lib/Permit';
import {Request, Response} from 'express';
import { IncomingMessage } from 'http';
import {Writable} from 'stream';
import { TLSSocket } from 'tls';
// log api details outside of dev environment (when GRIST_HOSTED_VERSION is set)
const shouldLogApiDetails = Boolean(process.env.GRIST_HOSTED_VERSION);
@@ -75,33 +77,33 @@ export function getOrgUrl(req: Request, path: string = '/') {
}
/**
* Returns true for requests from permitted origins. For such requests, an
* "Access-Control-Allow-Origin" header is added to the response. Vary: Origin
* is also set to reflect the fact that the headers are a function of the origin,
* to prevent inappropriate caching on the browser's side.
* Returns true for requests from permitted origins. For such requests, if
* a Response object is provided, an "Access-Control-Allow-Origin" header is added
* to the response. Vary: Origin is also set to reflect the fact that the headers
* are a function of the origin, to prevent inappropriate caching on the browser's side.
*/
export function trustOrigin(req: Request, resp: Response): boolean {
export function trustOrigin(req: IncomingMessage, resp?: Response): boolean {
// TODO: We may want to consider changing allowed origin values in the future.
// Note that the request origin is undefined for non-CORS requests.
const origin = req.get('origin');
const origin = req.headers.origin;
if (!origin) { return true; } // Not a CORS request.
if (process.env.GRIST_HOST && req.hostname === process.env.GRIST_HOST) { return true; }
if (!allowHost(req, new URL(origin))) { return false; }
// For a request to a custom domain, the full hostname must match.
resp.header("Access-Control-Allow-Origin", origin);
resp.header("Vary", "Origin");
if (resp) {
// For a request to a custom domain, the full hostname must match.
resp.header("Access-Control-Allow-Origin", origin);
resp.header("Vary", "Origin");
}
return true;
}
// Returns whether req satisfies the given allowedHost. Unless req is to a custom domain, it is
// enough if only the base domains match. Differing ports are allowed, which helps in dev/testing.
export function allowHost(req: Request, allowedHost: string|URL) {
const mreq = req as RequestWithOrg;
export function allowHost(req: IncomingMessage, allowedHost: string|URL) {
const proto = getEndUserProtocol(req);
const actualUrl = new URL(getOriginUrl(req));
const allowedUrl = (typeof allowedHost === 'string') ? new URL(`${proto}://${allowedHost}`) : allowedHost;
if (mreq.isCustomHost) {
if ((req as RequestWithOrg).isCustomHost) {
// For a request to a custom domain, the full hostname must match.
return actualUrl.hostname === allowedUrl.hostname;
} else {
@@ -330,8 +332,8 @@ export interface RequestWithGristInfo extends Request {
* https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-Proto
* https://docs.aws.amazon.com/elasticloadbalancing/latest/classic/x-forwarded-headers.html
*/
export function getOriginUrl(req: Request) {
const host = req.get('host')!;
export function getOriginUrl(req: IncomingMessage) {
const host = req.headers.host;
const protocol = getEndUserProtocol(req);
return `${protocol}://${host}`;
}
@@ -342,11 +344,13 @@ export function getOriginUrl(req: Request) {
* otherwise X-Forwarded-Proto is set on the provided request, otherwise
* the protocol of the request itself.
*/
export function getEndUserProtocol(req: Request) {
export function getEndUserProtocol(req: IncomingMessage) {
if (process.env.APP_HOME_URL) {
return new URL(process.env.APP_HOME_URL).protocol.replace(':', '');
}
return req.get("X-Forwarded-Proto") || req.protocol;
// TODO we shouldn't blindly trust X-Forwarded-Proto. See the Express approach:
// https://expressjs.com/en/5x/api.html#trust.proxy.options.table
return req.headers["x-forwarded-proto"] || ((req.socket as TLSSocket).encrypted ? 'https' : 'http');
}
/**