/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { assert } from '../util_base';
|
const ARROW = '->';
|
const ARROW_REGEX = /->/g;
|
const COMMA = ',';
|
const ELLIPSIS = '...';
|
/**
|
* Parse an equation for einsum.
|
*
|
* @param equation The einsum equation (e.g., "ij,jk->ik").
|
* @param numTensors Number of tensors provided along with `equation`. Used to
|
* check matching number of input tensors.
|
* @returns An object consisting of the following fields:
|
* - allDims: all dimension names as strings.
|
* - summedDims: a list of all dimensions being summed over, as indices to
|
* the elements of `allDims`.
|
* - idDims: indices of the dimensions in each input tensor, as indices to
|
* the elements of `allDims.
|
*/
|
export function decodeEinsumEquation(equation, numTensors) {
|
equation = equation.replace(/\s/g, ''); // Remove witespace in equation.
|
const numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) /
|
ARROW.length;
|
if (numArrows < 1) {
|
throw new Error('Equations without an arrow are not supported.');
|
}
|
else if (numArrows > 1) {
|
throw new Error(`Equation must contain exactly one arrow ("${ARROW}").`);
|
}
|
const [inputString, outputString] = equation.split(ARROW);
|
assert(inputString.indexOf(ELLIPSIS) === -1, () => `The ellipsis notation ("${ELLIPSIS}") is not supported yet.`);
|
const inputTerms = inputString.split(COMMA);
|
const numInputs = inputTerms.length;
|
if (numTensors !== numInputs) {
|
throw new Error(`Expected ${numInputs} input tensors, received ${numTensors}`);
|
}
|
if (numInputs > 2) {
|
throw new Error('Support for more than 2 input tensors is not implemented yet.');
|
}
|
const allDims = [];
|
for (let i = 0; i < outputString.length; ++i) {
|
const dimName = outputString[i];
|
if (!inputTerms.some(inputTerm => inputTerm.indexOf(dimName) !== -1)) {
|
throw new Error(`Output subscripts contain the label ${dimName} ` +
|
`not present in the input subscripts.`);
|
}
|
if (allDims.indexOf(dimName) === -1) {
|
allDims.push(dimName);
|
}
|
}
|
for (let i = 0; i < inputString.length; ++i) {
|
const dimName = inputString[i];
|
if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {
|
allDims.push(dimName);
|
}
|
}
|
const idDims = new Array(inputTerms.length);
|
for (let i = 0; i < numInputs; ++i) {
|
if (new Set(inputTerms[i].split('')).size !== inputTerms[i].length) {
|
throw new Error(`Found duplicate axes in input component ${inputTerms[i]}. ` +
|
`Support for duplicate axes in input is not implemented yet.`);
|
}
|
idDims[i] = [];
|
for (let j = 0; j < inputTerms[i].length; ++j) {
|
idDims[i].push(allDims.indexOf(inputTerms[i][j]));
|
}
|
}
|
const numDims = allDims.length; // Number of unique dimensions.
|
const numOutDims = outputString.length; // Number of output dimensions.
|
const summedDims = []; // Dimensions being summed over.
|
for (let i = numOutDims; i < numDims; ++i) {
|
summedDims.push(i);
|
}
|
return { allDims, summedDims, idDims };
|
}
|
/**
|
* Get the permutation for a given input tensor.
|
*
|
* @param nDims Total number of dimension of all tensors involved in the einsum
|
* operation.
|
* @param idDims Dimension indices involve in the tensor in question.
|
* @returns An object consisting of the following fields:
|
* - permutationIndices: Indices to permute the axes of the tensor with.
|
* - expandDims: Indices to the dimension that need to be expanded from the
|
* tensor after permutation.
|
*/
|
export function getEinsumPermutation(nDims, idDims) {
|
let permutationIndices = new Array(nDims);
|
permutationIndices.fill(-1);
|
for (let i = 0; i < idDims.length; ++i) {
|
permutationIndices[idDims[i]] = i;
|
}
|
const expandDims = [];
|
for (let i = 0; i < nDims; ++i) {
|
if (permutationIndices[i] === -1) {
|
expandDims.push(i);
|
}
|
}
|
permutationIndices = permutationIndices.filter(d => d !== -1);
|
return { permutationIndices, expandDims };
|
}
|
/**
|
* Checks that the dimension sizes from different input tensors match the
|
* equation.
|
*/
|
export function checkEinsumDimSizes(nDims, idDims, tensors) {
|
const dimSizes = new Array(nDims);
|
for (let i = 0; i < tensors.length; ++i) {
|
const shape = tensors[i].shape;
|
for (let j = 0; j < idDims[i].length; ++j) {
|
if (dimSizes[idDims[i][j]] === undefined) {
|
dimSizes[idDims[i][j]] = shape[j];
|
}
|
else {
|
assert(dimSizes[idDims[i][j]] === shape[j], () => `Expected dimension ${dimSizes[idDims[i][j]]} at axis ${j} ` +
|
`of input shaped ${JSON.stringify(shape)}, ` +
|
`but got dimension ${shape[j]}`);
|
}
|
}
|
}
|
}
|
/**
|
* Gets path of computation for einsum.
|
*
|
* @param summedDims indices to the dimensions being summed over.
|
* @param idDims A look up table for the dimensions present in each input
|
* tensor. Each consituent array contains indices for the dimensions in the
|
* corresponding input tensor.
|
*
|
* @return A map with two fields:
|
* - path: The path of computation, with each element indicating the dimension
|
* being summed over after the element-wise multiplication in that step.
|
* - steps: With the same length as `path`. Each element contains the indices
|
* to the input tensors being used for element-wise multiplication in the
|
* corresponding step.
|
*/
|
export function getEinsumComputePath(summedDims, idDims) {
|
const path = summedDims;
|
const steps = [];
|
let nSteps = 0;
|
if (summedDims.length === 0) {
|
// Einsum that involes no summing: e.g., transpose and outer product.
|
path.push(-1);
|
}
|
nSteps = summedDims.length + 1;
|
for (let i = 0; i < nSteps; ++i) {
|
steps.push([]);
|
}
|
const computedTermIndices = [];
|
for (let i = 0; i < path.length; ++i) {
|
const summedDim = path[i];
|
const termIndices = findTermsWithDim(idDims, summedDim);
|
for (const termIndex of termIndices) {
|
if (computedTermIndices.indexOf(termIndex) === -1) {
|
steps[i].push(termIndex);
|
computedTermIndices.push(termIndex);
|
}
|
}
|
}
|
return { path, steps };
|
}
|
/** Determines if an axes permutation is the identity permutation. */
|
export function isIdentityPermutation(perm) {
|
return perm.every((dim, index) => dim === index);
|
}
|
function findTermsWithDim(idDims, dim) {
|
const termIndices = [];
|
for (let i = 0; i < idDims.length; ++i) {
|
if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {
|
termIndices.push(i);
|
}
|
}
|
return termIndices;
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"einsum_util.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/backends/einsum_util.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAQH,OAAO,EAAC,MAAM,EAAC,MAAM,cAAc,CAAC;AAEpC,MAAM,KAAK,GAAG,IAAI,CAAC;AACnB,MAAM,WAAW,GAAG,KAAK,CAAC;AAC1B,MAAM,KAAK,GAAG,GAAG,CAAC;AAClB,MAAM,QAAQ,GAAG,KAAK,CAAC;AAEvB;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,oBAAoB,CAAC,QAAgB,EAAE,UAAkB;IAKvE,QAAQ,GAAG,QAAQ,CAAC,OAAO,CAAC,KAAK,EAAE,EAAE,CAAC,CAAC,CAAE,gCAAgC;IACzE,MAAM,SAAS,GACX,CAAC,QAAQ,CAAC,MAAM,GAAG,QAAQ,CAAC,OAAO,CAAC,WAAW,EAAE,EAAE,CAAC,CAAC,MAAM,CAAC;QAC5D,KAAK,CAAC,MAAM,CAAC;IACjB,IAAI,SAAS,GAAG,CAAC,EAAE;QACjB,MAAM,IAAI,KAAK,CAAC,+CAA+C,CAAC,CAAC;KAClE;SAAM,IAAI,SAAS,GAAG,CAAC,EAAE;QACxB,MAAM,IAAI,KAAK,CAAC,6CAA6C,KAAK,KAAK,CAAC,CAAC;KAC1E;IACD,MAAM,CAAC,WAAW,EAAE,YAAY,CAAC,GAAG,QAAQ,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;IAC1D,MAAM,CACF,WAAW,CAAC,OAAO,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,EACpC,GAAG,EAAE,CAAC,2BAA2B,QAAQ,0BAA0B,CAAC,CAAC;IACzE,MAAM,UAAU,GAAG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;IAC5C,MAAM,SAAS,GAAG,UAAU,CAAC,MAAM,CAAC;IACpC,IAAI,UAAU,KAAK,SAAS,EAAE;QAC5B,MAAM,IAAI,KAAK,CACX,YAAY,SAAS,4BAA4B,UAAU,EAAE,CAAC,CAAC;KACpE;IACD,IAAI,SAAS,GAAG,CAAC,EAAE;QACjB,MAAM,IAAI,KAAK,CACX,+DAA+D,CAAC,CAAC;KACtE;IAED,MAAM,OAAO,GAAa,EAAE,CAAC;IAC7B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,YAAY,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QAC5C,MAAM,OAAO,GAAG,YAAY,CAAC,CAAC,CAAC,CAAC;QAChC,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,SAAS,CAAC,EAAE,CAAC,SAAS,CAAC,OAAO,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;YACpE,MAAM,IAAI,KAAK,CACX,uCAAuC,OAAO,GAAG;gBACjD,sCAAsC,CAAC,CAAC;SAC7C;QACD,IAAI,OAAO,CAAC,OAAO,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE;YACnC,OAAO,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SACvB;KACF;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,WAAW,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QAC3C,MAAM,OAAO,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC;QAC/B,IAAI,OAAO,CAAC,OAAO,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,OAAO,KAAK,KAAK,EAAE;YACxD,OAAO,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SACvB;KACF;IAED,MAAM,MAAM,GAAe,IAAI,KAAK,CAAW,UAAU,CAAC,MAAM,CAAC,CAAC;IAClE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,EAAE,CAAC,EAAE;QAClC,IAAI,IAAI,GAAG,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,KAAK,UAAU,CAAC,CAAC,CAAC,CAAC,MAAM,EAAE;YAClE,MAAM,IAAI,KAAK,CACX,2CAA2C,UAAU,CAAC,CAAC,CAAC,IAAI;gBAC5D,6DAA6D,CAAC,CAAC;SACpE;QACD,MAAM,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC;QACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YAC7C,MAAM,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SACnD;KACF;IAED,MAAM,OAAO,GAAG,OAAO,CAAC,MAAM,CAAC,CAAU,+BAA+B;IACxE,MAAM,UAAU,GAAG,YAAY,CAAC,MAAM,CAAC,CAAE,+BAA+B;IACxE,MAAM,UAAU,GAAa,EAAE,CAAC,CAAS,gCAAgC;IACzE,KAAK,IAAI,CAAC,GAAG,UAAU,EAAE,CAAC,GAAG,OAAO,EAAE,EAAE,CAAC,EAAE;QACzC,UAAU,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;KACpB;IACD,OAAO,EAAC,OAAO,EAAE,UAAU,EAAE,MAAM,EAAC,CAAC;AACvC,CAAC;AAED;;;;;;;;;;GAUG;AACH,MAAM,UAAU,oBAAoB,CAAC,KAAa,EAAE,MAAgB;IAElE,IAAI,kBAAkB,GAAa,IAAI,KAAK,CAAS,KAAK,CAAC,CAAC;IAC5D,kBAAkB,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;IAC5B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACtC,kBAAkB,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;KACnC;IACD,MAAM,UAAU,GAAa,EAAE,CAAC;IAChC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,EAAE,EAAE,CAAC,EAAE;QAC9B,IAAI,kBAAkB,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE;YAChC,UAAU,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACpB;KACF;IACD,kBAAkB,GAAG,kBAAkB,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAC9D,OAAO,EAAC,kBAAkB,EAAE,UAAU,EAAC,CAAC;AAC1C,CAAC;AAED;;;GAGG;AACH,MAAM,UAAU,mBAAmB,CAC/B,KAAa,EAAE,MAAkB,EAAE,OAAiB;IACtD,MAAM,QAAQ,GAAa,IAAI,KAAK,CAAS,KAAK,CAAC,CAAC;IACpD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACvC,MAAM,KAAK,GAAa,OAAO,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC;QACzC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YACzC,IAAI,QAAQ,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,SAAS,EAAE;gBACxC,QAAQ,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;aACnC;iBAAM;gBACL,MAAM,CACF,QAAQ,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,KAAK,CAAC,CAAC,CAAC,EACnC,GAAG,EAAE,CAAC,sBAAsB,QAAQ,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,GAAG;oBAC9D,mBAAmB,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI;oBAC5C,qBAAqB,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;aAC1C;SACF;KACF;AACH,CAAC;AAED;;;;;;;;;;;;;;GAcG;AACH,MAAM,UAAU,oBAAoB,CAAC,UAAoB,EAAE,MAAkB;IAE3E,MAAM,IAAI,GAAa,UAAU,CAAC;IAClC,MAAM,KAAK,GAAe,EAAE,CAAC;IAC7B,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC,EAAE;QAC3B,qEAAqE;QACrE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;KACf;IACD,MAAM,GAAG,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC;IAC/B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,EAAE,CAAC,EAAE;QAC/B,KAAK,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;KAChB;IACD,MAAM,mBAAmB,GAAa,EAAE,CAAC;IACzC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACpC,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;QAC1B,MAAM,WAAW,GAAG,gBAAgB,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC;QACxD,KAAK,MAAM,SAAS,IAAI,WAAW,EAAE;YACnC,IAAI,mBAAmB,CAAC,OAAO,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,EAAE;gBACjD,KAAK,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;gBACzB,mBAAmB,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;aACrC;SACF;KACF;IACD,OAAO,EAAC,IAAI,EAAE,KAAK,EAAC,CAAC;AACvB,CAAC;AAED,qEAAqE;AACrE,MAAM,UAAU,qBAAqB,CAAC,IAAc;IAClD,OAAO,IAAI,CAAC,KAAK,CAAC,CAAC,GAAW,EAAE,KAAa,EAAE,EAAE,CAAC,GAAG,KAAK,KAAK,CAAC,CAAC;AACnE,CAAC;AAED,SAAS,gBAAgB,CAAC,MAAkB,EAAE,GAAW;IACvD,MAAM,WAAW,GAAa,EAAE,CAAC;IACjC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACtC,IAAI,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,KAAK,CAAC,IAAI,MAAM,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,GAAG,KAAK,CAAC,CAAC,EAAE;YACzE,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACrB;KACF;IACD,OAAO,WAAW,CAAC;AACrB,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2021 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * Utility functions for computing einsum (tensor contraction and summation\n * based on Einstein summation.)\n */\n\nimport {Tensor} from '../tensor';\nimport {assert} from '../util_base';\n\nconst ARROW = '->';\nconst ARROW_REGEX = /->/g;\nconst COMMA = ',';\nconst ELLIPSIS = '...';\n\n/**\n * Parse an equation for einsum.\n *\n * @param equation The einsum equation (e.g., \"ij,jk->ik\").\n * @param numTensors Number of tensors provided along with `equation`. Used to\n *   check matching number of input tensors.\n * @returns An object consisting of the following fields:\n *   - allDims: all dimension names as strings.\n *   - summedDims: a list of all dimensions being summed over, as indices to\n *     the elements of `allDims`.\n *   - idDims: indices of the dimensions in each input tensor, as indices to\n *     the elements of `allDims.\n */\nexport function decodeEinsumEquation(equation: string, numTensors: number): {\n  allDims: string[],\n  summedDims: number[],\n  idDims: number[][],\n} {\n  equation = equation.replace(/\\s/g, '');  // Remove witespace in equation.\n  const numArrows =\n      (equation.length - equation.replace(ARROW_REGEX, '').length) /\n      ARROW.length;\n  if (numArrows < 1) {\n    throw new Error('Equations without an arrow are not supported.');\n  } else if (numArrows > 1) {\n    throw new Error(`Equation must contain exactly one arrow (\"${ARROW}\").`);\n  }\n  const [inputString, outputString] = equation.split(ARROW);\n  assert(\n      inputString.indexOf(ELLIPSIS) === -1,\n      () => `The ellipsis notation (\"${ELLIPSIS}\") is not supported yet.`);\n  const inputTerms = inputString.split(COMMA);\n  const numInputs = inputTerms.length;\n  if (numTensors !== numInputs) {\n    throw new Error(\n        `Expected ${numInputs} input tensors, received ${numTensors}`);\n  }\n  if (numInputs > 2) {\n    throw new Error(\n        'Support for more than 2 input tensors is not implemented yet.');\n  }\n\n  const allDims: string[] = [];\n  for (let i = 0; i < outputString.length; ++i) {\n    const dimName = outputString[i];\n    if (!inputTerms.some(inputTerm => inputTerm.indexOf(dimName) !== -1)) {\n      throw new Error(\n          `Output subscripts contain the label ${dimName} ` +\n          `not present in the input subscripts.`);\n    }\n    if (allDims.indexOf(dimName) === -1) {\n      allDims.push(dimName);\n    }\n  }\n  for (let i = 0; i < inputString.length; ++i) {\n    const dimName = inputString[i];\n    if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {\n      allDims.push(dimName);\n    }\n  }\n\n  const idDims: number[][] = new Array<number[]>(inputTerms.length);\n  for (let i = 0; i < numInputs; ++i) {\n    if (new Set(inputTerms[i].split('')).size !== inputTerms[i].length) {\n      throw new Error(\n          `Found duplicate axes in input component ${inputTerms[i]}. ` +\n          `Support for duplicate axes in input is not implemented yet.`);\n    }\n    idDims[i] = [];\n    for (let j = 0; j < inputTerms[i].length; ++j) {\n      idDims[i].push(allDims.indexOf(inputTerms[i][j]));\n    }\n  }\n\n  const numDims = allDims.length;          // Number of unique dimensions.\n  const numOutDims = outputString.length;  // Number of output dimensions.\n  const summedDims: number[] = [];         // Dimensions being summed over.\n  for (let i = numOutDims; i < numDims; ++i) {\n    summedDims.push(i);\n  }\n  return {allDims, summedDims, idDims};\n}\n\n/**\n * Get the permutation for a given input tensor.\n *\n * @param nDims Total number of dimension of all tensors involved in the einsum\n *   operation.\n * @param idDims Dimension indices involve in the tensor in question.\n * @returns An object consisting of the following fields:\n *   - permutationIndices: Indices to permute the axes of the tensor with.\n *   - expandDims: Indices to the dimension that need to be expanded from the\n *     tensor after permutation.\n */\nexport function getEinsumPermutation(nDims: number, idDims: number[]):\n    {permutationIndices: number[], expandDims: number[]} {\n  let permutationIndices: number[] = new Array<number>(nDims);\n  permutationIndices.fill(-1);\n  for (let i = 0; i < idDims.length; ++i) {\n    permutationIndices[idDims[i]] = i;\n  }\n  const expandDims: number[] = [];\n  for (let i = 0; i < nDims; ++i) {\n    if (permutationIndices[i] === -1) {\n      expandDims.push(i);\n    }\n  }\n  permutationIndices = permutationIndices.filter(d => d !== -1);\n  return {permutationIndices, expandDims};\n}\n\n/**\n * Checks that the dimension sizes from different input tensors match the\n * equation.\n */\nexport function checkEinsumDimSizes(\n    nDims: number, idDims: number[][], tensors: Tensor[]) {\n  const dimSizes: number[] = new Array<number>(nDims);\n  for (let i = 0; i < tensors.length; ++i) {\n    const shape: number[] = tensors[i].shape;\n    for (let j = 0; j < idDims[i].length; ++j) {\n      if (dimSizes[idDims[i][j]] === undefined) {\n        dimSizes[idDims[i][j]] = shape[j];\n      } else {\n        assert(\n            dimSizes[idDims[i][j]] === shape[j],\n            () => `Expected dimension ${dimSizes[idDims[i][j]]} at axis ${j} ` +\n                `of input shaped ${JSON.stringify(shape)}, ` +\n                `but got dimension ${shape[j]}`);\n      }\n    }\n  }\n}\n\n/**\n * Gets path of computation for einsum.\n *\n * @param summedDims indices to the dimensions being summed over.\n * @param idDims A look up table for the dimensions present in each input\n *     tensor. Each consituent array contains indices for the dimensions in the\n *     corresponding input tensor.\n *\n * @return A map with two fields:\n *   - path: The path of computation, with each element indicating the dimension\n *     being summed over after the element-wise multiplication in that step.\n *   - steps: With the same length as `path`. Each element contains the indices\n *     to the input tensors being used for element-wise multiplication in the\n *     corresponding step.\n */\nexport function getEinsumComputePath(summedDims: number[], idDims: number[][]):\n    {path: number[], steps: number[][]} {\n  const path: number[] = summedDims;\n  const steps: number[][] = [];\n  let nSteps = 0;\n  if (summedDims.length === 0) {\n    // Einsum that involes no summing: e.g., transpose and outer product.\n    path.push(-1);\n  }\n  nSteps = summedDims.length + 1;\n  for (let i = 0; i < nSteps; ++i) {\n    steps.push([]);\n  }\n  const computedTermIndices: number[] = [];\n  for (let i = 0; i < path.length; ++i) {\n    const summedDim = path[i];\n    const termIndices = findTermsWithDim(idDims, summedDim);\n    for (const termIndex of termIndices) {\n      if (computedTermIndices.indexOf(termIndex) === -1) {\n        steps[i].push(termIndex);\n        computedTermIndices.push(termIndex);\n      }\n    }\n  }\n  return {path, steps};\n}\n\n/** Determines if an axes permutation is the identity permutation. */\nexport function isIdentityPermutation(perm: number[]): boolean {\n  return perm.every((dim: number, index: number) => dim === index);\n}\n\nfunction findTermsWithDim(idDims: number[][], dim: number): number[] {\n  const termIndices: number[] = [];\n  for (let i = 0; i < idDims.length; ++i) {\n    if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {\n      termIndices.push(i);\n    }\n  }\n  return termIndices;\n}\n"]}
|