"use strict";
|
/**
|
* @license
|
* Copyright 2017 Google Inc. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
Object.defineProperty(exports, "__esModule", { value: true });
|
var util = require("./util");
|
/**
|
* Computes a list of TapeNodes that connect x to y, filtering everything else
|
* out and preserving the order of the original tape elements.
|
*
|
* @param tape The tape elements to filter.
|
* @param xs The input Tensors.
|
* @param y The output Tensor.
|
*/
|
function getFilteredNodesXToY(tape, xs, y) {
|
// Forward pass to compute all the nodes and Tensors that are transitively a
|
// function of x.
|
var tensorsFromX = {};
|
var nodesFromX = {};
|
for (var i = 0; i < xs.length; i++) {
|
tensorsFromX[xs[i].id] = true;
|
}
|
for (var i = 0; i < tape.length; i++) {
|
var node = tape[i];
|
var nodeInputs = node.inputs;
|
for (var inputName in nodeInputs) {
|
var input = nodeInputs[inputName];
|
var anyInputFromX = false;
|
for (var j = 0; j < xs.length; j++) {
|
if (tensorsFromX[input.id]) {
|
node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; });
|
anyInputFromX = true;
|
nodesFromX[node.id] = true;
|
break;
|
}
|
}
|
if (anyInputFromX) {
|
break;
|
}
|
}
|
}
|
// Backward pass to find all of the nodes and Tensors that lead to y.
|
var tensorsLeadToY = {};
|
tensorsLeadToY[y.id] = true;
|
var nodesToY = {};
|
for (var i = tape.length - 1; i >= 0; i--) {
|
var node = tape[i];
|
var nodeInputs = node.inputs;
|
// If any of the outputs lead to y, mark all of the inputs as leading to y.
|
for (var j = 0; j < node.outputs.length; j++) {
|
if (tensorsLeadToY[node.outputs[j].id]) {
|
for (var inputName in nodeInputs) {
|
tensorsLeadToY[nodeInputs[inputName].id] = true;
|
nodesToY[node.id] = true;
|
}
|
break;
|
}
|
}
|
}
|
// Return the paths that come from x and lead to y.
|
var filteredTape = [];
|
for (var i = 0; i < tape.length; i++) {
|
var node = tape[i];
|
if (nodesFromX[node.id] && nodesToY[node.id]) {
|
// Prune the inputs from the node that aren't a function of x.
|
var prunedInputs = {};
|
for (var inputName in node.inputs) {
|
var nodeInput = node.inputs[inputName];
|
if (tensorsFromX[nodeInput.id]) {
|
prunedInputs[inputName] = nodeInput;
|
}
|
}
|
// Copy the node and overwrite inputsAndArgs to the pruned version.
|
var prunedNode = Object.assign({}, node);
|
prunedNode.inputs = prunedInputs;
|
prunedNode.outputs = node.outputs;
|
filteredTape.push(prunedNode);
|
}
|
}
|
return filteredTape;
|
}
|
exports.getFilteredNodesXToY = getFilteredNodesXToY;
|
/**
|
* Backpropagate gradients through the filtered TapeNodes.
|
*
|
* @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
|
* is mutated by this method.
|
* @param filteredTape The filtered TapeNodes to backprop through.
|
*/
|
function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy) {
|
var _loop_1 = function (i) {
|
var node = filteredTape[i];
|
var dys = [];
|
node.outputs.forEach(function (o) {
|
var gradTensor = tensorAccumulatedGradientMap[o.id];
|
if (gradTensor != null) {
|
dys.push(gradTensor);
|
}
|
else {
|
// This particular output is not in the back-propagation subgraph, so it
|
// does not affect the final output, thus we put null for its dy.
|
dys.push(null);
|
}
|
});
|
if (node.gradient == null) {
|
throw new Error("Cannot compute gradient: gradient function not found " +
|
("for " + node.kernelName + "."));
|
}
|
// Backprop dy through this node and accumulate gradients over the inputs.
|
var inputGradients = node.gradient(dys);
|
var _loop_2 = function (inputName) {
|
if (!(inputName in inputGradients)) {
|
throw new Error("Cannot backprop through input " + inputName + ". " +
|
("Available gradients found: " + Object.keys(inputGradients) + "."));
|
}
|
// Call the gradient function.
|
var dx = tidy(function () { return inputGradients[inputName](); });
|
if (dx.dtype !== 'float32') {
|
throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " +
|
(inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'"));
|
}
|
var x = node.inputs[inputName];
|
if (!util.arraysEqual(dx.shape, x.shape)) {
|
throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " +
|
("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") +
|
("the shape of the input '" + x.shape + "'"));
|
}
|
if (tensorAccumulatedGradientMap[x.id] == null) {
|
tensorAccumulatedGradientMap[x.id] = dx;
|
}
|
else {
|
var curGradient = tensorAccumulatedGradientMap[x.id];
|
tensorAccumulatedGradientMap[x.id] = curGradient.add(dx);
|
curGradient.dispose();
|
}
|
};
|
for (var inputName in node.inputs) {
|
_loop_2(inputName);
|
}
|
};
|
// Walk the tape backward and keep a map of Tensor to its gradient.
|
for (var i = filteredTape.length - 1; i >= 0; i--) {
|
_loop_1(i);
|
}
|
}
|
exports.backpropagateGradients = backpropagateGradients;
|
//# sourceMappingURL=tape.js.map
|