/**
|
* @license
|
* Copyright 2017 Google Inc. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
|
import {Tensor} from './tensor';
|
import {NamedTensorMap} from './tensor_types';
|
import * as util from './util';
|
|
export interface TapeNode {
|
id: number;
|
kernelName: string;
|
outputs: Tensor[];
|
inputs: NamedTensorMap;
|
// Optional params, defined only for ops with gradient impl.
|
gradient?: (dys: Tensor[]) => NamedGradientMap;
|
saved?: Tensor[];
|
}
|
|
export type NamedGradientMap = {
|
[inputName: string]: () => Tensor;
|
};
|
|
/**
|
* Computes a list of TapeNodes that connect x to y, filtering everything else
|
* out and preserving the order of the original tape elements.
|
*
|
* @param tape The tape elements to filter.
|
* @param xs The input Tensors.
|
* @param y The output Tensor.
|
*/
|
export function getFilteredNodesXToY(
|
tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] {
|
// Forward pass to compute all the nodes and Tensors that are transitively a
|
// function of x.
|
const tensorsFromX: {[tensorId: number]: boolean} = {};
|
const nodesFromX: {[nodeId: number]: boolean} = {};
|
for (let i = 0; i < xs.length; i++) {
|
tensorsFromX[xs[i].id] = true;
|
}
|
|
for (let i = 0; i < tape.length; i++) {
|
const node = tape[i];
|
const nodeInputs = node.inputs;
|
for (const inputName in nodeInputs) {
|
const input = nodeInputs[inputName];
|
|
let anyInputFromX = false;
|
for (let j = 0; j < xs.length; j++) {
|
if (tensorsFromX[input.id]) {
|
node.outputs.forEach(output => tensorsFromX[output.id] = true);
|
anyInputFromX = true;
|
nodesFromX[node.id] = true;
|
break;
|
}
|
}
|
|
if (anyInputFromX) {
|
break;
|
}
|
}
|
}
|
|
// Backward pass to find all of the nodes and Tensors that lead to y.
|
const tensorsLeadToY: {[tensorId: number]: boolean} = {};
|
tensorsLeadToY[y.id] = true;
|
const nodesToY: {[nodeId: number]: boolean} = {};
|
|
for (let i = tape.length - 1; i >= 0; i--) {
|
const node = tape[i];
|
const nodeInputs = node.inputs;
|
|
// If any of the outputs lead to y, mark all of the inputs as leading to y.
|
for (let j = 0; j < node.outputs.length; j++) {
|
if (tensorsLeadToY[node.outputs[j].id]) {
|
for (const inputName in nodeInputs) {
|
tensorsLeadToY[nodeInputs[inputName].id] = true;
|
nodesToY[node.id] = true;
|
}
|
break;
|
}
|
}
|
}
|
|
// Return the paths that come from x and lead to y.
|
const filteredTape: TapeNode[] = [];
|
for (let i = 0; i < tape.length; i++) {
|
const node = tape[i];
|
|
if (nodesFromX[node.id] && nodesToY[node.id]) {
|
// Prune the inputs from the node that aren't a function of x.
|
const prunedInputs: {[inputName: string]: Tensor} = {};
|
for (const inputName in node.inputs) {
|
const nodeInput = node.inputs[inputName];
|
if (tensorsFromX[nodeInput.id]) {
|
prunedInputs[inputName] = nodeInput;
|
}
|
}
|
|
// Copy the node and overwrite inputsAndArgs to the pruned version.
|
const prunedNode = Object.assign({}, node);
|
prunedNode.inputs = prunedInputs;
|
prunedNode.outputs = node.outputs;
|
|
filteredTape.push(prunedNode);
|
}
|
}
|
|
return filteredTape;
|
}
|
|
/**
|
* Backpropagate gradients through the filtered TapeNodes.
|
*
|
* @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
|
* is mutated by this method.
|
* @param filteredTape The filtered TapeNodes to backprop through.
|
*/
|
export function backpropagateGradients(
|
tensorAccumulatedGradientMap: {[tensorId: number]: Tensor},
|
filteredTape: TapeNode[], tidy: (f: Function) => Tensor) {
|
// Walk the tape backward and keep a map of Tensor to its gradient.
|
for (let i = filteredTape.length - 1; i >= 0; i--) {
|
const node = filteredTape[i];
|
|
const dys: Tensor[] = [];
|
node.outputs.forEach(o => {
|
const gradTensor = tensorAccumulatedGradientMap[o.id];
|
if (gradTensor != null) {
|
dys.push(gradTensor);
|
} else {
|
// This particular output is not in the back-propagation subgraph, so it
|
// does not affect the final output, thus we put null for its dy.
|
dys.push(null);
|
}
|
});
|
|
if (node.gradient == null) {
|
throw new Error(
|
`Cannot compute gradient: gradient function not found ` +
|
`for ${node.kernelName}.`);
|
}
|
|
// Backprop dy through this node and accumulate gradients over the inputs.
|
const inputGradients = node.gradient(dys);
|
|
for (const inputName in node.inputs) {
|
if (!(inputName in inputGradients)) {
|
throw new Error(
|
`Cannot backprop through input ${inputName}. ` +
|
`Available gradients found: ${Object.keys(inputGradients)}.`);
|
}
|
|
// Call the gradient function.
|
const dx = tidy(() => inputGradients[inputName]());
|
if (dx.dtype !== 'float32') {
|
throw new Error(
|
`Error in gradient for op ${
|
node.kernelName}. The gradient of input ` +
|
`${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
|
}
|
const x = node.inputs[inputName];
|
if (!util.arraysEqual(dx.shape, x.shape)) {
|
throw new Error(
|
`Error in gradient for op ${
|
node.kernelName}. The gradient of input ` +
|
`'${inputName}' has shape '${dx.shape}', which does not match ` +
|
`the shape of the input '${x.shape}'`);
|
}
|
|
if (tensorAccumulatedGradientMap[x.id] == null) {
|
tensorAccumulatedGradientMap[x.id] = dx;
|
} else {
|
const curGradient = tensorAccumulatedGradientMap[x.id];
|
tensorAccumulatedGradientMap[x.id] = curGradient.add(dx);
|
curGradient.dispose();
|
}
|
}
|
}
|
}
|