/**
|
* @license
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import * as util from './util';
|
/**
|
* Computes a list of TapeNodes that connect x to y, filtering everything else
|
* out and preserving the order of the original tape elements.
|
*
|
* @param tape The tape elements to filter.
|
* @param xs The input Tensors.
|
* @param y The output Tensor.
|
*/
|
export function getFilteredNodesXToY(tape, xs, y) {
|
// Forward pass to compute all the nodes and Tensors that are transitively a
|
// function of x.
|
const tensorsFromX = {};
|
const nodesFromX = {};
|
for (let i = 0; i < xs.length; i++) {
|
tensorsFromX[xs[i].id] = true;
|
}
|
for (let i = 0; i < tape.length; i++) {
|
const node = tape[i];
|
const nodeInputs = node.inputs;
|
for (const inputName in nodeInputs) {
|
const input = nodeInputs[inputName];
|
let anyInputFromX = false;
|
for (let j = 0; j < xs.length; j++) {
|
if (tensorsFromX[input.id]) {
|
node.outputs.forEach(output => tensorsFromX[output.id] = true);
|
anyInputFromX = true;
|
nodesFromX[node.id] = true;
|
break;
|
}
|
}
|
if (anyInputFromX) {
|
break;
|
}
|
}
|
}
|
// Backward pass to find all of the nodes and Tensors that lead to y.
|
const tensorsLeadToY = {};
|
tensorsLeadToY[y.id] = true;
|
const nodesToY = {};
|
for (let i = tape.length - 1; i >= 0; i--) {
|
const node = tape[i];
|
const nodeInputs = node.inputs;
|
// If any of the outputs lead to y, mark all of the inputs as leading to y.
|
for (let j = 0; j < node.outputs.length; j++) {
|
if (tensorsLeadToY[node.outputs[j].id]) {
|
for (const inputName in nodeInputs) {
|
tensorsLeadToY[nodeInputs[inputName].id] = true;
|
nodesToY[node.id] = true;
|
}
|
break;
|
}
|
}
|
}
|
// Return the paths that come from x and lead to y.
|
const filteredTape = [];
|
for (let i = 0; i < tape.length; i++) {
|
const node = tape[i];
|
if (nodesFromX[node.id] && nodesToY[node.id]) {
|
// Prune the inputs from the node that aren't a function of x.
|
const prunedInputs = {};
|
for (const inputName in node.inputs) {
|
const nodeInput = node.inputs[inputName];
|
if (tensorsFromX[nodeInput.id]) {
|
prunedInputs[inputName] = nodeInput;
|
}
|
}
|
// Copy the node and overwrite inputsAndArgs to the pruned version.
|
const prunedNode = Object.assign({}, node);
|
prunedNode.inputs = prunedInputs;
|
prunedNode.outputs = node.outputs;
|
filteredTape.push(prunedNode);
|
}
|
}
|
return filteredTape;
|
}
|
/**
|
* Backpropagate gradients through the filtered TapeNodes.
|
*
|
* @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
|
* is mutated by this method.
|
* @param filteredTape The filtered TapeNodes to backprop through.
|
*/
|
export function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
|
// Walk the tape backward and keep a map of Tensor to its gradient.
|
for (let i = filteredTape.length - 1; i >= 0; i--) {
|
const node = filteredTape[i];
|
const dys = [];
|
node.outputs.forEach(o => {
|
const gradTensor = tensorAccumulatedGradientMap[o.id];
|
if (gradTensor != null) {
|
dys.push(gradTensor);
|
}
|
else {
|
// This particular output is not in the back-propagation subgraph, so it
|
// does not affect the final output, thus we put null for its dy.
|
dys.push(null);
|
}
|
});
|
if (node.gradient == null) {
|
throw new Error(`Cannot compute gradient: gradient function not found ` +
|
`for ${node.kernelName}.`);
|
}
|
// Backprop dy through this node and accumulate gradients over the inputs.
|
const inputGradients = node.gradient(dys);
|
for (const inputName in node.inputs) {
|
if (!(inputName in inputGradients)) {
|
throw new Error(`Cannot backprop through input ${inputName}. ` +
|
`Available gradients found: ${Object.keys(inputGradients)}.`);
|
}
|
// Call the gradient function.
|
const dx = tidy(() => inputGradients[inputName]());
|
if (dx.dtype !== 'float32') {
|
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
|
`${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
|
}
|
const x = node.inputs[inputName];
|
if (!util.arraysEqual(dx.shape, x.shape)) {
|
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
|
`'${inputName}' has shape '${dx.shape}', which does not match ` +
|
`the shape of the input '${x.shape}'`);
|
}
|
if (tensorAccumulatedGradientMap[x.id] == null) {
|
tensorAccumulatedGradientMap[x.id] = dx;
|
}
|
else {
|
const curGradient = tensorAccumulatedGradientMap[x.id];
|
tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
|
curGradient.dispose();
|
}
|
}
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"tape.js","sourceRoot":"","sources":["../../../../../tfjs-core/src/tape.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAIH,OAAO,KAAK,IAAI,MAAM,QAAQ,CAAC;AAgB/B;;;;;;;GAOG;AACH,MAAM,UAAU,oBAAoB,CAChC,IAAgB,EAAE,EAAY,EAAE,CAAS;IAC3C,4EAA4E;IAC5E,iBAAiB;IACjB,MAAM,YAAY,GAAkC,EAAE,CAAC;IACvD,MAAM,UAAU,GAAgC,EAAE,CAAC;IACnD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;QAClC,YAAY,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC;KAC/B;IAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;QACpC,MAAM,IAAI,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;QACrB,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC;QAC/B,KAAK,MAAM,SAAS,IAAI,UAAU,EAAE;YAClC,MAAM,KAAK,GAAG,UAAU,CAAC,SAAS,CAAC,CAAC;YAEpC,IAAI,aAAa,GAAG,KAAK,CAAC;YAC1B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBAClC,IAAI,YAAY,CAAC,KAAK,CAAC,EAAE,CAAC,EAAE;oBAC1B,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC,YAAY,CAAC,MAAM,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,CAAC;oBAC/D,aAAa,GAAG,IAAI,CAAC;oBACrB,UAAU,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC;oBAC3B,MAAM;iBACP;aACF;YAED,IAAI,aAAa,EAAE;gBACjB,MAAM;aACP;SACF;KACF;IAED,qEAAqE;IACrE,MAAM,cAAc,GAAkC,EAAE,CAAC;IACzD,cAAc,CAAC,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC;IAC5B,MAAM,QAAQ,GAAgC,EAAE,CAAC;IAEjD,KAAK,IAAI,CAAC,GAAG,IAAI,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE;QACzC,MAAM,IAAI,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;QACrB,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC;QAE/B,2EAA2E;QAC3E,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YAC5C,IAAI,cAAc,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE;gBACtC,KAAK,MAAM,SAAS,IAAI,UAAU,EAAE;oBAClC,cAAc,CAAC,UAAU,CAAC,SAAS,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC;oBAChD,QAAQ,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC;iBAC1B;gBACD,MAAM;aACP;SACF;KACF;IAED,mDAAmD;IACnD,MAAM,YAAY,GAAe,EAAE,CAAC;IACpC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;QACpC,MAAM,IAAI,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;QAErB,IAAI,UAAU,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,QAAQ,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE;YAC5C,8DAA8D;YAC9D,MAAM,YAAY,GAAkC,EAAE,CAAC;YACvD,KAAK,MAAM,SAAS,IAAI,IAAI,CAAC,MAAM,EAAE;gBACnC,MAAM,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;gBACzC,IAAI,YAAY,CAAC,SAAS,CAAC,EAAE,CAAC,EAAE;oBAC9B,YAAY,CAAC,SAAS,CAAC,GAAG,SAAS,CAAC;iBACrC;aACF;YAED,mEAAmE;YACnE,MAAM,UAAU,GAAG,MAAM,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAC;YAC3C,UAAU,CAAC,MAAM,GAAG,YAAY,CAAC;YACjC,UAAU,CAAC,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC;YAElC,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;SAC/B;KACF;IAED,OAAO,YAAY,CAAC;AACtB,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,sBAAsB,CAClC,4BAA0D,EAC1D,YAAwB,EAAE,IAA6B,EACvD,GAAqC;IACvC,mEAAmE;IACnE,KAAK,IAAI,CAAC,GAAG,YAAY,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE;QACjD,MAAM,IAAI,GAAG,YAAY,CAAC,CAAC,CAAC,CAAC;QAE7B,MAAM,GAAG,GAAa,EAAE,CAAC;QACzB,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;YACvB,MAAM,UAAU,GAAG,4BAA4B,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;YACtD,IAAI,UAAU,IAAI,IAAI,EAAE;gBACtB,GAAG,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;aACtB;iBAAM;gBACL,wEAAwE;gBACxE,iEAAiE;gBACjE,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;aAChB;QACH,CAAC,CAAC,CAAC;QAEH,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,IAAI,KAAK,CACX,uDAAuD;gBACvD,OAAO,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC;SAChC;QAED,0EAA0E;QAC1E,MAAM,cAAc,GAAG,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,CAAC;QAE1C,KAAK,MAAM,SAAS,IAAI,IAAI,CAAC,MAAM,EAAE;YACnC,IAAI,CAAC,CAAC,SAAS,IAAI,cAAc,CAAC,EAAE;gBAClC,MAAM,IAAI,KAAK,CACX,iCAAiC,SAAS,IAAI;oBAC9C,8BAA8B,MAAM,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC;aACnE;YAED,8BAA8B;YAC9B,MAAM,EAAE,GAAG,IAAI,CAAC,GAAG,EAAE,CAAC,cAAc,CAAC,SAAS,CAAC,EAAE,CAAC,CAAC;YACnD,IAAI,EAAE,CAAC,KAAK,KAAK,SAAS,EAAE;gBAC1B,MAAM,IAAI,KAAK,CACX,4BACI,IAAI,CAAC,UAAU,0BAA0B;oBAC7C,GAAG,SAAS,wCAAwC,EAAE,CAAC,KAAK,GAAG,CAAC,CAAC;aACtE;YACD,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;YACjC,IAAI,CAAC,IAAI,CAAC,WAAW,CAAC,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,EAAE;gBACxC,MAAM,IAAI,KAAK,CACX,4BACI,IAAI,CAAC,UAAU,0BAA0B;oBAC7C,IAAI,SAAS,gBAAgB,EAAE,CAAC,KAAK,0BAA0B;oBAC/D,2BAA2B,CAAC,CAAC,KAAK,GAAG,CAAC,CAAC;aAC5C;YAED,IAAI,4BAA4B,CAAC,CAAC,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE;gBAC9C,4BAA4B,CAAC,CAAC,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC;aACzC;iBAAM;gBACL,MAAM,WAAW,GAAG,4BAA4B,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;gBACvD,4BAA4B,CAAC,CAAC,CAAC,EAAE,CAAC,GAAG,GAAG,CAAC,WAAW,EAAE,EAAE,CAAC,CAAC;gBAC1D,WAAW,CAAC,OAAO,EAAE,CAAC;aACvB;SACF;KACF;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2017 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from './tensor';\nimport {NamedTensorMap} from './tensor_types';\nimport * as util from './util';\n\nexport interface TapeNode {\n  id: number;\n  kernelName: string;\n  outputs: Tensor[];\n  inputs: NamedTensorMap;\n  // Optional params, defined only for ops with gradient impl.\n  gradient?: (dys: Tensor[]) => NamedGradientMap;\n  saved?: Tensor[];\n}\n\nexport type NamedGradientMap = {\n  [inputName: string]: () => Tensor;\n};\n\n/**\n * Computes a list of TapeNodes that connect x to y, filtering everything else\n * out and preserving the order of the original tape elements.\n *\n * @param tape The tape elements to filter.\n * @param xs The input Tensors.\n * @param y The output Tensor.\n */\nexport function getFilteredNodesXToY(\n    tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] {\n  // Forward pass to compute all the nodes and Tensors that are transitively a\n  // function of x.\n  const tensorsFromX: {[tensorId: number]: boolean} = {};\n  const nodesFromX: {[nodeId: number]: boolean} = {};\n  for (let i = 0; i < xs.length; i++) {\n    tensorsFromX[xs[i].id] = true;\n  }\n\n  for (let i = 0; i < tape.length; i++) {\n    const node = tape[i];\n    const nodeInputs = node.inputs;\n    for (const inputName in nodeInputs) {\n      const input = nodeInputs[inputName];\n\n      let anyInputFromX = false;\n      for (let j = 0; j < xs.length; j++) {\n        if (tensorsFromX[input.id]) {\n          node.outputs.forEach(output => tensorsFromX[output.id] = true);\n          anyInputFromX = true;\n          nodesFromX[node.id] = true;\n          break;\n        }\n      }\n\n      if (anyInputFromX) {\n        break;\n      }\n    }\n  }\n\n  // Backward pass to find all of the nodes and Tensors that lead to y.\n  const tensorsLeadToY: {[tensorId: number]: boolean} = {};\n  tensorsLeadToY[y.id] = true;\n  const nodesToY: {[nodeId: number]: boolean} = {};\n\n  for (let i = tape.length - 1; i >= 0; i--) {\n    const node = tape[i];\n    const nodeInputs = node.inputs;\n\n    // If any of the outputs lead to y, mark all of the inputs as leading to y.\n    for (let j = 0; j < node.outputs.length; j++) {\n      if (tensorsLeadToY[node.outputs[j].id]) {\n        for (const inputName in nodeInputs) {\n          tensorsLeadToY[nodeInputs[inputName].id] = true;\n          nodesToY[node.id] = true;\n        }\n        break;\n      }\n    }\n  }\n\n  // Return the paths that come from x and lead to y.\n  const filteredTape: TapeNode[] = [];\n  for (let i = 0; i < tape.length; i++) {\n    const node = tape[i];\n\n    if (nodesFromX[node.id] && nodesToY[node.id]) {\n      // Prune the inputs from the node that aren't a function of x.\n      const prunedInputs: {[inputName: string]: Tensor} = {};\n      for (const inputName in node.inputs) {\n        const nodeInput = node.inputs[inputName];\n        if (tensorsFromX[nodeInput.id]) {\n          prunedInputs[inputName] = nodeInput;\n        }\n      }\n\n      // Copy the node and overwrite inputsAndArgs to the pruned version.\n      const prunedNode = Object.assign({}, node);\n      prunedNode.inputs = prunedInputs;\n      prunedNode.outputs = node.outputs;\n\n      filteredTape.push(prunedNode);\n    }\n  }\n\n  return filteredTape;\n}\n\n/**\n * Backpropagate gradients through the filtered TapeNodes.\n *\n * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map\n * is mutated by this method.\n * @param filteredTape The filtered TapeNodes to backprop through.\n */\nexport function backpropagateGradients(\n    tensorAccumulatedGradientMap: {[tensorId: number]: Tensor},\n    filteredTape: TapeNode[], tidy: (f: Function) => Tensor,\n    add: (a: Tensor, b: Tensor) => Tensor) {\n  // Walk the tape backward and keep a map of Tensor to its gradient.\n  for (let i = filteredTape.length - 1; i >= 0; i--) {\n    const node = filteredTape[i];\n\n    const dys: Tensor[] = [];\n    node.outputs.forEach(o => {\n      const gradTensor = tensorAccumulatedGradientMap[o.id];\n      if (gradTensor != null) {\n        dys.push(gradTensor);\n      } else {\n        // This particular output is not in the back-propagation subgraph, so it\n        // does not affect the final output, thus we put null for its dy.\n        dys.push(null);\n      }\n    });\n\n    if (node.gradient == null) {\n      throw new Error(\n          `Cannot compute gradient: gradient function not found ` +\n          `for ${node.kernelName}.`);\n    }\n\n    // Backprop dy through this node and accumulate gradients over the inputs.\n    const inputGradients = node.gradient(dys);\n\n    for (const inputName in node.inputs) {\n      if (!(inputName in inputGradients)) {\n        throw new Error(\n            `Cannot backprop through input ${inputName}. ` +\n            `Available gradients found: ${Object.keys(inputGradients)}.`);\n      }\n\n      // Call the gradient function.\n      const dx = tidy(() => inputGradients[inputName]());\n      if (dx.dtype !== 'float32') {\n        throw new Error(\n            `Error in gradient for op ${\n                node.kernelName}. The gradient of input ` +\n            `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);\n      }\n      const x = node.inputs[inputName];\n      if (!util.arraysEqual(dx.shape, x.shape)) {\n        throw new Error(\n            `Error in gradient for op ${\n                node.kernelName}. The gradient of input ` +\n            `'${inputName}' has shape '${dx.shape}', which does not match ` +\n            `the shape of the input '${x.shape}'`);\n      }\n\n      if (tensorAccumulatedGradientMap[x.id] == null) {\n        tensorAccumulatedGradientMap[x.id] = dx;\n      } else {\n        const curGradient = tensorAccumulatedGradientMap[x.id];\n        tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);\n        curGradient.dispose();\n      }\n    }\n  }\n}\n"]}
|