/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
|
import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend';
|
import {Environment, setEnvironmentGlobal} from './environment';
|
import {getGradient, getKernel, getKernelsForBackend, NamedAttrMap, TensorInfo} from './kernel_registry';
|
import {Profiler} from './profiler';
|
import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode} from './tape';
|
import {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';
|
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
|
import {getTensorsInContainer} from './tensor_util';
|
import {BackendValues, DataType, DataValues} from './types';
|
import * as util from './util';
|
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';
|
|
/**
|
* A function that computes an output. The save function is for saving tensors
|
* computed in the forward pass, that we need in the backward pass.
|
*/
|
export type ForwardFunc<T> = (backend: KernelBackend, save?: GradSaveFunc) => T;
|
|
/**
|
* @docalias (a: Tensor, b: Tensor,..., save?: Function) => {
|
* value: Tensor,
|
* gradFunc: (dy: Tensor, saved?: NamedTensorMap) => Tensor | Tensor[]
|
* }
|
*/
|
export type CustomGradientFunc<T extends Tensor> =
|
(...inputs: Array<Tensor|GradSaveFunc>) => {
|
value: T;
|
gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[];
|
};
|
|
export type MemoryInfo = {
|
numTensors: number; numDataBuffers: number; numBytes: number;
|
unreliable?: boolean; reasons: string[];
|
};
|
|
type KernelProfile = {
|
name: string; bytesAdded: number; totalBytesSnapshot: number;
|
tensorsAdded: number;
|
totalTensorsSnapshot: number;
|
inputShapes: number[][];
|
outputShapes: number[][];
|
};
|
|
export type ProfileInfo = {
|
newBytes: number; newTensors: number; peakBytes: number;
|
kernels: KernelProfile[];
|
result: TensorContainer;
|
};
|
|
export interface TimingInfo extends BackendTimingInfo {
|
wallMs: number;
|
}
|
|
/** @docalias Function */
|
export type ScopeFn<T extends TensorContainer> = () => T;
|
|
interface ScopeState {
|
track: Tensor[];
|
name: string;
|
id: number;
|
}
|
|
class EngineState {
|
// Public since optimizers will use it.
|
registeredVariables: NamedVariableMap = {};
|
|
nextTapeNodeId = 0;
|
numBytes = 0;
|
numTensors = 0;
|
numStringTensors = 0;
|
numDataBuffers = 0;
|
|
activeTape: TapeNode[];
|
// Number of nested tf.grad() statements when computing higher-order
|
// gradients. E.g. `1` for first-order gradients and `2` for second-order
|
// gradients. Used to track if the tape should be removed after a backprop.
|
gradientDepth = 0;
|
// Number of nested kernel calls. When kernel depth is greater than 1, we turn
|
// off the tape.
|
kernelDepth = 0;
|
|
// Keep Tensors that parallel the tapes.
|
activeScope: ScopeState;
|
scopeStack: ScopeState[] = [];
|
/**
|
* Keeps track of the number of data moves during a kernel execution. We
|
* maintain a stack since kernels can call other kernels, recursively.
|
*/
|
numDataMovesStack: number[] = [];
|
nextScopeId = 0;
|
|
tensorInfo = new WeakMap<DataId, {
|
backend: KernelBackend,
|
bytes: number,
|
dtype: DataType,
|
shape: number[],
|
refCount: number
|
}>();
|
|
profiling = false;
|
activeProfile: ProfileInfo =
|
{newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};
|
|
dispose() {
|
for (const variableName in this.registeredVariables) {
|
this.registeredVariables[variableName].dispose();
|
}
|
}
|
}
|
|
export class Engine implements TensorTracker, DataMover {
|
state: EngineState;
|
backendName: string;
|
registry: {[id: string]: KernelBackend} = {};
|
registryFactory: {
|
[id: string]: {
|
factory: () => KernelBackend | Promise<KernelBackend>,
|
priority: number
|
}
|
} = {};
|
|
private profiler: Profiler;
|
private backendInstance: KernelBackend;
|
private pendingBackendInit: Promise<boolean>;
|
private pendingBackendInitId = 0;
|
|
constructor(public ENV: Environment) {
|
this.state = new EngineState();
|
}
|
|
async ready(): Promise<void> {
|
if (this.pendingBackendInit != null) {
|
return this.pendingBackendInit.then(() => {});
|
}
|
if (this.backendInstance != null) {
|
return;
|
}
|
const sortedBackends = this.getSortedBackends();
|
|
for (let i = 0; i < sortedBackends.length; i++) {
|
const backendName = sortedBackends[i];
|
const success = await this.initializeBackend(backendName).success;
|
if (success) {
|
await this.setBackend(backendName);
|
return;
|
}
|
}
|
|
throw new Error(
|
`Could not initialize any backends, all backend initializations ` +
|
`failed.`);
|
}
|
|
get backend(): KernelBackend {
|
if (this.pendingBackendInit != null) {
|
throw new Error(
|
`Backend '${this.backendName}' has not yet been initialized. Make ` +
|
`sure to await tf.ready() or await tf.setBackend() before calling ` +
|
`other methods`);
|
}
|
if (this.backendInstance == null) {
|
const {name, asyncInit} = this.initializeBackendsAndReturnBest();
|
if (asyncInit) {
|
throw new Error(
|
`The highest priority backend '${name}' has not yet been ` +
|
`initialized. Make sure to await tf.ready() or ` +
|
`await tf.setBackend() before calling other methods`);
|
}
|
this.setBackend(name);
|
}
|
return this.backendInstance;
|
}
|
|
backendNames(): string[] {
|
return Object.keys(this.registryFactory);
|
}
|
|
findBackend(backendName: string): KernelBackend {
|
if (!(backendName in this.registry)) {
|
// If the backend hasn't been initialized but we have a registry entry for
|
// it, initialize it and return it.
|
if (backendName in this.registryFactory) {
|
const {asyncInit} = this.initializeBackend(backendName);
|
if (asyncInit) {
|
// Backend is not ready yet.
|
return null;
|
}
|
} else {
|
return null;
|
}
|
}
|
return this.registry[backendName];
|
}
|
|
findBackendFactory(backendName: string):
|
() => KernelBackend | Promise<KernelBackend> {
|
if (!(backendName in this.registryFactory)) {
|
return null;
|
}
|
return this.registryFactory[backendName].factory;
|
}
|
|
registerBackend(
|
backendName: string,
|
factory: () => KernelBackend | Promise<KernelBackend>,
|
priority = 1): boolean {
|
if (backendName in this.registryFactory) {
|
console.warn(
|
`${backendName} backend was already registered. ` +
|
`Reusing existing backend factory.`);
|
return false;
|
}
|
this.registryFactory[backendName] = {factory, priority};
|
return true;
|
}
|
|
async setBackend(backendName: string): Promise<boolean> {
|
if (this.registryFactory[backendName] == null) {
|
throw new Error(`Backend name '${backendName}' not found in registry`);
|
}
|
this.backendName = backendName;
|
if (this.registry[backendName] == null) {
|
this.backendInstance = null;
|
const {success, asyncInit} = this.initializeBackend(backendName);
|
const result = asyncInit ? await success : success;
|
if (!result) {
|
return false;
|
}
|
}
|
this.backendInstance = this.registry[backendName];
|
this.setupRegisteredKernels();
|
// Reset the profiler.
|
this.profiler = new Profiler(this.backendInstance);
|
|
return true;
|
}
|
|
private setupRegisteredKernels(): void {
|
const kernels = getKernelsForBackend(this.backendName);
|
kernels.forEach(kernel => {
|
if (kernel.setupFunc != null) {
|
kernel.setupFunc(this.backendInstance);
|
}
|
});
|
}
|
|
private disposeRegisteredKernels(backendName: string): void {
|
const kernels = getKernelsForBackend(backendName);
|
kernels.forEach(kernel => {
|
if (kernel.disposeFunc != null) {
|
kernel.disposeFunc(this.registry[backendName]);
|
}
|
});
|
}
|
|
/**
|
* Initializes a backend by looking up the backend name in the factory
|
* registry and calling the factory method. Returns a boolean representing
|
* whether the initialization of the backend suceeded. Throws an error if
|
* there is no backend in the factory registry.
|
*/
|
private initializeBackend(backendName: string):
|
{success: boolean|Promise<boolean>, asyncInit: boolean} {
|
const registryFactoryEntry = this.registryFactory[backendName];
|
if (registryFactoryEntry == null) {
|
throw new Error(
|
`Cannot initialize backend ${backendName}, no registration found.`);
|
}
|
|
try {
|
const backend = registryFactoryEntry.factory();
|
// Test if the factory returns a promise.
|
if (Promise.resolve(backend) === backend) {
|
const promiseId = ++this.pendingBackendInitId;
|
const success =
|
backend
|
.then(backendInstance => {
|
// Outdated promise. Another backend was set in the meantime.
|
if (promiseId < this.pendingBackendInitId) {
|
return false;
|
}
|
this.registry[backendName] = backendInstance;
|
this.pendingBackendInit = null;
|
return true;
|
})
|
.catch(err => {
|
// Outdated promise. Another backend was set in the meantime.
|
if (promiseId < this.pendingBackendInitId) {
|
return false;
|
}
|
this.pendingBackendInit = null;
|
console.warn(
|
`Initialization of backend ${backendName} failed`);
|
console.warn(err.stack || err.message);
|
return false;
|
});
|
this.pendingBackendInit = success;
|
return {success, asyncInit: true};
|
} else {
|
this.registry[backendName] = backend as KernelBackend;
|
return {success: true, asyncInit: false};
|
}
|
} catch (err) {
|
console.warn(`Initialization of backend ${backendName} failed`);
|
console.warn(err.stack || err.message);
|
return {success: false, asyncInit: false};
|
}
|
}
|
|
removeBackend(backendName: string): void {
|
if (!(backendName in this.registryFactory)) {
|
throw new Error(`${backendName} backend not found in registry`);
|
}
|
if (this.backendName === backendName && this.pendingBackendInit != null) {
|
// There is a pending promise of the backend we want to remove. Make it
|
// obsolete.
|
this.pendingBackendInitId++;
|
}
|
|
if (backendName in this.registry) {
|
this.disposeRegisteredKernels(backendName);
|
this.registry[backendName].dispose();
|
delete this.registry[backendName];
|
}
|
|
delete this.registryFactory[backendName];
|
|
// Unset the backend if it is active.
|
if (this.backendName === backendName) {
|
this.pendingBackendInit = null;
|
this.backendName = null;
|
this.backendInstance = null;
|
}
|
}
|
|
private getSortedBackends(): string[] {
|
if (Object.keys(this.registryFactory).length === 0) {
|
throw new Error('No backend found in registry.');
|
}
|
return Object.keys(this.registryFactory).sort((a: string, b: string) => {
|
// Highest priority comes first.
|
return this.registryFactory[b].priority -
|
this.registryFactory[a].priority;
|
});
|
}
|
|
private initializeBackendsAndReturnBest():
|
{name: string, asyncInit: boolean} {
|
const sortedBackends = this.getSortedBackends();
|
|
for (let i = 0; i < sortedBackends.length; i++) {
|
const backendName = sortedBackends[i];
|
const {success, asyncInit} = this.initializeBackend(backendName);
|
if (asyncInit || success) {
|
return {name: backendName, asyncInit};
|
}
|
}
|
throw new Error(
|
`Could not initialize any backends, all backend initializations ` +
|
`failed.`);
|
}
|
|
moveData(destBackend: KernelBackend, dataId: DataId) {
|
const info = this.state.tensorInfo.get(dataId);
|
const srcBackend = info.backend;
|
const values = this.readSync(dataId);
|
// Delete the tensor from the old backend and move it to the new
|
// backend.
|
srcBackend.disposeData(dataId);
|
info.backend = destBackend;
|
destBackend.move(dataId, values, info.shape, info.dtype);
|
if (this.shouldCheckForMemLeaks()) {
|
// Track the number of moves during a kernel execution to correctly
|
// detect memory leaks.
|
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
|
}
|
}
|
|
tidy<T extends TensorContainer>(nameOrFn: string|ScopeFn<T>, fn?: ScopeFn<T>):
|
T {
|
let name: string = null;
|
if (fn == null) {
|
// Called with only 1 argument.
|
if (typeof nameOrFn !== 'function') {
|
throw new Error('Please provide a function to tidy()');
|
}
|
fn = nameOrFn;
|
} else {
|
// Called with 2 arguments.
|
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
|
throw new Error(
|
'When calling with two arguments, the first argument ' +
|
'to tidy() must be a string');
|
}
|
if (typeof fn !== 'function') {
|
throw new Error(
|
'When calling with two arguments, the 2nd argument ' +
|
'to tidy() must be a function');
|
}
|
name = nameOrFn as string;
|
// TODO(nsthorat,smilkov): Do operation logging and performance
|
// profiling.
|
}
|
let result: T;
|
return this.scopedRun(
|
() => this.startScope(name), () => this.endScope(result), () => {
|
result = fn();
|
if (result instanceof Promise) {
|
console.error('Cannot return a Promise inside of tidy.');
|
}
|
return result;
|
});
|
}
|
|
private scopedRun<T>(start: () => void, end: () => void, f: () => T): T {
|
start();
|
try {
|
const res = f();
|
end();
|
return res;
|
} catch (ex) {
|
end();
|
throw ex;
|
}
|
}
|
|
private static nextTensorId = 0;
|
private nextTensorId(): number {
|
return Engine.nextTensorId++;
|
}
|
|
private static nextVariableId = 0;
|
private nextVariableId(): number {
|
return Engine.nextVariableId++;
|
}
|
|
/**
|
* This method is called instead of the public-facing tensor.clone() when
|
* saving a tensor for backwards pass. It makes sure to add the clone
|
* operation to the tape regardless of being called inside a kernel
|
* execution.
|
*
|
* This method will go away once all kernels are modularized since we won't
|
* need to turn off the tape inside runKernel().
|
*/
|
private clone(x: Tensor): Tensor {
|
const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);
|
const inputs = {x};
|
const grad = (dy: Tensor) => ({x: () => dy.toFloat()});
|
const saved: Tensor[] = [];
|
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved);
|
return y;
|
}
|
|
/**
|
* Execute a kernel with the given name and return the output tensor.
|
*
|
* @param kernelName The name of the kernel to execute.
|
* @param inputs A map of input names to tensors.
|
* @param attrs A map of attribute names to their values. An attribute is a
|
* primitive (non-tensor) input to the kernel.
|
* @param inputsToSave A list of tensors, inputs to save for the backprop
|
* computation.
|
* @param outputsToSave A list of booleans, specifying which output to save
|
* for the backprop computation. These are booleans since the output
|
* tensors are not visible to the user.
|
*/
|
runKernel(
|
kernelName: string, inputs: NamedTensorMap, attrs: NamedAttrMap,
|
inputsToSave?: Tensor[], outputsToSave?: boolean[]): Tensor|Tensor[] {
|
const forwardFunc: null = null;
|
const backwardsFunc: null = null;
|
// Call runKernel as a stop-gap until we modularize all kernels.
|
// Once we modularize all kernels, we will remove the existing
|
// `runKernelFunc`.
|
return this.runKernelFunc(
|
forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave,
|
outputsToSave);
|
}
|
|
private shouldCheckForMemLeaks(): boolean {
|
return this.ENV.getBool('IS_TEST');
|
}
|
|
private checkKernelForMemLeak(
|
kernelName: string, numDataIdsBefore: number,
|
outInfos: TensorInfo[]): void {
|
const numDataIdsAfter = this.backend.numDataIds();
|
|
// Count the number of data ids associated with the result of the kernel.
|
let numOutputDataIds = 0;
|
outInfos.forEach(info => {
|
// Complex numbers allocate 3 data ids, one for 'real', one for
|
// 'imaginary', and one for the container that holds the former two.
|
numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
|
});
|
|
// Account for the number of moves during kernel execution. A "data move"
|
// can happen in the middle of a kernel execution, placing a new (key,value)
|
// pair in the data storage. Since data moves have net zero effect (we
|
// always remove the data from the old backend), we have to cancel them out
|
// when detecting memory leaks.
|
const numMoves =
|
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
|
const dataIdsLeaked =
|
numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
|
if (dataIdsLeaked > 0) {
|
throw new Error(
|
`Backend '${this.backendName}' has an internal memory leak ` +
|
`(${dataIdsLeaked} data ids) after running '${kernelName}'`);
|
}
|
}
|
|
/**
|
* @deprecated Use `runKernel` for newly added kernels. Keep using this method
|
* only for kernels that are not yet fully modularized.
|
*/
|
runKernelFunc<T extends Tensor|Tensor[], I extends NamedTensorMap>(
|
forwardFunc: ForwardFunc<T>, inputs: I,
|
backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]},
|
kernelName?: string, attrs?: NamedAttrMap, inputsToSave: Tensor[] = [],
|
outputsToSave: boolean[] = []): T {
|
let outputs: Tensor[];
|
let saved: Tensor[] = [];
|
const isTapeOn = this.isTapeOn();
|
if (kernelName == null) {
|
kernelName =
|
this.state.activeScope != null ? this.state.activeScope.name : '';
|
}
|
const saveFunc: GradSaveFunc = (tensors) => {
|
// Do not save unless we are recording to the tape. Otherwise it would
|
// cause a mem leak since we would never run backprop, which disposes
|
// the kept tensors.
|
if (!isTapeOn) {
|
return;
|
}
|
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
|
};
|
|
const startingBytecount = this.state.numBytes;
|
const startingNumTensors = this.state.numTensors;
|
|
if (this.shouldCheckForMemLeaks()) {
|
this.state.numDataMovesStack.push(0);
|
}
|
|
let kernelFunc: () => Tensor[];
|
const kernel = getKernel(kernelName, this.backendName);
|
let out: TensorInfo|TensorInfo[];
|
if (kernel != null) {
|
kernelFunc = () => {
|
const numDataIdsBefore = this.backend.numDataIds();
|
out = kernel.kernelFunc({inputs, attrs, backend: this.backend});
|
const outInfos = Array.isArray(out) ? out : [out];
|
if (this.shouldCheckForMemLeaks()) {
|
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
|
}
|
const outTensors = outInfos.map(
|
({dataId, shape, dtype}) =>
|
this.makeTensorFromDataId(dataId, shape, dtype));
|
const outsToSave = outTensors.filter((_, i) => outputsToSave[i]);
|
// Save the inputs and outputs.
|
saveFunc((inputsToSave || []).slice().concat(outsToSave));
|
return outTensors;
|
};
|
} else {
|
kernelFunc = () => {
|
const numDataIdsBefore = this.backend.numDataIds();
|
out = this.tidy(() => forwardFunc(this.backend, saveFunc));
|
const outs = (Array.isArray(out) ? out : [out]) as Tensor[];
|
if (this.shouldCheckForMemLeaks()) {
|
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);
|
}
|
return outs;
|
};
|
}
|
|
// Stop recording to a tape when running a kernel.
|
this.scopedRun(
|
() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
|
if (!this.ENV.getBool('DEBUG')) {
|
outputs = kernelFunc();
|
} else {
|
outputs = this.profiler.profileKernel(
|
kernelName, inputs, () => kernelFunc());
|
}
|
});
|
|
if (isTapeOn) {
|
this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved);
|
}
|
|
if (this.state.profiling) {
|
this.state.activeProfile.kernels.push({
|
name: kernelName,
|
bytesAdded: this.state.numBytes - startingBytecount,
|
totalBytesSnapshot: this.state.numBytes,
|
tensorsAdded: this.state.numTensors - startingNumTensors,
|
totalTensorsSnapshot: this.state.numTensors,
|
inputShapes: Object.keys(inputs).map(key => inputs[key].shape),
|
outputShapes: outputs.map(item => item.shape)
|
});
|
}
|
return (Array.isArray(out) ? outputs : outputs[0]) as T;
|
}
|
|
/**
|
* Internal method used by public APIs for tensor creation. Makes a new
|
* tensor with the provided shape, dtype and values. It always
|
* creates a new data id and writes the values to the underlying backend.
|
*/
|
makeTensor(
|
values: DataValues, shape: number[], dtype: DataType,
|
backend?: KernelBackend): Tensor {
|
if (values == null) {
|
throw new Error('Values passed to engine.makeTensor() are null');
|
}
|
dtype = dtype || 'float32';
|
backend = backend || this.backend;
|
let backendVals = values as BackendValues;
|
if (dtype === 'string' && util.isString(values[0])) {
|
backendVals = (values as string[]).map(d => util.encodeString(d));
|
}
|
const dataId = backend.write(backendVals, shape, dtype);
|
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
|
this.incRef(t, backend);
|
|
// Count bytes for string tensors.
|
if (dtype === 'string') {
|
const info = this.state.tensorInfo.get(dataId);
|
const newBytes = bytesFromStringArray(backendVals as Uint8Array[]);
|
this.state.numBytes += newBytes - info.bytes;
|
info.bytes = newBytes;
|
}
|
return t;
|
}
|
|
/**
|
* Internal method used by backends. Makes a new tensor
|
* that is a wrapper around an existing data id. It doesn't create
|
* a new data id, only increments the ref count used in memory tracking.
|
*/
|
makeTensorFromDataId(
|
dataId: DataId, shape: number[], dtype: DataType,
|
backend?: KernelBackend): Tensor {
|
dtype = dtype || 'float32';
|
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
|
this.incRef(t, backend);
|
return t;
|
}
|
|
makeVariable(
|
initialValue: Tensor, trainable = true, name?: string,
|
dtype?: DataType): Variable {
|
name = name || this.nextVariableId().toString();
|
if (dtype != null && dtype !== initialValue.dtype) {
|
initialValue = initialValue.asType(dtype);
|
}
|
const v = new Variable(initialValue, trainable, name, this.nextTensorId());
|
if (this.state.registeredVariables[v.name] != null) {
|
throw new Error(`Variable with name ${v.name} was already registered`);
|
}
|
this.state.registeredVariables[v.name] = v;
|
this.incRef(v, this.backend);
|
return v;
|
}
|
|
incRef(a: Tensor, backend: KernelBackend): void {
|
const refCount = this.state.tensorInfo.has(a.dataId) ?
|
this.state.tensorInfo.get(a.dataId).refCount :
|
0;
|
this.state.numTensors++;
|
if (a.dtype === 'string') {
|
this.state.numStringTensors++;
|
}
|
if (refCount === 0) {
|
this.state.numDataBuffers++;
|
|
// Bytes for complex numbers are counted by their components. Bytes for
|
// string tensors are counted when writing values.
|
let bytes = 0;
|
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
|
bytes = a.size * util.bytesPerElement(a.dtype);
|
}
|
this.state.tensorInfo.set(a.dataId, {
|
backend: backend || this.backend,
|
dtype: a.dtype,
|
shape: a.shape,
|
bytes,
|
refCount: 0
|
});
|
this.state.numBytes += bytes;
|
}
|
this.state.tensorInfo.get(a.dataId).refCount++;
|
if (!(a instanceof Variable)) {
|
this.track(a);
|
}
|
}
|
|
disposeTensor(a: Tensor): void {
|
if (!this.state.tensorInfo.has(a.dataId)) {
|
return;
|
}
|
|
this.state.numTensors--;
|
if (a.dtype === 'string') {
|
this.state.numStringTensors--;
|
}
|
const info = this.state.tensorInfo.get(a.dataId);
|
const refCount = info.refCount;
|
if (refCount <= 1) {
|
// Don't count bytes for complex numbers as they are counted by their
|
// components.
|
if (a.dtype !== 'complex64') {
|
this.state.numBytes -= info.bytes;
|
}
|
this.state.numDataBuffers--;
|
info.backend.disposeData(a.dataId);
|
this.state.tensorInfo.delete(a.dataId);
|
} else {
|
this.state.tensorInfo.get(a.dataId).refCount--;
|
}
|
// TODO(nsthorat): Construct an error and save the stack trace for
|
// debugging when in debug mode. Creating a stack trace is too expensive
|
// to do unconditionally.
|
}
|
|
disposeVariables(): void {
|
for (const varName in this.state.registeredVariables) {
|
const v = this.state.registeredVariables[varName];
|
this.disposeVariable(v);
|
}
|
}
|
|
disposeVariable(v: Variable): void {
|
this.disposeTensor(v);
|
if (this.state.registeredVariables[v.name] != null) {
|
delete this.state.registeredVariables[v.name];
|
}
|
}
|
|
memory(): MemoryInfo {
|
const info = this.backend.memory() as MemoryInfo;
|
info.numTensors = this.state.numTensors;
|
info.numDataBuffers = this.state.numDataBuffers;
|
info.numBytes = this.state.numBytes;
|
if (this.state.numStringTensors > 0) {
|
info.unreliable = true;
|
if (info.reasons == null) {
|
info.reasons = [];
|
}
|
info.reasons.push(
|
'Memory usage by string tensors is approximate ' +
|
'(2 bytes per character)');
|
}
|
return info;
|
}
|
|
async profile(query: () => TensorContainer): Promise<ProfileInfo> {
|
this.state.profiling = true;
|
|
const startBytes = this.state.numBytes;
|
const startNumTensors = this.state.numTensors;
|
|
this.state.activeProfile.kernels = [];
|
this.state.activeProfile.result = query();
|
|
this.state.profiling = false;
|
|
this.state.activeProfile.peakBytes = Math.max(
|
...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
|
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
|
this.state.activeProfile.newTensors =
|
this.state.numTensors - startNumTensors;
|
return this.state.activeProfile;
|
}
|
|
isTapeOn(): boolean {
|
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
}
|
|
private addTapeNode(
|
kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
|
gradientsFunc: (dy: Tensor|Tensor[], saved: Tensor[]) => NamedGradientMap,
|
saved: Tensor[]): void {
|
const tapeNode: TapeNode =
|
{id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};
|
|
const gradConfig = getGradient(kernelName);
|
if (gradConfig != null) {
|
gradientsFunc = gradConfig.gradFunc;
|
}
|
if (gradientsFunc != null) {
|
tapeNode.gradient = (dys: Tensor[]) => {
|
// TODO(smilkov): To optimize back-prop, pass dys that are not used in
|
// the backprop graph to the user as null instead of zeros
|
dys = dys.map((dy, i) => {
|
if (dy == null) {
|
const output = outputs[i];
|
const vals = util.makeZerosTypedArray(output.size, output.dtype);
|
return this.makeTensor(vals, output.shape, output.dtype);
|
}
|
return dy;
|
});
|
// Grad functions of ops with single outputs expect a dy, while ops
|
// with multiple outputs expect dys (array of dy).
|
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved);
|
};
|
}
|
this.state.activeTape.push(tapeNode);
|
}
|
|
keep<T extends Tensor>(result: T): T {
|
result.kept = true;
|
return result;
|
}
|
|
private startTape() {
|
if (this.state.gradientDepth === 0) {
|
this.state.activeTape = [];
|
}
|
this.state.gradientDepth++;
|
}
|
|
private endTape() {
|
this.state.gradientDepth--;
|
}
|
|
/**
|
* Start a scope. Use this with endScope() to achieve the same functionality
|
* as scope() without the need for a function closure.
|
*/
|
startScope(name?: string) {
|
const scopeInfo: ScopeState = {
|
track: [],
|
name: 'unnamed scope',
|
id: this.state.nextScopeId++
|
};
|
if (name) {
|
scopeInfo.name = name;
|
}
|
this.state.scopeStack.push(scopeInfo);
|
this.state.activeScope = scopeInfo;
|
}
|
|
/**
|
* End a scope. Use this with startScope() to achieve the same functionality
|
* as scope() without the need for a function closure.
|
*/
|
endScope(result?: TensorContainer) {
|
const tensorsToTrackInParent = getTensorsInContainer(result);
|
const tensorsToTrackInParentSet =
|
new Set(tensorsToTrackInParent.map(t => t.id));
|
|
// Dispose the arrays tracked in this scope.
|
for (let i = 0; i < this.state.activeScope.track.length; i++) {
|
const tensor = this.state.activeScope.track[i];
|
if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
|
tensor.dispose();
|
}
|
}
|
|
const oldScope = this.state.scopeStack.pop();
|
this.state.activeScope = this.state.scopeStack.length === 0 ?
|
null :
|
this.state.scopeStack[this.state.scopeStack.length - 1];
|
|
// Track the current result in the parent scope.
|
tensorsToTrackInParent.forEach(tensor => {
|
// Only track the tensor if was allocated in the inner scope and is not
|
// globally kept.
|
if (!tensor.kept && tensor.scopeId === oldScope.id) {
|
this.track(tensor);
|
}
|
});
|
}
|
|
/**
|
* Returns gradients of `f` with respect to each of the `xs`. The gradients
|
* returned are of the same length as `xs`, but some might be null if `f`
|
* was not a function of that `x`. It also takes optional dy to multiply the
|
* gradient, which defaults to `1`.
|
*/
|
gradients<T extends Tensor>(
|
f: () => T, xs: Tensor[], dy?: T,
|
allowNoGradients = false): {value: T, grads: Tensor[]} {
|
util.assert(
|
xs.length > 0, () => 'gradients() received an empty list of xs.');
|
if (dy != null && dy.dtype !== 'float32') {
|
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
|
}
|
|
const y = this.scopedRun(
|
() => this.startTape(), () => this.endTape(),
|
() => this.tidy('forward', f));
|
|
util.assert(
|
y instanceof Tensor,
|
() => 'The result y returned by f() must be a tensor.');
|
// Filter out the nodes that don't connect x => y.
|
const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
|
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
|
throw new Error(
|
'Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
|
'that the f you passed encloses all operations that lead from x ' +
|
'to y.');
|
}
|
|
return this.tidy('backward', () => {
|
const accumulatedGradientMap: {[tensorId: number]: Tensor} = {};
|
accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;
|
|
// Backprop gradients through the filtered nodes.
|
backpropagateGradients(
|
accumulatedGradientMap, filteredTape,
|
// Pass the tidy function to avoid circular dep with `tape.ts`.
|
f => this.tidy(f as ScopeFn<Tensor>));
|
const grads = xs.map(x => accumulatedGradientMap[x.id]);
|
|
if (this.state.gradientDepth === 0) {
|
// This means that we are not computing higher-order gradients
|
// and can clean up the tape.
|
this.state.activeTape.forEach(node => {
|
for (const tensor of node.saved) {
|
tensor.dispose();
|
}
|
});
|
this.state.activeTape = null;
|
}
|
return {value: y, grads};
|
});
|
}
|
|
customGrad<T extends Tensor>(f: CustomGradientFunc<T>):
|
(...args: Array<Tensor|GradSaveFunc>) => T {
|
util.assert(
|
util.isFunction(f),
|
() => 'The f passed in customGrad(f) must be a function.');
|
return (...inputs: Tensor[]): T => {
|
util.assert(
|
inputs.every(t => t instanceof Tensor),
|
() => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
|
'tensors');
|
|
let res: {
|
value: T,
|
gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[],
|
};
|
const inputMap: NamedTensorMap = {};
|
inputs.forEach((input, i) => {
|
inputMap[i] = input;
|
});
|
return this.runKernelFunc(
|
(_, save) => {
|
res = f(...[...inputs, save]);
|
util.assert(
|
res.value instanceof Tensor,
|
() => 'The function f passed in customGrad(f) must return an ' +
|
'object where `obj.value` is a tensor');
|
util.assert(
|
util.isFunction(res.gradFunc),
|
() => 'The function f passed in customGrad(f) must return an ' +
|
'object where `obj.gradFunc` is a function.');
|
return res.value;
|
},
|
inputMap,
|
(dy: T, saved: Tensor[]) => {
|
const gradRes = res.gradFunc(dy, saved);
|
const grads: Tensor[] =
|
Array.isArray(gradRes) ? gradRes : [gradRes];
|
util.assert(
|
grads.length === inputs.length,
|
() => 'The function f passed in customGrad(f) must return an ' +
|
'object where `obj.gradFunc` is a function that returns ' +
|
'the same number of tensors as inputs passed to f(...).');
|
util.assert(
|
grads.every(t => t instanceof Tensor),
|
() => 'The function f passed in customGrad(f) must return an ' +
|
'object where `obj.gradFunc` is a function that returns ' +
|
'a list of only tensors.');
|
const gradMap: {[key: string]: () => Tensor} = {};
|
grads.forEach((grad, i) => {
|
gradMap[i] = () => grad;
|
});
|
return gradMap;
|
});
|
};
|
}
|
|
readSync(dataId: DataId): BackendValues {
|
// Route the read to the correct backend.
|
const info = this.state.tensorInfo.get(dataId);
|
return info.backend.readSync(dataId);
|
}
|
read(dataId: DataId): Promise<BackendValues> {
|
// Route the read to the correct backend.
|
const info = this.state.tensorInfo.get(dataId);
|
return info.backend.read(dataId);
|
}
|
|
async time(query: () => void): Promise<TimingInfo> {
|
const start = now();
|
const timingInfo = await this.backend.time(query) as TimingInfo;
|
timingInfo.wallMs = now() - start;
|
return timingInfo;
|
}
|
|
/**
|
* Tracks a Tensor in the current scope to be automatically cleaned up
|
* when the current scope ends, and returns the value.
|
*
|
* @param result The Tensor to track in the current scope.
|
*/
|
private track<T extends Tensor>(result: T): T {
|
if (this.state.activeScope != null) {
|
result.scopeId = this.state.activeScope.id;
|
this.state.activeScope.track.push(result);
|
}
|
|
return result;
|
}
|
|
get registeredVariables(): NamedVariableMap {
|
return this.state.registeredVariables;
|
}
|
|
/**
|
* Resets the engine state. Removes all backends but does not remove
|
* registered backend factories.
|
*/
|
reset(): void {
|
// Make any pending promise obsolete.
|
this.pendingBackendInitId++;
|
|
this.state.dispose();
|
this.ENV.reset();
|
this.state = new EngineState();
|
|
for (const backendName in this.registry) {
|
this.disposeRegisteredKernels(backendName);
|
this.registry[backendName].dispose();
|
delete this.registry[backendName];
|
}
|
this.backendName = null;
|
this.backendInstance = null;
|
this.pendingBackendInit = null;
|
}
|
}
|
|
function ones(shape: number[]): Tensor {
|
const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
|
return ENGINE.makeTensor(values, shape, 'float32');
|
}
|
|
let GLOBAL: {_tfengine: Engine};
|
function getGlobalNamespace(): {_tfengine: Engine} {
|
if (GLOBAL == null) {
|
// tslint:disable-next-line:no-any
|
let ns: any;
|
if (typeof (window) !== 'undefined') {
|
ns = window;
|
} else if (typeof (global) !== 'undefined') {
|
ns = global;
|
} else if (typeof (process) !== 'undefined') {
|
ns = process;
|
} else if (typeof (self) !== 'undefined') {
|
ns = self;
|
} else {
|
throw new Error('Could not find a global object');
|
}
|
GLOBAL = ns;
|
}
|
return GLOBAL;
|
}
|
|
function getOrMakeEngine(): Engine {
|
const ns = getGlobalNamespace();
|
if (ns._tfengine == null) {
|
const environment = new Environment(ns);
|
ns._tfengine = new Engine(environment);
|
}
|
setEnvironmentGlobal(ns._tfengine.ENV);
|
|
// Tell the current tensor interface that the global engine is responsible
|
// for tracking.
|
setTensorTracker(() => ns._tfengine);
|
return ns._tfengine;
|
}
|
|
export const ENGINE = getOrMakeEngine();
|