/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/// <amd-module name="@tensorflow/tfjs-core/dist/backends/einsum_util" />
|
/**
|
* Utility functions for computing einsum (tensor contraction and summation
|
* based on Einstein summation.)
|
*/
|
import { Tensor } from '../tensor';
|
/**
|
* Parse an equation for einsum.
|
*
|
* @param equation The einsum equation (e.g., "ij,jk->ik").
|
* @param numTensors Number of tensors provided along with `equation`. Used to
|
* check matching number of input tensors.
|
* @returns An object consisting of the following fields:
|
* - allDims: all dimension names as strings.
|
* - summedDims: a list of all dimensions being summed over, as indices to
|
* the elements of `allDims`.
|
* - idDims: indices of the dimensions in each input tensor, as indices to
|
* the elements of `allDims.
|
*/
|
export declare function decodeEinsumEquation(equation: string, numTensors: number): {
|
allDims: string[];
|
summedDims: number[];
|
idDims: number[][];
|
};
|
/**
|
* Get the permutation for a given input tensor.
|
*
|
* @param nDims Total number of dimension of all tensors involved in the einsum
|
* operation.
|
* @param idDims Dimension indices involve in the tensor in question.
|
* @returns An object consisting of the following fields:
|
* - permutationIndices: Indices to permute the axes of the tensor with.
|
* - expandDims: Indices to the dimension that need to be expanded from the
|
* tensor after permutation.
|
*/
|
export declare function getEinsumPermutation(nDims: number, idDims: number[]): {
|
permutationIndices: number[];
|
expandDims: number[];
|
};
|
/**
|
* Checks that the dimension sizes from different input tensors match the
|
* equation.
|
*/
|
export declare function checkEinsumDimSizes(nDims: number, idDims: number[][], tensors: Tensor[]): void;
|
/**
|
* Gets path of computation for einsum.
|
*
|
* @param summedDims indices to the dimensions being summed over.
|
* @param idDims A look up table for the dimensions present in each input
|
* tensor. Each consituent array contains indices for the dimensions in the
|
* corresponding input tensor.
|
*
|
* @return A map with two fields:
|
* - path: The path of computation, with each element indicating the dimension
|
* being summed over after the element-wise multiplication in that step.
|
* - steps: With the same length as `path`. Each element contains the indices
|
* to the input tensors being used for element-wise multiplication in the
|
* corresponding step.
|
*/
|
export declare function getEinsumComputePath(summedDims: number[], idDims: number[][]): {
|
path: number[];
|
steps: number[][];
|
};
|
/** Determines if an axes permutation is the identity permutation. */
|
export declare function isIdentityPermutation(perm: number[]): boolean;
|