/** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Einsum, util } from '@tensorflow/tfjs-core'; import { multiply } from './Multiply'; import { reshape } from './Reshape'; import { sum } from './Sum'; import { transpose } from './Transpose'; export function einsum(args) { const { inputs, backend, attrs } = args; const { equation } = attrs; const tensors = inputs; const { allDims, summedDims, idDims } = backend_util.decodeEinsumEquation(equation, tensors.length); backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors); const { path, steps } = backend_util.getEinsumComputePath(summedDims, idDims); const nSteps = steps.length; let out = null; let numDimsRemaining = allDims.length; const tensorsToDispose = []; for (let i = 0; i < nSteps; ++i) { for (const idTerm of steps[i]) { const { permutationIndices: perm, expandDims: dimsToExpand } = backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]); let x; if (backend_util.isIdentityPermutation(perm)) { x = tensors[idTerm]; } else { x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } }); tensorsToDispose.push(x); } const targetShape = x.shape.slice(); for (let k = 0; k < dimsToExpand.length; ++k) { targetShape.splice(dimsToExpand[k], 0, 1); } if (!util.arraysEqual(x.shape, targetShape)) { x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } }); tensorsToDispose.push(x); } if (out === null) { out = x; } else { // tslint:disable-next-line: no-unnecessary-type-assertion out = multiply({ inputs: { a: x, b: out }, backend }); tensorsToDispose.push(out); } } if (i < nSteps - 1) { if (path[i] >= 0) { out = sum({ inputs: { x: out }, backend, attrs: { axis: path[i] - (allDims.length - numDimsRemaining), keepDims: false } }); tensorsToDispose.push(out); } numDimsRemaining--; } } // Clean up intermediate tensors. for (const tensorInfo of tensorsToDispose) { if (tensorInfo === out) { continue; } backend.disposeIntermediateTensorInfo(tensorInfo); } return out; } export const einsumConfig = { kernelName: Einsum, backendName: 'webgl', kernelFunc: einsum }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Einsum.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/Einsum.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,MAAM,EAA2E,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAI1I,OAAO,EAAC,QAAQ,EAAC,MAAM,YAAY,CAAC;AACpC,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAClC,OAAO,EAAC,GAAG,EAAC,MAAM,OAAO,CAAC;AAC1B,OAAO,EAAC,SAAS,EAAC,MAAM,aAAa,CAAC;AAEtC,MAAM,UAAU,MAAM,CAClB,IACyE;IAE3E,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,QAAQ,EAAC,GAAG,KAAK,CAAC;IACzB,MAAM,OAAO,GAAG,MAAkB,CAAC;IAEnC,MAAM,EAAC,OAAO,EAAE,UAAU,EAAE,MAAM,EAAC,GAC/B,YAAY,CAAC,oBAAoB,CAAC,QAAQ,EAAE,OAAO,CAAC,MAAM,CAAC,CAAC;IAChE,YAAY,CAAC,mBAAmB,CAAC,OAAO,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,CAAC,CAAC;IAClE,MAAM,EAAC,IAAI,EAAE,KAAK,EAAC,GAAG,YAAY,CAAC,oBAAoB,CAAC,UAAU,EAAE,MAAM,CAAC,CAAC;IAE5E,MAAM,MAAM,GAAG,KAAK,CAAC,MAAM,CAAC;IAC5B,IAAI,GAAG,GAAoB,IAAI,CAAC;IAChC,IAAI,gBAAgB,GAAG,OAAO,CAAC,MAAM,CAAC;IACtC,MAAM,gBAAgB,GAAiB,EAAE,CAAC;IAC1C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,EAAE,CAAC,EAAE;QAC/B,KAAK,MAAM,MAAM,IAAI,KAAK,CAAC,CAAC,CAAC,EAAE;YAC7B,MAAM,EAAC,kBAAkB,EAAE,IAAI,EAAE,UAAU,EAAE,YAAY,EAAC,GACtD,YAAY,CAAC,oBAAoB,CAAC,gBAAgB,EAAE,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC;YACxE,IAAI,CAAa,CAAC;YAClB,IAAI,YAAY,CAAC,qBAAqB,CAAC,IAAI,CAAC,EAAE;gBAC5C,CAAC,GAAG,OAAO,CAAC,MAAM,CAAC,CAAC;aACrB;iBAAM;gBACL,CAAC,GAAG,SAAS,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,OAAO,CAAC,MAAM,CAAC,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,IAAI,EAAC,EAAC,CAAC,CAAC;gBACtE,gBAAgB,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aAC1B;YACD,MAAM,WAAW,GAAa,CAAC,CAAC,KAAK,CAAC,KAAK,EAAE,CAAC;YAC9C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,YAAY,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBAC5C,WAAW,CAAC,MAAM,CAAC,YAAY,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;aAC3C;YAED,IAAI,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,KAAK,EAAE,WAAW,CAAC,EAAE;gBAC3C,CAAC,GAAG,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,WAAW,EAAC,EAAC,CAAC,CAAC;gBACjE,gBAAgB,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aAC1B;YACD,IAAI,GAAG,KAAK,IAAI,EAAE;gBAChB,GAAG,GAAG,CAAC,CAAC;aACT;iBAAM;gBACL,0DAA0D;gBAC1D,GAAG,GAAG,QAAQ,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,GAAG,EAAC,EAAE,OAAO,EAAC,CAAe,CAAC;gBAChE,gBAAgB,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;aAC5B;SACF;QACD,IAAI,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE;YAClB,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE;gBAChB,GAAG,GAAG,GAAG,CAAC;oBACR,MAAM,EAAE,EAAC,CAAC,EAAE,GAAG,EAAC;oBAChB,OAAO;oBACP,KAAK,EAAE;wBACL,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,MAAM,GAAG,gBAAgB,CAAC;wBACnD,QAAQ,EAAE,KAAK;qBAChB;iBACF,CAAC,CAAC;gBACH,gBAAgB,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;aAC5B;YACD,gBAAgB,EAAE,CAAC;SACpB;KACF;IAED,iCAAiC;IACjC,KAAK,MAAM,UAAU,IAAI,gBAAgB,EAAE;QACzC,IAAI,UAAU,KAAK,GAAG,EAAE;YACtB,SAAS;SACV;QACD,OAAO,CAAC,6BAA6B,CAAC,UAAU,CAAC,CAAC;KACnD;IAED,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,YAAY,GAAiB;IACxC,UAAU,EAAE,MAAM;IAClB,WAAW,EAAE,OAAO;IACpB,UAAU,EAAE,MAA+B;CAC5C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2021 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, Einsum, EinsumAttrs, EinsumInputs, KernelConfig, KernelFunc, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendWebGL} from '../backend_webgl';\n\nimport {multiply} from './Multiply';\nimport {reshape} from './Reshape';\nimport {sum} from './Sum';\nimport {transpose} from './Transpose';\n\nexport function einsum(\n    args:\n        {inputs: EinsumInputs, backend: MathBackendWebGL, attrs: EinsumAttrs}):\n    TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {equation} = attrs;\n  const tensors = inputs as Tensor[];\n\n  const {allDims, summedDims, idDims} =\n      backend_util.decodeEinsumEquation(equation, tensors.length);\n  backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors);\n  const {path, steps} = backend_util.getEinsumComputePath(summedDims, idDims);\n\n  const nSteps = steps.length;\n  let out: TensorInfo|null = null;\n  let numDimsRemaining = allDims.length;\n  const tensorsToDispose: TensorInfo[] = [];\n  for (let i = 0; i < nSteps; ++i) {\n    for (const idTerm of steps[i]) {\n      const {permutationIndices: perm, expandDims: dimsToExpand} =\n          backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]);\n      let x: TensorInfo;\n      if (backend_util.isIdentityPermutation(perm)) {\n        x = tensors[idTerm];\n      } else {\n        x = transpose({inputs: {x: tensors[idTerm]}, backend, attrs: {perm}});\n        tensorsToDispose.push(x);\n      }\n      const targetShape: number[] = x.shape.slice();\n      for (let k = 0; k < dimsToExpand.length; ++k) {\n        targetShape.splice(dimsToExpand[k], 0, 1);\n      }\n\n      if (!util.arraysEqual(x.shape, targetShape)) {\n        x = reshape({inputs: {x}, backend, attrs: {shape: targetShape}});\n        tensorsToDispose.push(x);\n      }\n      if (out === null) {\n        out = x;\n      } else {\n        // tslint:disable-next-line: no-unnecessary-type-assertion\n        out = multiply({inputs: {a: x, b: out}, backend}) as TensorInfo;\n        tensorsToDispose.push(out);\n      }\n    }\n    if (i < nSteps - 1) {\n      if (path[i] >= 0) {\n        out = sum({\n          inputs: {x: out},\n          backend,\n          attrs: {\n            axis: path[i] - (allDims.length - numDimsRemaining),\n            keepDims: false\n          }\n        });\n        tensorsToDispose.push(out);\n      }\n      numDimsRemaining--;\n    }\n  }\n\n  // Clean up intermediate tensors.\n  for (const tensorInfo of tensorsToDispose) {\n    if (tensorInfo === out) {\n      continue;\n    }\n    backend.disposeIntermediateTensorInfo(tensorInfo);\n  }\n\n  return out;\n}\n\nexport const einsumConfig: KernelConfig = {\n  kernelName: Einsum,\n  backendName: 'webgl',\n  kernelFunc: einsum as unknown as KernelFunc\n};\n"]}