diff --git a/src/di/Container.ts b/src/di/Container.ts index 1f401b4..c4f6a6a 100644 --- a/src/di/Container.ts +++ b/src/di/Container.ts @@ -1,4 +1,12 @@ -import {DependencyKey, InstanceRef, Instantiable, isInstantiable, StaticClass, TypedDependencyKey} from './types' +import { + DependencyKey, + InstanceRef, + Instantiable, + isInstantiable, + StaticClass, + StaticInstantiable, + TypedDependencyKey, +} from './types' import {AbstractFactory} from './factory/AbstractFactory' import {collect, Collection, globalRegistry, logIfDebugging} from '../util' import {Factory} from './factory/Factory' @@ -66,6 +74,12 @@ export class Container { */ protected instances: Collection = new Collection() + /** + * Collection of static-class overrides registered with this container. + * @protected + */ + protected staticOverrides: Collection<{ base: StaticInstantiable, override: StaticInstantiable }> = new Collection<{base: StaticInstantiable; override: StaticInstantiable}>() + /** * Collection of callbacks waiting for a dependency key to be resolved. * @protected @@ -110,6 +124,52 @@ export class Container { return this } + /** + * Register a static class as an override of some base class. + * @param base + * @param override + */ + registerStaticOverride(base: StaticInstantiable, override: StaticInstantiable): this { + if ( this.hasStaticOverride(base) ) { + throw new DuplicateFactoryKeyError(base) + } + + this.staticOverrides.push({ + base, + override, + }) + + return this + } + + /** Returns true if a static override exists for the given base class. */ + hasStaticOverride(base: StaticInstantiable): boolean { + return this.staticOverrides.where('base', '=', base).isNotEmpty() + } + + /** + * Get the static class overriding the base class. + * @param base + */ + getStaticOverride(base: StaticInstantiable): StaticInstantiable { + const override = this.staticOverrides.firstWhere('base', '=', base) + if ( override ) { + return override.override + } + + return base + } + + /** + * Get the registered instance of the static override of a given class. + * @param base + * @param parameters + */ + makeByStaticOverride(base: StaticInstantiable, ...parameters: any[]): T { + const key = this.getStaticOverride(base) + return this.make(key, ...parameters) + } + /** * Register the given function as a factory within the container. * @param {string} name - unique name to identify the factory in the container diff --git a/src/di/types.ts b/src/di/types.ts index dd54144..afe7c0f 100644 --- a/src/di/types.ts +++ b/src/di/types.ts @@ -38,6 +38,11 @@ export function isInstantiableOf(what: unknown, type: StaticClass): w */ export type StaticClass = Function & {prototype: T} & T2 // eslint-disable-line @typescript-eslint/ban-types +/** + * Type that identifies a value as a static class that instantiates to itself + */ +export type StaticInstantiable = StaticClass> + /** * Returns true if the parameter is a static class. * @param something