/**
|
* @license
|
* Copyright 2018 Google LLC
|
*
|
* Use of this source code is governed by an MIT-style
|
* license that can be found in the LICENSE file or at
|
* https://opensource.org/licenses/MIT.
|
* =============================================================================
|
*/
|
/**
|
* Executor: Evaluates SymbolicTensor based on feeds.
|
*/
|
import { cast, dispose, memory, util } from '@tensorflow/tfjs-core';
|
import { ValueError } from '../errors';
|
import { LruCache } from '../utils/executor_utils';
|
import { toList } from '../utils/generic_utils';
|
import { InputLayer } from './input_layer';
|
import { SymbolicTensor } from './topology';
|
/**
|
* Helper function to check the dtype and shape compatibility of a feed value.
|
*/
|
function assertFeedCompatibility(key, val) {
|
// Check dtype compatibility.
|
if (key.dtype == null || key.dtype === val.dtype) {
|
// a. If types match, return val tensor as is.
|
return val;
|
}
|
try {
|
// b. Attempt to convert to expected type.
|
return cast(val, key.dtype);
|
}
|
catch (err) {
|
// c. If conversion fails, return helpful error.
|
throw new ValueError(`The dtype of the feed (${val.dtype}) can not be cast to the dtype ` +
|
`of the key '${key.name}' (${key.dtype}).`);
|
}
|
}
|
/**
|
* FeedDict: A mapping from unique SymbolicTensors to feed values for them.
|
* A feed value is a concrete value represented as an `Tensor`.
|
*/
|
export class FeedDict {
|
/**
|
* Constructor, optionally does copy-construction.
|
* @param feeds An Array of `Feed`s, or another `FeedDict`, in which case
|
* copy-construction will be performed.
|
*/
|
constructor(feeds) {
|
this.id2Value = {};
|
this.id2Mask = {};
|
this.name2Id = {};
|
if (feeds instanceof FeedDict) {
|
for (const id in feeds.id2Value) {
|
this.id2Value[id] = feeds.id2Value[id];
|
if (id in feeds.id2Mask) {
|
this.id2Mask[id] = feeds.id2Mask[id];
|
}
|
}
|
}
|
else {
|
if (feeds == null) {
|
return;
|
}
|
for (const feed of feeds) {
|
this.add(feed.key, feed.value);
|
}
|
}
|
}
|
/**
|
* Add a key-value pair to the FeedDict.
|
*
|
* @param key The key of the feed.
|
* @param value The value of the tensor feed.
|
* @param mask The value of the mask feed (optional).
|
* @returns This `FeedDict`.
|
* @throws ValueError: If the key `SymbolicTensor` already exists in the
|
* `FeedDict`.
|
*/
|
add(key, value, mask) {
|
if (this.id2Value[key.id] == null) {
|
this.id2Value[key.id] = assertFeedCompatibility(key, value);
|
this.name2Id[key.name] = key.id;
|
if (mask != null) {
|
this.id2Mask[key.id] = mask;
|
}
|
}
|
else {
|
throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`);
|
}
|
return this;
|
}
|
/**
|
* Add a Feed to the FeedDict.
|
* @param feed The new `Feed` to add.
|
* @returns This `FeedDict`.
|
*/
|
addFeed(feed) {
|
this.add(feed.key, feed.value);
|
}
|
/**
|
* Probe whether a key already exists in the FeedDict.
|
* @param key
|
*/
|
hasKey(key) {
|
return this.id2Value[key.id] != null;
|
}
|
/**
|
* Get all the SymbolicTensor available in this FeedDict.
|
*/
|
names() {
|
return Object.keys(this.name2Id);
|
}
|
/**
|
* Get the feed value for given key.
|
* @param key The SymbolicTensor, or its name (as a string), of which the
|
* value is sought.
|
* @returns If `key` exists, the corresponding feed value.
|
* @throws ValueError: If `key` does not exist in this `FeedDict`.
|
*/
|
getValue(key) {
|
if (key instanceof SymbolicTensor) {
|
if (this.id2Value[key.id] == null) {
|
throw new ValueError(`Nonexistent key: ${key.name}`);
|
}
|
else {
|
return this.id2Value[key.id];
|
}
|
}
|
else {
|
const id = this.name2Id[key];
|
if (id == null) {
|
throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
|
}
|
return this.id2Value[id];
|
}
|
}
|
/**
|
* Get the feed mask for given key.
|
* @param key The SymbolicTensor, or its name (as a string), of which the
|
* value is sought.
|
* @returns If `key` exists, the corresponding feed mask.
|
* @throws ValueError: If `key` does not exist in this `FeedDict`.
|
*/
|
getMask(key) {
|
if (key instanceof SymbolicTensor) {
|
if (this.id2Value[key.id] == null) {
|
throw new ValueError(`Nonexistent key: ${key.name}`);
|
}
|
else {
|
return this.id2Mask[key.id];
|
}
|
}
|
else {
|
const id = this.name2Id[key];
|
if (id == null) {
|
throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
|
}
|
return this.id2Mask[id];
|
}
|
}
|
/** Dispose all mask Tensors held by this object. */
|
disposeMasks() {
|
if (this.id2Mask != null) {
|
dispose(this.id2Mask);
|
}
|
}
|
}
|
// Cache for topologically sorted SymbolicTensors for given execution
|
// targets (i.e., fetches).
|
export const cachedSorted = new LruCache();
|
// Cache for recipient count maps for given execution targets (i.e., fetches).
|
export const cachedRecipientCounts = new LruCache();
|
export function updateCacheMaxEntries(maxEntries) {
|
if (cachedSorted != null) {
|
cachedSorted.setMaxEntries(maxEntries);
|
}
|
if (cachedRecipientCounts != null) {
|
cachedRecipientCounts.setMaxEntries(maxEntries);
|
}
|
}
|
/**
|
* Execute a SymbolicTensor by using concrete feed values.
|
*
|
* A `SymbolicTensor` object is a node in a computation graph of TF.js
|
* Layers. The object is backed by a source layer and input
|
* `SymbolicTensor`s to the source layer. This method evaluates
|
* the `call()` method of the source layer, using concrete values of the
|
* inputs obtained from either
|
* * `feedDict`, if the input key exists in `feedDict`, or else,
|
* * a recursive call to `execute()` itself.
|
*
|
* @param x: The `SymbolicTensor` to execute.
|
* @param feedDict: The feed values, as base condition of the recursion.
|
* execution.
|
* @param kwargs: Optional keyword arguments.
|
* @param probe: A probe object (of interface `ExecutionProbe`) used for
|
* testing memory footprint of `execute` calls.
|
* @returns Result of the execution.
|
* @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s
|
* encountered during the execution lacks a feed value in `feedDict`.
|
*/
|
export function execute(fetches, feedDict, kwargs, probe) {
|
const training = kwargs == null ? false : kwargs['training'];
|
const arrayFetches = Array.isArray(fetches);
|
const fetchArray = arrayFetches ? fetches : [fetches];
|
const outputNames = fetchArray.map(t => t.name);
|
const finalOutputs = [];
|
const feedNames = feedDict.names();
|
for (const outputName of outputNames) {
|
if (feedNames.indexOf(outputName) !== -1) {
|
finalOutputs.push(feedDict.getValue(outputName));
|
}
|
else {
|
finalOutputs.push(null);
|
}
|
}
|
if (probe != null) {
|
// For optional probing of memory footprint during execution.
|
probe.maxNumTensors = -Infinity;
|
probe.minNumTensors = Infinity;
|
}
|
// Check cache.
|
const fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(',');
|
let sorted = cachedSorted.get(fetchAndFeedKey);
|
let recipientCounts;
|
if (sorted == null) {
|
// Cache doesn't contain the desired combination of fetches. Compute
|
// topological sort for the combination for the first time.
|
const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
|
sorted = out.sorted;
|
recipientCounts = out.recipientCounts;
|
// Store results in cache for future use.
|
cachedSorted.put(fetchAndFeedKey, sorted);
|
cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts);
|
}
|
recipientCounts = {};
|
if (!training) {
|
Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey));
|
}
|
const internalFeedDict = new FeedDict(feedDict);
|
// Start iterative execution on the topologically-sorted SymbolicTensors.
|
for (let i = 0; i < sorted.length; ++i) {
|
if (probe != null) {
|
// For optional probing of memory usage during execution.
|
const numTensors = memory().numTensors;
|
if (numTensors > probe.maxNumTensors) {
|
probe.maxNumTensors = numTensors;
|
}
|
if (numTensors < probe.minNumTensors) {
|
probe.minNumTensors = numTensors;
|
}
|
}
|
const symbolic = sorted[i];
|
const srcLayer = symbolic.sourceLayer;
|
if (srcLayer instanceof InputLayer) {
|
continue;
|
}
|
const inputValues = [];
|
const inputMasks = [];
|
const tensorsToDispose = [];
|
let maskExists = false;
|
for (const input of symbolic.inputs) {
|
const value = internalFeedDict.getValue(input);
|
const mask = internalFeedDict.getMask(input);
|
inputValues.push(value);
|
inputMasks.push(mask);
|
if (mask != null) {
|
maskExists = true;
|
}
|
if (!training) {
|
recipientCounts[input.name]--;
|
if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&
|
outputNames.indexOf(input.name) === -1 && !value.isDisposed &&
|
input.sourceLayer.stateful !== true) {
|
tensorsToDispose.push(value);
|
}
|
}
|
}
|
if (maskExists) {
|
kwargs = kwargs || {};
|
kwargs['mask'] = inputMasks[0];
|
}
|
const outputTensors = toList(srcLayer.apply(inputValues, kwargs));
|
let outputMask = null;
|
if (srcLayer.supportsMasking) {
|
outputMask = srcLayer.computeMask(inputValues, inputMasks);
|
}
|
const layerOutputs = getNodeOutputs(symbolic);
|
const outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
|
for (let i = 0; i < outputSymbolicTensors.length; ++i) {
|
if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {
|
internalFeedDict.add(outputSymbolicTensors[i], outputTensors[i], Array.isArray(outputMask) ? outputMask[0] : outputMask);
|
}
|
const index = outputNames.indexOf(outputSymbolicTensors[i].name);
|
if (index !== -1) {
|
finalOutputs[index] = outputTensors[i];
|
}
|
}
|
if (!training) {
|
// Clean up Tensors that are no longer needed.
|
dispose(tensorsToDispose);
|
}
|
}
|
// NOTE(cais): Unlike intermediate tensors, we don't discard mask
|
// tensors as we go, because these tensors are sometimes passed over a
|
// series of mutliple layers, i.e., not obeying the immediate input
|
// relations in the graph. If this becomes a memory-usage concern,
|
// we can improve this in the future.
|
internalFeedDict.disposeMasks();
|
return arrayFetches ? finalOutputs : finalOutputs[0];
|
}
|
/**
|
* Sort the `SymbolicTensor`s topologically, for an array of fetches.
|
*
|
* This function calls getTopologicalSortAndRecipientCountsForOneFetch and
|
* merges their results.
|
*
|
* @param fetch The array of fetches requested. Must be a non-empty array.
|
* @param feedDict The dictionary of fed values.
|
* @returns sorted: Topologically-sorted array of SymbolicTensors.
|
* recipientCounts: Recipient counts for all SymbolicTensors in `sorted`.
|
*/
|
function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
|
util.assert(fetches != null && fetches.length > 0, () => `Expected at least one fetch, got none`);
|
let finalSorted = [];
|
let finalRecipientMap = {};
|
if (fetches.length === 1) {
|
// Special-casing 1 fetch for efficiency.
|
const out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
|
finalSorted = out.sorted;
|
finalRecipientMap = out.recipientMap;
|
}
|
else {
|
const visited = new Set();
|
for (const fetch of fetches) {
|
const { sorted, recipientMap } = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict);
|
// Merge sorted SymbolicTensor Arrays.
|
for (const symbolicTensor of sorted) {
|
if (!visited.has(symbolicTensor.name)) {
|
finalSorted.push(symbolicTensor);
|
visited.add(symbolicTensor.name);
|
}
|
}
|
// Merge recipient maps.
|
for (const name in recipientMap) {
|
if (finalRecipientMap[name] == null) {
|
finalRecipientMap[name] = new Set();
|
}
|
recipientMap[name].forEach(recipient => finalRecipientMap[name].add(recipient));
|
}
|
}
|
}
|
return {
|
sorted: finalSorted,
|
recipientCounts: recipientMap2Counts(finalRecipientMap)
|
};
|
}
|
function recipientMap2Counts(recipientMap) {
|
const recipientCounts = {};
|
for (const name in recipientMap) {
|
recipientCounts[name] = recipientMap[name].size;
|
}
|
return recipientCounts;
|
}
|
/**
|
* Sort the `SymbolicTensor`s topologically, for a single fetch.
|
*
|
* This helper function processes the upstream SymbolicTensors of a single
|
* fetch.
|
*
|
* @param fetch The single fetch requested.
|
* @param feedDict The dictionary of fed values.
|
* @returns sorted: Topologically-sorted array of SymbolicTensors.
|
* recipientMap: Recipient names for all SymbolicTensors in `sorted`.
|
*/
|
export function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) {
|
const visited = new Set();
|
const sorted = [];
|
const recipientMap = {};
|
// Put keys of the feedDict into visited first, so they don't have to be
|
// walked. This is needed in case where there are feeds for intermediate
|
// SymbolicTensors of the graph.
|
for (const key of feedDict.names()) {
|
visited.add(key);
|
}
|
const stack = [];
|
const marks = [];
|
// Initial population of stack and marks.
|
stack.push(fetch);
|
while (stack.length > 0) {
|
const top = stack[stack.length - 1];
|
if (visited.has(top.name)) {
|
stack.pop();
|
continue;
|
}
|
const topIsMarked = marks[marks.length - 1] === stack.length - 1;
|
if (top.inputs.length === 0 || topIsMarked) {
|
// Input SymbolicTensor or all children have been visited.
|
stack.pop();
|
sorted.push(top);
|
visited.add(top.name);
|
if (topIsMarked) {
|
marks.pop();
|
}
|
}
|
else {
|
// A non-input SymbolicTensor whose upstream SymbolicTensors haven't
|
// been visited yet. Push them onto the stack.
|
marks.push(stack.length - 1);
|
for (const input of top.inputs) {
|
// Increment the recipient count. Note that this needs to happen
|
// regardless of whether the SymbolicTensor has been visited before.
|
if (recipientMap[input.name] == null) {
|
recipientMap[input.name] = new Set();
|
}
|
recipientMap[input.name].add(top.name);
|
if (visited.has(input.name)) {
|
continue; // Avoid repeated visits to the same SymbolicTensor.
|
}
|
stack.push(input);
|
}
|
}
|
}
|
return { sorted, recipientMap };
|
}
|
/**
|
* Get the symbolic output tensors of the node to which a given fetch belongs.
|
* @param fetch The fetched symbolic tensor.
|
* @returns The Array of symbolic tensors output by the node to which `fetch`
|
* belongs.
|
*/
|
function getNodeOutputs(fetch) {
|
let layerOutputs;
|
if (fetch.sourceLayer.inboundNodes.length === 1) {
|
layerOutputs = fetch.sourceLayer.output;
|
}
|
else {
|
let nodeIndex = null;
|
for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {
|
for (const outputTensor of fetch.sourceLayer.inboundNodes[i]
|
.outputTensors) {
|
if (outputTensor.id === fetch.id) {
|
nodeIndex = i;
|
break;
|
}
|
}
|
}
|
layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);
|
}
|
return layerOutputs;
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"executor.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/engine/executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,EAAC,IAAI,EAAE,OAAO,EAAE,MAAM,EAAU,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE1E,OAAO,EAAC,UAAU,EAAC,MAAM,WAAW,CAAC;AAErC,OAAO,EAAC,QAAQ,EAAC,MAAM,yBAAyB,CAAC;AACjD,OAAO,EAAC,MAAM,EAAC,MAAM,wBAAwB,CAAC;AAE9C,OAAO,EAAC,UAAU,EAAC,MAAM,eAAe,CAAC;AACzC,OAAO,EAAC,cAAc,EAAC,MAAM,YAAY,CAAC;AAE1C;;GAEG;AACH,SAAS,uBAAuB,CAAC,GAAmB,EAAE,GAAW;IAC/D,6BAA6B;IAC7B,IAAI,GAAG,CAAC,KAAK,IAAI,IAAI,IAAI,GAAG,CAAC,KAAK,KAAK,GAAG,CAAC,KAAK,EAAE;QAChD,gDAAgD;QAChD,OAAO,GAAG,CAAC;KACZ;IACD,IAAI;QACF,2CAA2C;QAC3C,OAAO,IAAI,CAAC,GAAG,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC;KAC7B;IAAC,OAAO,GAAG,EAAE;QACZ,iDAAiD;QACjD,MAAM,IAAI,UAAU,CAChB,0BAA0B,GAAG,CAAC,KAAK,iCAAiC;YACpE,eAAe,GAAG,CAAC,IAAI,MAAM,GAAG,CAAC,KAAK,IAAI,CAAC,CAAC;KACjD;AACH,CAAC;AAUD;;;GAGG;AACH,MAAM,OAAO,QAAQ;IAKnB;;;;OAIG;IACH,YAAY,KAAuB;QAT3B,aAAQ,GAA2B,EAAE,CAAC;QACtC,YAAO,GAA2B,EAAE,CAAC;QACrC,YAAO,GAA6B,EAAE,CAAC;QAQ7C,IAAI,KAAK,YAAY,QAAQ,EAAE;YAC7B,KAAK,MAAM,EAAE,IAAI,KAAK,CAAC,QAAQ,EAAE;gBAC/B,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC;gBACvC,IAAI,EAAE,IAAI,KAAK,CAAC,OAAO,EAAE;oBACvB,IAAI,CAAC,OAAO,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;iBACtC;aACF;SACF;aAAM;YACL,IAAI,KAAK,IAAI,IAAI,EAAE;gBACjB,OAAO;aACR;YACD,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE;gBACxB,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;aAChC;SACF;IACH,CAAC;IAED;;;;;;;;;OASG;IACH,GAAG,CAAC,GAAmB,EAAE,KAAa,EAAE,IAAa;QACnD,IAAI,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE;YACjC,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,uBAAuB,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;YAC5D,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG,GAAG,CAAC,EAAE,CAAC;YAChC,IAAI,IAAI,IAAI,IAAI,EAAE;gBAChB,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC;aAC7B;SACF;aAAM;YACL,MAAM,IAAI,UAAU,CAAC,uBAAuB,GAAG,CAAC,IAAI,QAAQ,GAAG,CAAC,EAAE,EAAE,CAAC,CAAC;SACvE;QACD,OAAO,IAAI,CAAC;IACd,CAAC;IAED;;;;OAIG;IACH,OAAO,CAAC,IAAU;QAChB,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;IACjC,CAAC;IAED;;;OAGG;IACH,MAAM,CAAC,GAAmB;QACxB,OAAO,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,IAAI,CAAC;IACvC,CAAC;IAED;;OAEG;IACH,KAAK;QACH,OAAO,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;IACnC,CAAC;IAED;;;;;;OAMG;IACH,QAAQ,CAAC,GAA0B;QACjC,IAAI,GAAG,YAAY,cAAc,EAAE;YACjC,IAAI,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE;gBACjC,MAAM,IAAI,UAAU,CAAC,oBAAoB,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC;aACtD;iBAAM;gBACL,OAAO,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;aAC9B;SACF;aAAM;YACL,MAAM,EAAE,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;YAC7B,IAAI,EAAE,IAAI,IAAI,EAAE;gBACd,MAAM,IAAI,UAAU,CAAC,yCAAyC,GAAG,EAAE,CAAC,CAAC;aACtE;YACD,OAAO,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC;SAC1B;IACH,CAAC;IAED;;;;;;OAMG;IACH,OAAO,CAAC,GAA0B;QAChC,IAAI,GAAG,YAAY,cAAc,EAAE;YACjC,IAAI,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE;gBACjC,MAAM,IAAI,UAAU,CAAC,oBAAoB,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC;aACtD;iBAAM;gBACL,OAAO,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;aAC7B;SACF;aAAM;YACL,MAAM,EAAE,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;YAC7B,IAAI,EAAE,IAAI,IAAI,EAAE;gBACd,MAAM,IAAI,UAAU,CAAC,yCAAyC,GAAG,EAAE,CAAC,CAAC;aACtE;YACD,OAAO,IAAI,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;SACzB;IACH,CAAC;IAED,oDAAoD;IACpD,YAAY;QACV,IAAI,IAAI,CAAC,OAAO,IAAI,IAAI,EAAE;YACxB,OAAO,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SACvB;IACH,CAAC;CACF;AAED,qEAAqE;AACrE,2BAA2B;AAC3B,MAAM,CAAC,MAAM,YAAY,GACrB,IAAI,QAAQ,EAAoB,CAAC;AAErC,8EAA8E;AAC9E,MAAM,CAAC,MAAM,qBAAqB,GAC9B,IAAI,QAAQ,EAAmB,CAAC;AAEpC,MAAM,UAAU,qBAAqB,CAAC,UAAkB;IACtD,IAAI,YAAY,IAAI,IAAI,EAAE;QACxB,YAAY,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;KACxC;IACD,IAAI,qBAAqB,IAAI,IAAI,EAAE;QACjC,qBAAqB,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;KACjD;AACH,CAAC;AAsBD;;;;;;;;;;;;;;;;;;;;GAoBG;AACH,MAAM,UAAU,OAAO,CACnB,OAAwC,EAAE,QAAkB,EAC5D,MAAe,EAAE,KAAsB;IAEzC,MAAM,QAAQ,GAAY,MAAM,IAAI,IAAI,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,MAAM,CAAC,UAAU,CAAC,CAAC;IAEtE,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC;IAC5C,MAAM,UAAU,GACZ,YAAY,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC;IAEvC,MAAM,WAAW,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;IAChD,MAAM,YAAY,GAAa,EAAE,CAAC;IAClC,MAAM,SAAS,GAAG,QAAQ,CAAC,KAAK,EAAE,CAAC;IACnC,KAAK,MAAM,UAAU,IAAI,WAAW,EAAE;QACpC,IAAI,SAAS,CAAC,OAAO,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,EAAE;YACxC,YAAY,CAAC,IAAI,CAAC,QAAQ,CAAC,QAAQ,CAAC,UAAU,CAAC,CAAC,CAAC;SAClD;aAAM;YACL,YAAY,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SACzB;KACF;IAED,IAAI,KAAK,IAAI,IAAI,EAAE;QACjB,6DAA6D;QAC7D,KAAK,CAAC,aAAa,GAAG,CAAC,QAAQ,CAAC;QAChC,KAAK,CAAC,aAAa,GAAG,QAAQ,CAAC;KAChC;IAED,eAAe;IACf,MAAM,eAAe,GACjB,WAAW,CAAC,IAAI,CAAC,GAAG,CAAC,GAAG,GAAG,GAAG,QAAQ,CAAC,KAAK,EAAE,CAAC,IAAI,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;IACpE,IAAI,MAAM,GAAqB,YAAY,CAAC,GAAG,CAAC,eAAe,CAAC,CAAC;IACjE,IAAI,eAA8C,CAAC;IACnD,IAAI,MAAM,IAAI,IAAI,EAAE;QAClB,oEAAoE;QACpE,2DAA2D;QAC3D,MAAM,GAAG,GAAG,oCAAoC,CAAC,UAAU,EAAE,QAAQ,CAAC,CAAC;QACvE,MAAM,GAAG,GAAG,CAAC,MAAM,CAAC;QACpB,eAAe,GAAG,GAAG,CAAC,eAAe,CAAC;QAEtC,yCAAyC;QACzC,YAAY,CAAC,GAAG,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC;QAC1C,qBAAqB,CAAC,GAAG,CAAC,eAAe,EAAE,eAAe,CAAC,CAAC;KAC7D;IACD,eAAe,GAAG,EAAE,CAAC;IACrB,IAAI,CAAC,QAAQ,EAAE;QACb,MAAM,CAAC,MAAM,CAAC,eAAe,EAAE,qBAAqB,CAAC,GAAG,CAAC,eAAe,CAAC,CAAC,CAAC;KAC5E;IAED,MAAM,gBAAgB,GAAG,IAAI,QAAQ,CAAC,QAAQ,CAAC,CAAC;IAEhD,yEAAyE;IACzE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACtC,IAAI,KAAK,IAAI,IAAI,EAAE;YACjB,yDAAyD;YACzD,MAAM,UAAU,GAAG,MAAM,EAAE,CAAC,UAAU,CAAC;YACvC,IAAI,UAAU,GAAG,KAAK,CAAC,aAAa,EAAE;gBACpC,KAAK,CAAC,aAAa,GAAG,UAAU,CAAC;aAClC;YACD,IAAI,UAAU,GAAG,KAAK,CAAC,aAAa,EAAE;gBACpC,KAAK,CAAC,aAAa,GAAG,UAAU,CAAC;aAClC;SACF;QAED,MAAM,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QAC3B,MAAM,QAAQ,GAAG,QAAQ,CAAC,WAAW,CAAC;QACtC,IAAI,QAAQ,YAAY,UAAU,EAAE;YAClC,SAAS;SACV;QACD,MAAM,WAAW,GAAa,EAAE,CAAC;QACjC,MAAM,UAAU,GAAa,EAAE,CAAC;QAChC,MAAM,gBAAgB,GAAa,EAAE,CAAC;QAEtC,IAAI,UAAU,GAAG,KAAK,CAAC;QACvB,KAAK,MAAM,KAAK,IAAI,QAAQ,CAAC,MAAM,EAAE;YACnC,MAAM,KAAK,GAAG,gBAAgB,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC;YAC/C,MAAM,IAAI,GAAG,gBAAgB,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;YAC7C,WAAW,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;YACxB,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;YACtB,IAAI,IAAI,IAAI,IAAI,EAAE;gBAChB,UAAU,GAAG,IAAI,CAAC;aACnB;YACD,IAAI,CAAC,QAAQ,EAAE;gBACb,eAAe,CAAC,KAAK,CAAC,IAAI,CAAC,EAAE,CAAC;gBAC9B,IAAI,eAAe,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,KAAK,CAAC;oBAC5D,WAAW,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,UAAU;oBAC3D,KAAK,CAAC,WAAW,CAAC,QAAQ,KAAK,IAAI,EAAE;oBACvC,gBAAgB,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;iBAC9B;aACF;SACF;QAED,IAAI,UAAU,EAAE;YACd,MAAM,GAAG,MAAM,IAAI,EAAE,CAAC;YACtB,MAAM,CAAC,MAAM,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;SAChC;QACD,MAAM,aAAa,GACf,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAa,CAAC;QAC5D,IAAI,UAAU,GAAoB,IAAI,CAAC;QACvC,IAAI,QAAQ,CAAC,eAAe,EAAE;YAC5B,UAAU,GAAG,QAAQ,CAAC,WAAW,CAAC,WAAW,EAAE,UAAU,CAAC,CAAC;SAC5D;QACD,MAAM,YAAY,GAAG,cAAc,CAAC,QAAQ,CAAC,CAAC;QAC9C,MAAM,qBAAqB,GACvB,KAAK,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC;QAChE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,qBAAqB,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YACrD,IAAI,CAAC,gBAAgB,CAAC,MAAM,CAAC,qBAAqB,CAAC,CAAC,CAAC,CAAC,EAAE;gBACtD,gBAAgB,CAAC,GAAG,CAChB,qBAAqB,CAAC,CAAC,CAAC,EAAE,aAAa,CAAC,CAAC,CAAC,EAC1C,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC;aAC7D;YACD,MAAM,KAAK,GAAG,WAAW,CAAC,OAAO,CAAC,qBAAqB,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;YACjE,IAAI,KAAK,KAAK,CAAC,CAAC,EAAE;gBAChB,YAAY,CAAC,KAAK,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;aACxC;SACF;QAED,IAAI,CAAC,QAAQ,EAAE;YACb,8CAA8C;YAC9C,OAAO,CAAC,gBAAgB,CAAC,CAAC;SAC3B;KACF;IACD,iEAAiE;IACjE,sEAAsE;IACtE,mEAAmE;IACnE,kEAAkE;IAClE,qCAAqC;IACrC,gBAAgB,CAAC,YAAY,EAAE,CAAC;IAEhC,OAAO,YAAY,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC;AACvD,CAAC;AAUD;;;;;;;;;;GAUG;AACH,SAAS,oCAAoC,CACzC,OAAyB,EAAE,QAAkB;IAE/C,IAAI,CAAC,MAAM,CACP,OAAO,IAAI,IAAI,IAAI,OAAO,CAAC,MAAM,GAAG,CAAC,EACrC,GAAG,EAAE,CAAC,uCAAuC,CAAC,CAAC;IAEnD,IAAI,WAAW,GAAqB,EAAE,CAAC;IACvC,IAAI,iBAAiB,GAAiB,EAAE,CAAC;IACzC,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;QACxB,yCAAyC;QACzC,MAAM,GAAG,GACL,+CAA+C,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;QAC1E,WAAW,GAAG,GAAG,CAAC,MAAM,CAAC;QACzB,iBAAiB,GAAG,GAAG,CAAC,YAAY,CAAC;KACtC;SAAM;QACL,MAAM,OAAO,GAAG,IAAI,GAAG,EAAU,CAAC;QAClC,KAAK,MAAM,KAAK,IAAI,OAAO,EAAE;YAC3B,MAAM,EAAC,MAAM,EAAE,YAAY,EAAC,GACxB,+CAA+C,CAAC,KAAK,EAAE,QAAQ,CAAC,CAAC;YAErE,sCAAsC;YACtC,KAAK,MAAM,cAAc,IAAI,MAAM,EAAE;gBACnC,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,cAAc,CAAC,IAAI,CAAC,EAAE;oBACrC,WAAW,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;oBACjC,OAAO,CAAC,GAAG,CAAC,cAAc,CAAC,IAAI,CAAC,CAAC;iBAClC;aACF;YAED,wBAAwB;YACxB,KAAK,MAAM,IAAI,IAAI,YAAY,EAAE;gBAC/B,IAAI,iBAAiB,CAAC,IAAI,CAAC,IAAI,IAAI,EAAE;oBACnC,iBAAiB,CAAC,IAAI,CAAC,GAAG,IAAI,GAAG,EAAU,CAAC;iBAC7C;gBACD,YAAY,CAAC,IAAI,CAAC,CAAC,OAAO,CACtB,SAAS,CAAC,EAAE,CAAC,iBAAiB,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,CAAC;aAC1D;SACF;KACF;IACD,OAAO;QACL,MAAM,EAAE,WAAW;QACnB,eAAe,EAAE,mBAAmB,CAAC,iBAAiB,CAAC;KACxD,CAAC;AACJ,CAAC;AAED,SAAS,mBAAmB,CAAC,YAA0B;IACrD,MAAM,eAAe,GAAoB,EAAE,CAAC;IAC5C,KAAK,MAAM,IAAI,IAAI,YAAY,EAAE;QAC/B,eAAe,CAAC,IAAI,CAAC,GAAG,YAAY,CAAC,IAAI,CAAC,CAAC,IAAI,CAAC;KACjD;IACD,OAAO,eAAe,CAAC;AACzB,CAAC;AAED;;;;;;;;;;GAUG;AACH,MAAM,UAAU,+CAA+C,CAC3D,KAAqB,EAAE,QAAkB;IAE3C,MAAM,OAAO,GAAG,IAAI,GAAG,EAAU,CAAC;IAClC,MAAM,MAAM,GAAqB,EAAE,CAAC;IACpC,MAAM,YAAY,GAAiB,EAAE,CAAC;IAEtC,wEAAwE;IACxE,wEAAwE;IACxE,gCAAgC;IAChC,KAAK,MAAM,GAAG,IAAI,QAAQ,CAAC,KAAK,EAAE,EAAE;QAClC,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;KAClB;IAED,MAAM,KAAK,GAAqB,EAAE,CAAC;IACnC,MAAM,KAAK,GAAa,EAAE,CAAC;IAE3B,yCAAyC;IACzC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;IAElB,OAAO,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE;QACvB,MAAM,GAAG,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACpC,IAAI,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE;YACzB,KAAK,CAAC,GAAG,EAAE,CAAC;YACZ,SAAS;SACV;QACD,MAAM,WAAW,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,KAAK,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;QACjE,IAAI,GAAG,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,IAAI,WAAW,EAAE;YAC1C,0DAA0D;YAC1D,KAAK,CAAC,GAAG,EAAE,CAAC;YACZ,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;YACjB,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;YACtB,IAAI,WAAW,EAAE;gBACf,KAAK,CAAC,GAAG,EAAE,CAAC;aACb;SACF;aAAM;YACL,oEAAoE;YACpE,8CAA8C;YAC9C,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC7B,KAAK,MAAM,KAAK,IAAI,GAAG,CAAC,MAAM,EAAE;gBAC9B,gEAAgE;gBAChE,oEAAoE;gBACpE,IAAI,YAAY,CAAC,KAAK,CAAC,IAAI,CAAC,IAAI,IAAI,EAAE;oBACpC,YAAY,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,IAAI,GAAG,EAAU,CAAC;iBAC9C;gBACD,YAAY,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;gBAEvC,IAAI,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC,EAAE;oBAC3B,SAAS,CAAE,oDAAoD;iBAChE;gBACD,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;aACnB;SACF;KACF;IACD,OAAO,EAAC,MAAM,EAAE,YAAY,EAAC,CAAC;AAChC,CAAC;AAED;;;;;GAKG;AACH,SAAS,cAAc,CAAC,KAAqB;IAE3C,IAAI,YAA6C,CAAC;IAClD,IAAI,KAAK,CAAC,WAAW,CAAC,YAAY,CAAC,MAAM,KAAK,CAAC,EAAE;QAC/C,YAAY,GAAG,KAAK,CAAC,WAAW,CAAC,MAAM,CAAC;KACzC;SAAM;QACL,IAAI,SAAS,GAAW,IAAI,CAAC;QAC7B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,WAAW,CAAC,YAAY,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YAC9D,KAAK,MAAM,YAAY,IAAI,KAAK,CAAC,WAAW,CAAC,YAAY,CAAC,CAAC,CAAC;iBAClD,aAAa,EAAE;gBACvB,IAAI,YAAY,CAAC,EAAE,KAAK,KAAK,CAAC,EAAE,EAAE;oBAChC,SAAS,GAAG,CAAC,CAAC;oBACd,MAAM;iBACP;aACF;SACF;QACD,YAAY,GAAG,KAAK,CAAC,WAAW,CAAC,WAAW,CAAC,SAAS,CAAC,CAAC;KACzD;IACD,OAAO,YAAY,CAAC;AACtB,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * Executor: Evaluates SymbolicTensor based on feeds.\n */\n\nimport {cast, dispose, memory, Tensor, util} from '@tensorflow/tfjs-core';\n\nimport {ValueError} from '../errors';\nimport {Kwargs} from '../types';\nimport {LruCache} from '../utils/executor_utils';\nimport {toList} from '../utils/generic_utils';\n\nimport {InputLayer} from './input_layer';\nimport {SymbolicTensor} from './topology';\n\n/**\n * Helper function to check the dtype and shape compatibility of a feed value.\n */\nfunction assertFeedCompatibility(key: SymbolicTensor, val: Tensor): Tensor {\n  // Check dtype compatibility.\n  if (key.dtype == null || key.dtype === val.dtype) {\n    //  a.  If types match, return val tensor as is.\n    return val;\n  }\n  try {\n    //  b. Attempt to convert to expected type.\n    return cast(val, key.dtype);\n  } catch (err) {\n    //  c. If conversion fails, return helpful error.\n    throw new ValueError(\n        `The dtype of the feed (${val.dtype}) can not be cast to the dtype ` +\n        `of the key '${key.name}' (${key.dtype}).`);\n  }\n}\n\n/**\n * A concrete Tensor value for a symbolic tensor as the key.\n */\nexport interface Feed {\n  key: SymbolicTensor;\n  value: Tensor;\n}\n\n/**\n * FeedDict: A mapping from unique SymbolicTensors to feed values for them.\n * A feed value is a concrete value represented as an `Tensor`.\n */\nexport class FeedDict {\n  private id2Value: {[id: number]: Tensor} = {};\n  private id2Mask: {[id: number]: Tensor} = {};\n  private name2Id: {[name: string]: number} = {};\n\n  /**\n   * Constructor, optionally does copy-construction.\n   * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case\n   *   copy-construction will be performed.\n   */\n  constructor(feeds?: Feed[]|FeedDict) {\n    if (feeds instanceof FeedDict) {\n      for (const id in feeds.id2Value) {\n        this.id2Value[id] = feeds.id2Value[id];\n        if (id in feeds.id2Mask) {\n          this.id2Mask[id] = feeds.id2Mask[id];\n        }\n      }\n    } else {\n      if (feeds == null) {\n        return;\n      }\n      for (const feed of feeds) {\n        this.add(feed.key, feed.value);\n      }\n    }\n  }\n\n  /**\n   * Add a key-value pair to the FeedDict.\n   *\n   * @param key The key of the feed.\n   * @param value The value of the tensor feed.\n   * @param mask The value of the mask feed (optional).\n   * @returns This `FeedDict`.\n   * @throws ValueError: If the key `SymbolicTensor` already exists in the\n   *   `FeedDict`.\n   */\n  add(key: SymbolicTensor, value: Tensor, mask?: Tensor): FeedDict {\n    if (this.id2Value[key.id] == null) {\n      this.id2Value[key.id] = assertFeedCompatibility(key, value);\n      this.name2Id[key.name] = key.id;\n      if (mask != null) {\n        this.id2Mask[key.id] = mask;\n      }\n    } else {\n      throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`);\n    }\n    return this;\n  }\n\n  /**\n   * Add a Feed to the FeedDict.\n   * @param feed The new `Feed` to add.\n   * @returns This `FeedDict`.\n   */\n  addFeed(feed: Feed) {\n    this.add(feed.key, feed.value);\n  }\n\n  /**\n   * Probe whether a key already exists in the FeedDict.\n   * @param key\n   */\n  hasKey(key: SymbolicTensor): boolean {\n    return this.id2Value[key.id] != null;\n  }\n\n  /**\n   * Get all the SymbolicTensor available in this FeedDict.\n   */\n  names(): string[] {\n    return Object.keys(this.name2Id);\n  }\n\n  /**\n   * Get the feed value for given key.\n   * @param key The SymbolicTensor, or its name (as a string), of which the\n   *     value is sought.\n   * @returns If `key` exists, the corresponding feed value.\n   * @throws ValueError: If `key` does not exist in this `FeedDict`.\n   */\n  getValue(key: SymbolicTensor|string): Tensor {\n    if (key instanceof SymbolicTensor) {\n      if (this.id2Value[key.id] == null) {\n        throw new ValueError(`Nonexistent key: ${key.name}`);\n      } else {\n        return this.id2Value[key.id];\n      }\n    } else {\n      const id = this.name2Id[key];\n      if (id == null) {\n        throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);\n      }\n      return this.id2Value[id];\n    }\n  }\n\n  /**\n   * Get the feed mask for given key.\n   * @param key The SymbolicTensor, or its name (as a string), of which the\n   *     value is sought.\n   * @returns If `key` exists, the corresponding feed mask.\n   * @throws ValueError: If `key` does not exist in this `FeedDict`.\n   */\n  getMask(key: SymbolicTensor|string): Tensor {\n    if (key instanceof SymbolicTensor) {\n      if (this.id2Value[key.id] == null) {\n        throw new ValueError(`Nonexistent key: ${key.name}`);\n      } else {\n        return this.id2Mask[key.id];\n      }\n    } else {\n      const id = this.name2Id[key];\n      if (id == null) {\n        throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);\n      }\n      return this.id2Mask[id];\n    }\n  }\n\n  /** Dispose all mask Tensors held by this object. */\n  disposeMasks() {\n    if (this.id2Mask != null) {\n      dispose(this.id2Mask);\n    }\n  }\n}\n\n// Cache for topologically sorted SymbolicTensors for given execution\n// targets (i.e., fetches).\nexport const cachedSorted: LruCache<SymbolicTensor[]> =\n    new LruCache<SymbolicTensor[]>();\n\n// Cache for recipient count maps for given execution targets (i.e., fetches).\nexport const cachedRecipientCounts: LruCache<RecipientCounts> =\n    new LruCache<RecipientCounts>();\n\nexport function updateCacheMaxEntries(maxEntries: number) {\n  if (cachedSorted != null) {\n    cachedSorted.setMaxEntries(maxEntries);\n  }\n  if (cachedRecipientCounts != null) {\n    cachedRecipientCounts.setMaxEntries(maxEntries);\n  }\n}\n\n/**\n * Interface for the optional object used for probing the memory\n * usage and other statistics during execution.\n */\nexport interface ExecutionProbe {\n  /**\n   * Maximum number of tensors that exist during all steps of the\n   * execution. Tensor counts are measured at the beginning of every\n   * step.\n   */\n  maxNumTensors?: number;\n\n  /**\n   * Minimum number of tensors that exist during all steps of the\n   * execution. Tensor counts are measured at the beginning of every\n   * step.\n   */\n  minNumTensors?: number;\n}\n\n/**\n * Execute a SymbolicTensor by using concrete feed values.\n *\n * A `SymbolicTensor` object is a node in a computation graph of TF.js\n * Layers. The object is backed by a source layer and input\n * `SymbolicTensor`s to the source layer. This method evaluates\n * the `call()` method of the source layer, using concrete values of the\n * inputs obtained from either\n * * `feedDict`, if the input key exists in `feedDict`, or else,\n * * a recursive call to `execute()` itself.\n *\n * @param x: The `SymbolicTensor` to execute.\n * @param feedDict: The feed values, as base condition of the recursion.\n *   execution.\n * @param kwargs: Optional keyword arguments.\n * @param probe: A probe object (of interface `ExecutionProbe`) used for\n *   testing memory footprint of `execute` calls.\n * @returns Result of the execution.\n * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s\n *   encountered during the execution lacks a feed value in `feedDict`.\n */\nexport function execute(\n    fetches: SymbolicTensor|SymbolicTensor[], feedDict: FeedDict,\n    kwargs?: Kwargs, probe?: ExecutionProbe): Tensor|\n    Tensor[]|[Tensor | Tensor[]] {\n  const training: boolean = kwargs == null ? false : kwargs['training'];\n\n  const arrayFetches = Array.isArray(fetches);\n  const fetchArray: SymbolicTensor[] =\n      arrayFetches ? fetches : [fetches];\n\n  const outputNames = fetchArray.map(t => t.name);\n  const finalOutputs: Tensor[] = [];\n  const feedNames = feedDict.names();\n  for (const outputName of outputNames) {\n    if (feedNames.indexOf(outputName) !== -1) {\n      finalOutputs.push(feedDict.getValue(outputName));\n    } else {\n      finalOutputs.push(null);\n    }\n  }\n\n  if (probe != null) {\n    // For optional probing of memory footprint during execution.\n    probe.maxNumTensors = -Infinity;\n    probe.minNumTensors = Infinity;\n  }\n\n  // Check cache.\n  const fetchAndFeedKey =\n      outputNames.join(',') + '|' + feedDict.names().sort().join(',');\n  let sorted: SymbolicTensor[] = cachedSorted.get(fetchAndFeedKey);\n  let recipientCounts: {[fetchName: string]: number};\n  if (sorted == null) {\n    // Cache doesn't contain the desired combination of fetches. Compute\n    // topological sort for the combination for the first time.\n    const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);\n    sorted = out.sorted;\n    recipientCounts = out.recipientCounts;\n\n    // Store results in cache for future use.\n    cachedSorted.put(fetchAndFeedKey, sorted);\n    cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts);\n  }\n  recipientCounts = {};\n  if (!training) {\n    Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey));\n  }\n\n  const internalFeedDict = new FeedDict(feedDict);\n\n  // Start iterative execution on the topologically-sorted SymbolicTensors.\n  for (let i = 0; i < sorted.length; ++i) {\n    if (probe != null) {\n      // For optional probing of memory usage during execution.\n      const numTensors = memory().numTensors;\n      if (numTensors > probe.maxNumTensors) {\n        probe.maxNumTensors = numTensors;\n      }\n      if (numTensors < probe.minNumTensors) {\n        probe.minNumTensors = numTensors;\n      }\n    }\n\n    const symbolic = sorted[i];\n    const srcLayer = symbolic.sourceLayer;\n    if (srcLayer instanceof InputLayer) {\n      continue;\n    }\n    const inputValues: Tensor[] = [];\n    const inputMasks: Tensor[] = [];\n    const tensorsToDispose: Tensor[] = [];\n\n    let maskExists = false;\n    for (const input of symbolic.inputs) {\n      const value = internalFeedDict.getValue(input);\n      const mask = internalFeedDict.getMask(input);\n      inputValues.push(value);\n      inputMasks.push(mask);\n      if (mask != null) {\n        maskExists = true;\n      }\n      if (!training) {\n        recipientCounts[input.name]--;\n        if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&\n            outputNames.indexOf(input.name) === -1 && !value.isDisposed &&\n            input.sourceLayer.stateful !== true) {\n          tensorsToDispose.push(value);\n        }\n      }\n    }\n\n    if (maskExists) {\n      kwargs = kwargs || {};\n      kwargs['mask'] = inputMasks[0];\n    }\n    const outputTensors =\n        toList(srcLayer.apply(inputValues, kwargs)) as Tensor[];\n    let outputMask: Tensor|Tensor[] = null;\n    if (srcLayer.supportsMasking) {\n      outputMask = srcLayer.computeMask(inputValues, inputMasks);\n    }\n    const layerOutputs = getNodeOutputs(symbolic);\n    const outputSymbolicTensors =\n        Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];\n    for (let i = 0; i < outputSymbolicTensors.length; ++i) {\n      if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {\n        internalFeedDict.add(\n            outputSymbolicTensors[i], outputTensors[i],\n            Array.isArray(outputMask) ? outputMask[0] : outputMask);\n      }\n      const index = outputNames.indexOf(outputSymbolicTensors[i].name);\n      if (index !== -1) {\n        finalOutputs[index] = outputTensors[i];\n      }\n    }\n\n    if (!training) {\n      // Clean up Tensors that are no longer needed.\n      dispose(tensorsToDispose);\n    }\n  }\n  // NOTE(cais): Unlike intermediate tensors, we don't discard mask\n  // tensors as we go, because these tensors are sometimes passed over a\n  // series of mutliple layers, i.e., not obeying the immediate input\n  // relations in the graph. If this becomes a memory-usage concern,\n  // we can improve this in the future.\n  internalFeedDict.disposeMasks();\n\n  return arrayFetches ? finalOutputs : finalOutputs[0];\n}\n\ntype RecipientCounts = {\n  [fetchName: string]: number\n};\n\nexport type RecipientMap = {\n  [fetchName: string]: Set<string>;\n};\n\n/**\n * Sort the `SymbolicTensor`s topologically, for an array of fetches.\n *\n * This function calls getTopologicalSortAndRecipientCountsForOneFetch and\n * merges their results.\n *\n * @param fetch The array of fetches requested. Must be a non-empty array.\n * @param feedDict The dictionary of fed values.\n * @returns sorted: Topologically-sorted array of SymbolicTensors.\n *   recipientCounts: Recipient counts for all SymbolicTensors in `sorted`.\n */\nfunction getTopologicalSortAndRecipientCounts(\n    fetches: SymbolicTensor[], feedDict: FeedDict):\n    {sorted: SymbolicTensor[], recipientCounts: RecipientCounts} {\n  util.assert(\n      fetches != null && fetches.length > 0,\n      () => `Expected at least one fetch, got none`);\n\n  let finalSorted: SymbolicTensor[] = [];\n  let finalRecipientMap: RecipientMap = {};\n  if (fetches.length === 1) {\n    // Special-casing 1 fetch for efficiency.\n    const out =\n        getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);\n    finalSorted = out.sorted;\n    finalRecipientMap = out.recipientMap;\n  } else {\n    const visited = new Set<string>();\n    for (const fetch of fetches) {\n      const {sorted, recipientMap} =\n          getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict);\n\n      // Merge sorted SymbolicTensor Arrays.\n      for (const symbolicTensor of sorted) {\n        if (!visited.has(symbolicTensor.name)) {\n          finalSorted.push(symbolicTensor);\n          visited.add(symbolicTensor.name);\n        }\n      }\n\n      // Merge recipient maps.\n      for (const name in recipientMap) {\n        if (finalRecipientMap[name] == null) {\n          finalRecipientMap[name] = new Set<string>();\n        }\n        recipientMap[name].forEach(\n            recipient => finalRecipientMap[name].add(recipient));\n      }\n    }\n  }\n  return {\n    sorted: finalSorted,\n    recipientCounts: recipientMap2Counts(finalRecipientMap)\n  };\n}\n\nfunction recipientMap2Counts(recipientMap: RecipientMap): RecipientCounts {\n  const recipientCounts: RecipientCounts = {};\n  for (const name in recipientMap) {\n    recipientCounts[name] = recipientMap[name].size;\n  }\n  return recipientCounts;\n}\n\n/**\n * Sort the `SymbolicTensor`s topologically, for a single fetch.\n *\n * This helper function processes the upstream SymbolicTensors of a single\n * fetch.\n *\n * @param fetch The single fetch requested.\n * @param feedDict The dictionary of fed values.\n * @returns sorted: Topologically-sorted array of SymbolicTensors.\n *   recipientMap: Recipient names for all SymbolicTensors in `sorted`.\n */\nexport function getTopologicalSortAndRecipientCountsForOneFetch(\n    fetch: SymbolicTensor, feedDict: FeedDict):\n    {sorted: SymbolicTensor[], recipientMap: RecipientMap} {\n  const visited = new Set<string>();\n  const sorted: SymbolicTensor[] = [];\n  const recipientMap: RecipientMap = {};\n\n  // Put keys of the feedDict into visited first, so they don't have to be\n  // walked. This is needed in case where there are feeds for intermediate\n  // SymbolicTensors of the graph.\n  for (const key of feedDict.names()) {\n    visited.add(key);\n  }\n\n  const stack: SymbolicTensor[] = [];\n  const marks: number[] = [];\n\n  // Initial population of stack and marks.\n  stack.push(fetch);\n\n  while (stack.length > 0) {\n    const top = stack[stack.length - 1];\n    if (visited.has(top.name)) {\n      stack.pop();\n      continue;\n    }\n    const topIsMarked = marks[marks.length - 1] === stack.length - 1;\n    if (top.inputs.length === 0 || topIsMarked) {\n      // Input SymbolicTensor or all children have been visited.\n      stack.pop();\n      sorted.push(top);\n      visited.add(top.name);\n      if (topIsMarked) {\n        marks.pop();\n      }\n    } else {\n      // A non-input SymbolicTensor whose upstream SymbolicTensors haven't\n      // been visited yet. Push them onto the stack.\n      marks.push(stack.length - 1);\n      for (const input of top.inputs) {\n        // Increment the recipient count. Note that this needs to happen\n        // regardless of whether the SymbolicTensor has been visited before.\n        if (recipientMap[input.name] == null) {\n          recipientMap[input.name] = new Set<string>();\n        }\n        recipientMap[input.name].add(top.name);\n\n        if (visited.has(input.name)) {\n          continue;  // Avoid repeated visits to the same SymbolicTensor.\n        }\n        stack.push(input);\n      }\n    }\n  }\n  return {sorted, recipientMap};\n}\n\n/**\n * Get the symbolic output tensors of the node to which a given fetch belongs.\n * @param fetch The fetched symbolic tensor.\n * @returns The Array of symbolic tensors output by the node to which `fetch`\n *   belongs.\n */\nfunction getNodeOutputs(fetch: SymbolicTensor): SymbolicTensor|\n    SymbolicTensor[] {\n  let layerOutputs: SymbolicTensor|SymbolicTensor[];\n  if (fetch.sourceLayer.inboundNodes.length === 1) {\n    layerOutputs = fetch.sourceLayer.output;\n  } else {\n    let nodeIndex: number = null;\n    for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {\n      for (const outputTensor of fetch.sourceLayer.inboundNodes[i]\n               .outputTensors) {\n        if (outputTensor.id === fetch.id) {\n          nodeIndex = i;\n          break;\n        }\n      }\n    }\n    layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);\n  }\n  return layerOutputs;\n}\n"]}
|