gristlabs_grist-core/app/common/BaseAPI.ts
Paul Fitzpatrick bd6a54e901 (core) mitigate csrf by requiring custom header for unsafe methods
Summary:
For methods other than `GET`, `HEAD`, and `OPTIONS`, allow cookie-based authentication only if a certain custom header is present.

Specifically, we check that `X-Requested-With` is set to `XMLHttpRequest`. This is somewhat arbitrary, but allows us to use https://expressjs.com/en/api.html#req.xhr.

A request send from a browser that sets a custom header will prompt a preflight check, giving us a chance to check if the origin is trusted.

This diff deals with getting the header in place. There will be more work to do after this:
 * Make sure that all important endpoints are checking origin.  Skimming code, /api endpoint check origin, and some but not all others.
 * Add tests spot-testing origin checks.
 * Check on cases that authenticate differently.
    - Check the websocket endpoint - it can be connected to from an arbitrary site; there is per-doc access control but probably better to lock it down more.
    - There may be old endpoints that authenticate based on knowledge of a client id rather than cookies.

Test Plan: added a test

Reviewers: dsagal

Reviewed By: dsagal

Differential Revision: https://phab.getgrist.com/D2631
2020-10-08 14:19:25 -04:00

136 lines
4.7 KiB
TypeScript

import {ApiError, ApiErrorDetails} from 'app/common/ApiError';
import axios, {AxiosRequestConfig, AxiosResponse} from 'axios';
import {tbind} from './tbind';
export type ILogger = Pick<Console, 'log'|'debug'|'info'|'warn'|'error'>;
export interface IOptions {
headers?: Record<string, string>;
fetch?: typeof fetch;
newFormData?: () => FormData; // constructor for FormData depends on platform.
logger?: ILogger;
extraParameters?: Map<string, string>; // if set, add query parameters to requests.
}
/**
* Base setup class for creating a REST API client interface.
*/
export class BaseAPI {
// Count of pending requests. It is relied on by tests.
public static numPendingRequests(): number { return this._numPendingRequests; }
// Wrap a promise to add to the count of pending requests until the promise is resolved.
public static async countPendingRequest<T>(promise: Promise<T>): Promise<T> {
try {
BaseAPI._numPendingRequests++;
return await promise;
} finally {
BaseAPI._numPendingRequests--;
}
}
// Define a decorator for methods in BaseAPI or derived classes.
public static countRequest(target: unknown, propertyKey: string, descriptor: PropertyDescriptor) {
const originalMethod = descriptor.value;
descriptor.value = async function(...args: any[]) {
return BaseAPI.countPendingRequest(originalMethod.apply(this, args));
};
}
private static _numPendingRequests: number = 0;
protected fetch: typeof fetch;
protected newFormData: () => FormData;
private _headers: Record<string, string>;
private _logger: ILogger;
private _extraParameters?: Map<string, string>;
constructor(options: IOptions = {}) {
this.fetch = options.fetch || tbind(window.fetch, window);
this.newFormData = options.newFormData || (() => new FormData());
this._logger = options.logger || console;
this._headers = {
'Content-Type': 'application/json',
'X-Requested-With': 'XMLHttpRequest',
...options.headers
};
this._extraParameters = options.extraParameters;
}
// Make a modified request, exposed for test convenience.
public async testRequest(url: string, init: RequestInit = {}): Promise<Response> {
return this.request(url, init);
}
public defaultHeaders() {
return this._headers;
}
public defaultHeadersWithoutContentType() {
const headers = {...this.defaultHeaders()};
delete headers['Content-Type'];
return headers;
}
// Similar to request, but uses the axios library, and supports progress indicator.
@BaseAPI.countRequest
protected async requestAxios(url: string, config: AxiosRequestConfig): Promise<AxiosResponse> {
// If using with FormData in node, axios needs the headers prepared by FormData.
let headers = config.headers;
if (config.data && typeof config.data.getHeaders === 'function') {
headers = {...config.data.getHeaders(), ...headers};
}
const resp = await axios.request({
url,
withCredentials: true,
validateStatus: (status) => true, // This is more like fetch
...config,
headers,
});
if (resp.status !== 200) {
throwApiError(url, resp, resp.data);
}
return resp;
}
@BaseAPI.countRequest
protected async request(input: string, init: RequestInit = {}): Promise<Response> {
init = Object.assign({ headers: this._headers, credentials: 'include' }, init);
if (this._extraParameters) {
const url = new URL(input);
for (const [key, val] of this._extraParameters.entries()) {
url.searchParams.set(key, val);
input = url.href;
}
}
const resp = await this.fetch(input, init);
this._logger.log("Fetched", input);
if (resp.status !== 200) {
const body = await resp.json().catch(() => ({}));
throwApiError(input, resp, body);
}
return resp;
}
/**
* Make a request, and read the response as JSON. This allows counting the request as pending
* until it has been read, which is relied on by tests.
*/
@BaseAPI.countRequest
protected async requestJson(input: string, init: RequestInit = {}): Promise<any> {
return (await this.request(input, init)).json();
}
}
function throwApiError(url: string, resp: Response | AxiosResponse, body: any) {
// If the response includes details, include them into the ApiError we construct. Include
// also the error message from the server as details.userError. It's used by the Notifier.
if (!body) { body = {}; }
const details: ApiErrorDetails = body.details && typeof body.details === 'object' ? body.details : {};
if (body.error) {
details.userError = body.error;
}
throw new ApiError(`Request to ${url} failed with status ${resp.status}: ` +
`${resp.statusText} (${body.error || 'unknown cause'})`, resp.status, details);
}