/** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Einsum, util } from '@tensorflow/tfjs-core'; import { multiply } from './Multiply'; import { reshape } from './Reshape'; import { sum } from './Sum'; import { transpose } from './Transpose'; export function einsum(args) { const { inputs, backend, attrs } = args; const { equation } = attrs; const tensors = inputs; const { allDims, summedDims, idDims } = backend_util.decodeEinsumEquation(equation, tensors.length); backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors); const { path, steps } = backend_util.getEinsumComputePath(summedDims, idDims); const nSteps = steps.length; let out = null; let numDimsRemaining = allDims.length; const tensorsToDispose = []; for (let i = 0; i < nSteps; ++i) { for (const idTerm of steps[i]) { const { permutationIndices: perm, expandDims: dimsToExpand } = backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]); let x; if (backend_util.isIdentityPermutation(perm)) { x = tensors[idTerm]; } else { x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } }); tensorsToDispose.push(x); } const targetShape = x.shape.slice(); for (let k = 0; k < dimsToExpand.length; ++k) { targetShape.splice(dimsToExpand[k], 0, 1); } if (!util.arraysEqual(x.shape, targetShape)) { x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } }); tensorsToDispose.push(x); } if (out === null) { out = x; } else { // tslint:disable-next-line: no-unnecessary-type-assertion out = multiply({ inputs: { a: x, b: out }, backend }); tensorsToDispose.push(out); } } if (i < nSteps - 1) { if (path[i] >= 0) { out = sum({ inputs: { x: out }, backend, attrs: { axis: path[i] - (allDims.length - numDimsRemaining), keepDims: false } }); tensorsToDispose.push(out); } numDimsRemaining--; } } // Clean up intermediate tensors. for (const tensorInfo of tensorsToDispose) { if (tensorInfo === out) { continue; } backend.disposeIntermediateTensorInfo(tensorInfo); } return out; } export const einsumConfig = { kernelName: Einsum, backendName: 'cpu', kernelFunc: einsum }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRWluc3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9FaW5zdW0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBRSxNQUFNLEVBQTJFLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBSTFJLE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFDcEMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFdEMsTUFBTSxVQUFVLE1BQU0sQ0FDbEIsSUFBeUU7SUFFM0UsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxRQUFRLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFDekIsTUFBTSxPQUFPLEdBQUcsTUFBa0IsQ0FBQztJQUVuQyxNQUFNLEVBQUMsT0FBTyxFQUFFLFVBQVUsRUFBRSxNQUFNLEVBQUMsR0FDL0IsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFFBQVEsRUFBRSxPQUFPLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDaEUsWUFBWSxDQUFDLG1CQUFtQixDQUFDLE9BQU8sQ0FBQyxNQUFNLEVBQUUsTUFBTSxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBQ2xFLE1BQU0sRUFBQyxJQUFJLEVBQUUsS0FBSyxFQUFDLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFVBQVUsRUFBRSxNQUFNLENBQUMsQ0FBQztJQUU1RSxNQUFNLE1BQU0sR0FBRyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBQzVCLElBQUksR0FBRyxHQUFvQixJQUFJLENBQUM7SUFDaEMsSUFBSSxnQkFBZ0IsR0FBRyxPQUFPLENBQUMsTUFBTSxDQUFDO0lBQ3RDLE1BQU0sZ0JBQWdCLEdBQWlCLEVBQUUsQ0FBQztJQUMxQyxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQy9CLEtBQUssTUFBTSxNQUFNLElBQUksS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFO1lBQzdCLE1BQU0sRUFBQyxrQkFBa0IsRUFBRSxJQUFJLEVBQUUsVUFBVSxFQUFFLFlBQVksRUFBQyxHQUN0RCxZQUFZLENBQUMsb0JBQW9CLENBQUMsZ0JBQWdCLEVBQUUsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUM7WUFDeEUsSUFBSSxDQUFhLENBQUM7WUFDbEIsSUFBSSxZQUFZLENBQUMscUJBQXFCLENBQUMsSUFBSSxDQUFDLEVBQUU7Z0JBQzVDLENBQUMsR0FBRyxPQUFPLENBQUMsTUFBTSxDQUFDLENBQUM7YUFDckI7aUJBQU07Z0JBQ0wsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsTUFBTSxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFDLEVBQUMsQ0FBQyxDQUFDO2dCQUN0RSxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUM7YUFDMUI7WUFDRCxNQUFNLFdBQVcsR0FBYSxDQUFDLENBQUMsS0FBSyxDQUFDLEtBQUssRUFBRSxDQUFDO1lBQzlDLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxZQUFZLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO2dCQUM1QyxXQUFXLENBQUMsTUFBTSxDQUFDLFlBQVksQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUM7YUFDM0M7WUFFRCxJQUFJLENBQUMsSUFBSSxDQUFDLFdBQVcsQ0FBQyxDQUFDLENBQUMsS0FBSyxFQUFFLFdBQVcsQ0FBQyxFQUFFO2dCQUMzQyxDQUFDLEdBQUcsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxXQUFXLEVBQUMsRUFBQyxDQUFDLENBQUM7Z0JBQ2pFLGdCQUFnQixDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQzthQUMxQjtZQUNELElBQUksR0FBRyxLQUFLLElBQUksRUFBRTtnQkFDaEIsR0FBRyxHQUFHLENBQUMsQ0FBQzthQUNUO2lCQUFNO2dCQUNMLDBEQUEwRDtnQkFDMUQsR0FBRyxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLEdBQUcsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFlLENBQUM7Z0JBQ2hFLGdCQUFnQixDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQzthQUM1QjtTQUNGO1FBQ0QsSUFBSSxDQUFDLEdBQUcsTUFBTSxHQUFHLENBQUMsRUFBRTtZQUNsQixJQUFJLElBQUksQ0FBQyxDQUFDLENBQUMsSUFBSSxDQUFDLEVBQUU7Z0JBQ2hCLEdBQUcsR0FBRyxHQUFHLENBQUM7b0JBQ1IsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLEdBQUcsRUFBQztvQkFDaEIsT0FBTztvQkFDUCxLQUFLLEVBQUU7d0JBQ0wsSUFBSSxFQUFFLElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxNQUFNLEdBQUcsZ0JBQWdCLENBQUM7d0JBQ25ELFFBQVEsRUFBRSxLQUFLO3FCQUNoQjtpQkFDRixDQUFDLENBQUM7Z0JBQ0gsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDO2FBQzVCO1lBQ0QsZ0JBQWdCLEVBQUUsQ0FBQztTQUNwQjtLQUNGO0lBRUQsaUNBQWlDO0lBQ2pDLEtBQUssTUFBTSxVQUFVLElBQUksZ0JBQWdCLEVBQUU7UUFDekMsSUFBSSxVQUFVLEtBQUssR0FBRyxFQUFFO1lBQ3RCLFNBQVM7U0FDVjtRQUNELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxVQUFVLENBQUMsQ0FBQztLQUNuRDtJQUVELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLFlBQVksR0FBaUI7SUFDeEMsVUFBVSxFQUFFLE1BQU07SUFDbEIsV0FBVyxFQUFFLEtBQUs7SUFDbEIsVUFBVSxFQUFFLE1BQStCO0NBQzVDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7YmFja2VuZF91dGlsLCBFaW5zdW0sIEVpbnN1bUF0dHJzLCBFaW5zdW1JbnB1dHMsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgVGVuc29yLCBUZW5zb3JJbmZvLCB1dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kQ1BVfSBmcm9tICcuLi9iYWNrZW5kX2NwdSc7XG5cbmltcG9ydCB7bXVsdGlwbHl9IGZyb20gJy4vTXVsdGlwbHknO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuaW1wb3J0IHtzdW19IGZyb20gJy4vU3VtJztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBlaW5zdW0oXG4gICAgYXJnczoge2lucHV0czogRWluc3VtSW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSwgYXR0cnM6IEVpbnN1bUF0dHJzfSk6XG4gICAgVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHtlcXVhdGlvbn0gPSBhdHRycztcbiAgY29uc3QgdGVuc29ycyA9IGlucHV0cyBhcyBUZW5zb3JbXTtcblxuICBjb25zdCB7YWxsRGltcywgc3VtbWVkRGltcywgaWREaW1zfSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZGVjb2RlRWluc3VtRXF1YXRpb24oZXF1YXRpb24sIHRlbnNvcnMubGVuZ3RoKTtcbiAgYmFja2VuZF91dGlsLmNoZWNrRWluc3VtRGltU2l6ZXMoYWxsRGltcy5sZW5ndGgsIGlkRGltcywgdGVuc29ycyk7XG4gIGNvbnN0IHtwYXRoLCBzdGVwc30gPSBiYWNrZW5kX3V0aWwuZ2V0RWluc3VtQ29tcHV0ZVBhdGgoc3VtbWVkRGltcywgaWREaW1zKTtcblxuICBjb25zdCBuU3RlcHMgPSBzdGVwcy5sZW5ndGg7XG4gIGxldCBvdXQ6IFRlbnNvckluZm98bnVsbCA9IG51bGw7XG4gIGxldCBudW1EaW1zUmVtYWluaW5nID0gYWxsRGltcy5sZW5ndGg7XG4gIGNvbnN0IHRlbnNvcnNUb0Rpc3Bvc2U6IFRlbnNvckluZm9bXSA9IFtdO1xuICBmb3IgKGxldCBpID0gMDsgaSA8IG5TdGVwczsgKytpKSB7XG4gICAgZm9yIChjb25zdCBpZFRlcm0gb2Ygc3RlcHNbaV0pIHtcbiAgICAgIGNvbnN0IHtwZXJtdXRhdGlvbkluZGljZXM6IHBlcm0sIGV4cGFuZERpbXM6IGRpbXNUb0V4cGFuZH0gPVxuICAgICAgICAgIGJhY2tlbmRfdXRpbC5nZXRFaW5zdW1QZXJtdXRhdGlvbihudW1EaW1zUmVtYWluaW5nLCBpZERpbXNbaWRUZXJtXSk7XG4gICAgICBsZXQgeDogVGVuc29ySW5mbztcbiAgICAgIGlmIChiYWNrZW5kX3V0aWwuaXNJZGVudGl0eVBlcm11dGF0aW9uKHBlcm0pKSB7XG4gICAgICAgIHggPSB0ZW5zb3JzW2lkVGVybV07XG4gICAgICB9IGVsc2Uge1xuICAgICAgICB4ID0gdHJhbnNwb3NlKHtpbnB1dHM6IHt4OiB0ZW5zb3JzW2lkVGVybV19LCBiYWNrZW5kLCBhdHRyczoge3Blcm19fSk7XG4gICAgICAgIHRlbnNvcnNUb0Rpc3Bvc2UucHVzaCh4KTtcbiAgICAgIH1cbiAgICAgIGNvbnN0IHRhcmdldFNoYXBlOiBudW1iZXJbXSA9IHguc2hhcGUuc2xpY2UoKTtcbiAgICAgIGZvciAobGV0IGsgPSAwOyBrIDwgZGltc1RvRXhwYW5kLmxlbmd0aDsgKytrKSB7XG4gICAgICAgIHRhcmdldFNoYXBlLnNwbGljZShkaW1zVG9FeHBhbmRba10sIDAsIDEpO1xuICAgICAgfVxuXG4gICAgICBpZiAoIXV0aWwuYXJyYXlzRXF1YWwoeC5zaGFwZSwgdGFyZ2V0U2hhcGUpKSB7XG4gICAgICAgIHggPSByZXNoYXBlKHtpbnB1dHM6IHt4fSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogdGFyZ2V0U2hhcGV9fSk7XG4gICAgICAgIHRlbnNvcnNUb0Rpc3Bvc2UucHVzaCh4KTtcbiAgICAgIH1cbiAgICAgIGlmIChvdXQgPT09IG51bGwpIHtcbiAgICAgICAgb3V0ID0geDtcbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIC8vIHRzbGludDpkaXNhYmxlLW5leHQtbGluZTogbm8tdW5uZWNlc3NhcnktdHlwZS1hc3NlcnRpb25cbiAgICAgICAgb3V0ID0gbXVsdGlwbHkoe2lucHV0czoge2E6IHgsIGI6IG91dH0sIGJhY2tlbmR9KSBhcyBUZW5zb3JJbmZvO1xuICAgICAgICB0ZW5zb3JzVG9EaXNwb3NlLnB1c2gob3V0KTtcbiAgICAgIH1cbiAgICB9XG4gICAgaWYgKGkgPCBuU3RlcHMgLSAxKSB7XG4gICAgICBpZiAocGF0aFtpXSA+PSAwKSB7XG4gICAgICAgIG91dCA9IHN1bSh7XG4gICAgICAgICAgaW5wdXRzOiB7eDogb3V0fSxcbiAgICAgICAgICBiYWNrZW5kLFxuICAgICAgICAgIGF0dHJzOiB7XG4gICAgICAgICAgICBheGlzOiBwYXRoW2ldIC0gKGFsbERpbXMubGVuZ3RoIC0gbnVtRGltc1JlbWFpbmluZyksXG4gICAgICAgICAgICBrZWVwRGltczogZmFsc2VcbiAgICAgICAgICB9XG4gICAgICAgIH0pO1xuICAgICAgICB0ZW5zb3JzVG9EaXNwb3NlLnB1c2gob3V0KTtcbiAgICAgIH1cbiAgICAgIG51bURpbXNSZW1haW5pbmctLTtcbiAgICB9XG4gIH1cblxuICAvLyBDbGVhbiB1cCBpbnRlcm1lZGlhdGUgdGVuc29ycy5cbiAgZm9yIChjb25zdCB0ZW5zb3JJbmZvIG9mIHRlbnNvcnNUb0Rpc3Bvc2UpIHtcbiAgICBpZiAodGVuc29ySW5mbyA9PT0gb3V0KSB7XG4gICAgICBjb250aW51ZTtcbiAgICB9XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyh0ZW5zb3JJbmZvKTtcbiAgfVxuXG4gIHJldHVybiBvdXQ7XG59XG5cbmV4cG9ydCBjb25zdCBlaW5zdW1Db25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogRWluc3VtLFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IGVpbnN1bSBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=