/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { env } from './environment';
|
import { getGlobal } from './global_util';
|
import * as log from './log';
|
const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
|
const gradRegistry = getGlobal('gradRegistry', () => new Map());
|
/**
|
* Returns the kernel function (code) associated with the provided names.
|
*
|
* @param kernelName The official name of the kernel.
|
* @param backendName The official name of the backend.
|
*/
|
export function getKernel(kernelName, backendName) {
|
const key = makeKey(kernelName, backendName);
|
return kernelRegistry.get(key);
|
}
|
/**
|
* Returns the registered gradient info associated with the provided kernel.
|
* @param kernelName The official TF kernel name.
|
*/
|
export function getGradient(kernelName) {
|
return gradRegistry.get(kernelName);
|
}
|
export function getKernelsForBackend(backendName) {
|
const it = kernelRegistry.entries();
|
const result = [];
|
while (true) {
|
const { done, value } = it.next();
|
if (done) {
|
break;
|
}
|
const [key, config] = value;
|
const [backend,] = key.split('_');
|
if (backend === backendName) {
|
result.push(config);
|
}
|
}
|
return result;
|
}
|
/**
|
* Registers the function (forward pass) for the kernel in a global registry.
|
*
|
* @param config A config object with the following properties:
|
* - `kernelName` The official name of the kernel.
|
* - `backendName` The official name of the backend.
|
* - `kernelFunc` The function to run during the forward pass of the kernel.
|
* - `setupFunc` Optional. Gets called once, after the backend initializes.
|
* - `disposeFunc` Optional. Gets called once, right before the backend is
|
* disposed.
|
*/
|
export function registerKernel(config) {
|
const { kernelName, backendName } = config;
|
const key = makeKey(kernelName, backendName);
|
if (kernelRegistry.has(key)) {
|
log.warn(`The kernel '${kernelName}' for backend ` +
|
`'${backendName}' is already registered`);
|
}
|
kernelRegistry.set(key, config);
|
}
|
/**
|
* Registers a gradient function for a given kernel in the global registry,
|
* to be used during the back-propagation of that kernel.
|
*
|
* @param config An object with the following properties:
|
* - `kernelName` The name of the kernel that the gradient function is for.
|
* - `gradFunc` The function to run during back-propagation.
|
*/
|
export function registerGradient(config) {
|
const { kernelName } = config;
|
if (gradRegistry.has(kernelName)) {
|
// TODO (yassogba) after 3.0 assess whether we need to keep this gated
|
// to debug mode.
|
if (env().getBool('DEBUG')) {
|
log.warn(`Overriding the gradient for '${kernelName}'`);
|
}
|
}
|
gradRegistry.set(kernelName, config);
|
}
|
/**
|
* Removes the kernel function from the registry.
|
*
|
* @param kernelName The official name of the kernel.
|
* @param backendName The official name of the backend.
|
*
|
*/
|
export function unregisterKernel(kernelName, backendName) {
|
const key = makeKey(kernelName, backendName);
|
if (!kernelRegistry.has(key)) {
|
throw new Error(`The kernel '${kernelName}' for backend ` +
|
`'${backendName}' is not registered`);
|
}
|
kernelRegistry.delete(key);
|
}
|
/** Removes the registered gradient from the global registry. */
|
export function unregisterGradient(kernelName) {
|
if (!gradRegistry.has(kernelName)) {
|
throw new Error(`The gradient '${kernelName}' for backend is not registered`);
|
}
|
gradRegistry.delete(kernelName);
|
}
|
/**
|
* Finds kernels that have already been registered to a backend and re-registers
|
* them for a new backend. Useful for registering custom backends.
|
* @param registeredBackendName Already registered backend.
|
* @param newBackendName New backend.
|
*/
|
export function copyRegisteredKernels(registeredBackendName, newBackendName) {
|
const kernels = getKernelsForBackend(registeredBackendName);
|
kernels.forEach(kernelConfig => {
|
const newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName });
|
registerKernel(newKernelConfig);
|
});
|
}
|
function makeKey(kernelName, backendName) {
|
return `${backendName}_${kernelName}`;
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"kernel_registry.js","sourceRoot":"","sources":["../../../../../tfjs-core/src/kernel_registry.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAC,GAAG,EAAC,MAAM,eAAe,CAAC;AAClC,OAAO,EAAC,SAAS,EAAC,MAAM,eAAe,CAAC;AACxC,OAAO,KAAK,GAAG,MAAM,OAAO,CAAC;AAM7B,MAAM,cAAc,GAClB,SAAS,CAAC,gBAAgB,EAAE,GAAG,EAAE,CAAC,IAAI,GAAG,EACxB,CAAC,CAAC;AACrB,MAAM,YAAY,GAChB,SAAS,CAAC,cAAc,EAAE,GAAG,EAAE,CAAC,IAAI,GAAG,EAAsB,CAAC,CAAC;AAqDjE;;;;;GAKG;AACH,MAAM,UAAU,SAAS,CACrB,UAAkB,EAAE,WAAmB;IACzC,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;IAC7C,OAAO,cAAc,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;AACjC,CAAC;AAED;;;GAGG;AACH,MAAM,UAAU,WAAW,CAAC,UAAkB;IAC5C,OAAO,YAAY,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;AACtC,CAAC;AAED,MAAM,UAAU,oBAAoB,CAAC,WAAmB;IACtD,MAAM,EAAE,GAAG,cAAc,CAAC,OAAO,EAAE,CAAC;IACpC,MAAM,MAAM,GAAmB,EAAE,CAAC;IAElC,OAAO,IAAI,EAAE;QACX,MAAM,EAAC,IAAI,EAAE,KAAK,EAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC;QAChC,IAAI,IAAI,EAAE;YACR,MAAM;SACP;QACD,MAAM,CAAC,GAAG,EAAE,MAAM,CAAC,GAAG,KAAK,CAAC;QAC5B,MAAM,CAAC,OAAO,EAAG,GAAG,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;QACnC,IAAI,OAAO,KAAK,WAAW,EAAE;YAC3B,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;SACrB;KACF;IACD,OAAO,MAAM,CAAC;AAChB,CAAC;AAED;;;;;;;;;;GAUG;AACH,MAAM,UAAU,cAAc,CAAC,MAAoB;IACjD,MAAM,EAAC,UAAU,EAAE,WAAW,EAAC,GAAG,MAAM,CAAC;IACzC,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;IAC7C,IAAI,cAAc,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;QAC3B,GAAG,CAAC,IAAI,CACJ,eAAe,UAAU,gBAAgB;YACzC,IAAI,WAAW,yBAAyB,CAAC,CAAC;KAC/C;IACD,cAAc,CAAC,GAAG,CAAC,GAAG,EAAE,MAAM,CAAC,CAAC;AAClC,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,gBAAgB,CAAC,MAAkB;IACjD,MAAM,EAAC,UAAU,EAAC,GAAG,MAAM,CAAC;IAE5B,IAAI,YAAY,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE;QAChC,sEAAsE;QACtE,iBAAiB;QACjB,IAAI,GAAG,EAAE,CAAC,OAAO,CAAC,OAAO,CAAC,EAAE;YAC1B,GAAG,CAAC,IAAI,CAAC,gCAAgC,UAAU,GAAG,CAAC,CAAC;SACzD;KACF;IACD,YAAY,CAAC,GAAG,CAAC,UAAU,EAAE,MAAM,CAAC,CAAC;AACvC,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,gBAAgB,CAC5B,UAAkB,EAAE,WAAmB;IACzC,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;IAC7C,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;QAC5B,MAAM,IAAI,KAAK,CACX,eAAe,UAAU,gBAAgB;YACzC,IAAI,WAAW,qBAAqB,CAAC,CAAC;KAC3C;IACD,cAAc,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;AAC7B,CAAC;AAED,gEAAgE;AAChE,MAAM,UAAU,kBAAkB,CAAC,UAAkB;IACnD,IAAI,CAAC,YAAY,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE;QACjC,MAAM,IAAI,KAAK,CACX,iBAAiB,UAAU,iCAAiC,CAAC,CAAC;KACnE;IACD,YAAY,CAAC,MAAM,CAAC,UAAU,CAAC,CAAC;AAClC,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,qBAAqB,CACjC,qBAA6B,EAAE,cAAsB;IACvD,MAAM,OAAO,GAAG,oBAAoB,CAAC,qBAAqB,CAAC,CAAC;IAC5D,OAAO,CAAC,OAAO,CAAC,YAAY,CAAC,EAAE;QAC7B,MAAM,eAAe,GACjB,MAAM,CAAC,MAAM,CAAC,EAAE,EAAE,YAAY,EAAE,EAAC,WAAW,EAAE,cAAc,EAAC,CAAC,CAAC;QACnE,cAAc,CAAC,eAAe,CAAC,CAAC;IAClC,CAAC,CAAC,CAAC;AACL,CAAC;AAED,SAAS,OAAO,CAAC,UAAkB,EAClB,WAAmB;IAClC,OAAO,GAAG,WAAW,IAAI,UAAU,EAAE,CAAC;AACxC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {env} from './environment';\nimport {getGlobal} from './global_util';\nimport * as log from './log';\nimport {NamedGradientMap} from './tape';\nimport {Tensor} from './tensor';\nimport {TensorInfo} from './tensor_info';\nimport {RecursiveArray} from './types';\n\nconst kernelRegistry =\n  getGlobal('kernelRegistry', () => new Map<`${string}_${string}`,\n    KernelConfig>());\nconst gradRegistry =\n  getGlobal('gradRegistry', () => new Map<string, GradConfig>());\n\ntype AttributeValue =\n  number | number[] | boolean | boolean[] | string | string[] | NamedAttrMap;\n\n/** These are extra non-tensor/primitive params passed to kernel functions. */\nexport type Attribute = AttributeValue | RecursiveArray<AttributeValue>;\n\n/** Specifies the code to run when executing a kernel. */\nexport type KernelFunc = (params: {\n  inputs: NamedTensorInfoMap,\n  backend: {},\n  attrs?: NamedAttrMap,\n}) => TensorInfo | TensorInfo[];\n\n/** The function to run when computing a gradient during backprop. */\nexport type GradFunc =\n  (dy: Tensor | Tensor[], saved: Tensor[], attrs: NamedAttrMap) =>\n    NamedGradientMap;\n\n/** Function that gets called after the backend initializes. */\nexport type KernelSetupFunc = (backend: {}) => void;\n/** Function that gets called right before the backend is disposed. */\nexport type KernelDisposeFunc = KernelSetupFunc;\n\n/** Config object for registering a kernel in the global registry. */\nexport interface KernelConfig {\n  kernelName: string;\n  backendName: string;\n  kernelFunc: KernelFunc;\n  setupFunc?: KernelSetupFunc;\n  disposeFunc?: KernelDisposeFunc;\n}\n\n/** Config object for registering a gradient in the global registry. */\nexport interface GradConfig {\n  kernelName: string;\n  inputsToSave?: string[];\n  // When saveAllInputs is true, all inputs will be saved. Only use this flag\n  // if inputs is an array of Tensors.\n  saveAllInputs?: boolean;\n  outputsToSave?: boolean[];\n  gradFunc: GradFunc;\n}\n\nexport interface NamedTensorInfoMap {\n  [name: string]: TensorInfo|undefined;\n}\n\nexport interface NamedAttrMap {\n  [name: string]: Attribute;\n}\n\n/**\n * Returns the kernel function (code) associated with the provided names.\n *\n * @param kernelName The official name of the kernel.\n * @param backendName The official name of the backend.\n */\nexport function getKernel(\n    kernelName: string, backendName: string): KernelConfig {\n  const key = makeKey(kernelName, backendName);\n  return kernelRegistry.get(key);\n}\n\n/**\n * Returns the registered gradient info associated with the provided kernel.\n * @param kernelName The official TF kernel name.\n */\nexport function getGradient(kernelName: string): GradConfig {\n  return gradRegistry.get(kernelName);\n}\n\nexport function getKernelsForBackend(backendName: string): KernelConfig[] {\n  const it = kernelRegistry.entries();\n  const result: KernelConfig[] = [];\n\n  while (true) {\n    const {done, value} = it.next();\n    if (done) {\n      break;\n    }\n    const [key, config] = value;\n    const [backend, ] = key.split('_');\n    if (backend === backendName) {\n      result.push(config);\n    }\n  }\n  return result;\n}\n\n/**\n * Registers the function (forward pass) for the kernel in a global registry.\n *\n * @param config A config object with the following properties:\n * - `kernelName` The official name of the kernel.\n * - `backendName` The official name of the backend.\n * - `kernelFunc` The function to run during the forward pass of the kernel.\n * - `setupFunc` Optional. Gets called once, after the backend initializes.\n * - `disposeFunc` Optional. Gets called once, right before the backend is\n * disposed.\n */\nexport function registerKernel(config: KernelConfig) {\n  const {kernelName, backendName} = config;\n  const key = makeKey(kernelName, backendName);\n  if (kernelRegistry.has(key)) {\n    log.warn(\n        `The kernel '${kernelName}' for backend ` +\n        `'${backendName}' is already registered`);\n  }\n  kernelRegistry.set(key, config);\n}\n\n/**\n * Registers a gradient function for a given kernel in the global registry,\n * to be used during the back-propagation of that kernel.\n *\n * @param config An object with the following properties:\n * - `kernelName` The name of the kernel that the gradient function is for.\n * - `gradFunc` The function to run during back-propagation.\n */\nexport function registerGradient(config: GradConfig) {\n  const {kernelName} = config;\n\n  if (gradRegistry.has(kernelName)) {\n    // TODO (yassogba) after 3.0 assess whether we need to keep this gated\n    // to debug mode.\n    if (env().getBool('DEBUG')) {\n      log.warn(`Overriding the gradient for '${kernelName}'`);\n    }\n  }\n  gradRegistry.set(kernelName, config);\n}\n\n/**\n * Removes the kernel function from the registry.\n *\n * @param kernelName The official name of the kernel.\n * @param backendName The official name of the backend.\n *\n */\nexport function unregisterKernel(\n    kernelName: string, backendName: string): void {\n  const key = makeKey(kernelName, backendName);\n  if (!kernelRegistry.has(key)) {\n    throw new Error(\n        `The kernel '${kernelName}' for backend ` +\n        `'${backendName}' is not registered`);\n  }\n  kernelRegistry.delete(key);\n}\n\n/** Removes the registered gradient from the global registry. */\nexport function unregisterGradient(kernelName: string): void {\n  if (!gradRegistry.has(kernelName)) {\n    throw new Error(\n        `The gradient '${kernelName}' for backend is not registered`);\n  }\n  gradRegistry.delete(kernelName);\n}\n\n/**\n * Finds kernels that have already been registered to a backend and re-registers\n * them for a new backend. Useful for registering custom backends.\n * @param registeredBackendName Already registered backend.\n * @param newBackendName New backend.\n */\nexport function copyRegisteredKernels(\n    registeredBackendName: string, newBackendName: string): void {\n  const kernels = getKernelsForBackend(registeredBackendName);\n  kernels.forEach(kernelConfig => {\n    const newKernelConfig =\n        Object.assign({}, kernelConfig, {backendName: newBackendName});\n    registerKernel(newKernelConfig);\n  });\n}\n\nfunction makeKey(kernelName: string,\n                 backendName: string): `${string}_${string}` {\n  return `${backendName}_${kernelName}`;\n}\n"]}
|