chenyc
2025-05-29 92f69c57b920cf62ecc9f15f9ed196fa26dbf2ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { ENGINE } from '../engine';
import { Einsum } from '../kernel_names';
import { convertToTensor } from '../tensor_util_env';
import { op } from './operation';
/**
 * Tensor contraction over specified indices and outer product.
 *
 * `einsum` allows defining Tensors by defining their element-wise computation.
 * This computation is based on
 * [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation).
 *
 * Some special cases include:
 *
 * Matrix multiplication:
 * ```js
 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
 * const y = tf.tensor2d([[0, 1], [2, 3], [4, 5]]);
 * x.print();
 * y.print();
 * tf.einsum('ij,jk->ik', x, y).print();
 * ```
 *
 * Dot product:
 * ```js
 * const x = tf.tensor1d([1, 2, 3]);
 * const y = tf.tensor1d([0, 1, 2]);
 * x.print();
 * y.print();
 * tf.einsum('i,i->', x, y).print();
 * ```
 *
 * Batch dot product:
 * ```js
 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
 * const y = tf.tensor2d([[0, 1, 2], [3, 4, 5]]);
 * x.print();
 * y.print();
 * tf.einsum('bi,bi->b', x, y).print();
 * ```
 *
 * Outer prouduct:
 * ```js
 * const x = tf.tensor1d([1, 3, 5]);
 * const y = tf.tensor1d([2, 4, 6]);
 * x.print();
 * y.print();
 * tf.einsum('i,j->ij', x, y).print();
 * ```
 *
 * Matrix transpose:
 * ```js
 * const x = tf.tensor2d([[1, 2], [3, 4]]);
 * x.print();
 * tf.einsum('ij->ji', x).print();
 * ```
 *
 * Batch matrix transpose:
 * ```js
 * const x = tf.tensor3d([[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]]);
 * x.print();
 * tf.einsum('bij->bji', x).print();
 * ```
 *
 * Limitations:
 *
 * This implementation of einsum has the following limitations:
 *
 * - Does not support >2 input tensors.
 * - Does not support duplicate axes for any given input tensor. E.g., equation
 *   'ii->' is not supported.
 * - The `...` notation is not supported.
 *
 * @param equation a string describing the contraction, in the same format as
 * [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
 * @param tensors the input(s) to contract (each one a Tensor), whose shapes
 *     should be consistent with equation.
 * @returns The output tensor.
 *
 * @doc {heading: 'Tensors', subheading: 'Matrices'}
 */
export function einsum_(equation, ...tensors) {
    const $tensors = tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'einsum'));
    const attrs = { equation };
    return ENGINE.runKernel(Einsum, $tensors, attrs);
}
export const einsum = /* @__PURE__ */ op({ einsum_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZWluc3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvZWluc3VtLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDakMsT0FBTyxFQUFDLE1BQU0sRUFBYyxNQUFNLGlCQUFpQixDQUFDO0FBSXBELE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUVuRCxPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRS9COzs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7R0EyRUc7QUFDSCxNQUFNLFVBQVUsT0FBTyxDQUFDLFFBQWdCLEVBQUUsR0FBRyxPQUFpQjtJQUM1RCxNQUFNLFFBQVEsR0FDVixPQUFPLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsZUFBZSxDQUFDLENBQUMsRUFBRSxVQUFVLENBQUMsRUFBRSxFQUFFLFFBQVEsQ0FBQyxDQUFDLENBQUM7SUFDdkUsTUFBTSxLQUFLLEdBQWdCLEVBQUMsUUFBUSxFQUFDLENBQUM7SUFDdEMsT0FBTyxNQUFNLENBQUMsU0FBUyxDQUNuQixNQUFNLEVBQUUsUUFBcUMsRUFDN0MsS0FBZ0MsQ0FBQyxDQUFDO0FBQ3hDLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxNQUFNLEdBQUcsZUFBZSxDQUFDLEVBQUUsQ0FBQyxFQUFDLE9BQU8sRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7RU5HSU5FfSBmcm9tICcuLi9lbmdpbmUnO1xuaW1wb3J0IHtFaW5zdW0sIEVpbnN1bUF0dHJzfSBmcm9tICcuLi9rZXJuZWxfbmFtZXMnO1xuaW1wb3J0IHtOYW1lZEF0dHJNYXB9IGZyb20gJy4uL2tlcm5lbF9yZWdpc3RyeSc7XG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7TmFtZWRUZW5zb3JNYXB9IGZyb20gJy4uL3RlbnNvcl90eXBlcyc7XG5pbXBvcnQge2NvbnZlcnRUb1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yX3V0aWxfZW52JztcblxuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuXG4vKipcbiAqIFRlbnNvciBjb250cmFjdGlvbiBvdmVyIHNwZWNpZmllZCBpbmRpY2VzIGFuZCBvdXRlciBwcm9kdWN0LlxuICpcbiAqIGBlaW5zdW1gIGFsbG93cyBkZWZpbmluZyBUZW5zb3JzIGJ5IGRlZmluaW5nIHRoZWlyIGVsZW1lbnQtd2lzZSBjb21wdXRhdGlvbi5cbiAqIFRoaXMgY29tcHV0YXRpb24gaXMgYmFzZWQgb25cbiAqIFtFaW5zdGVpbiBzdW1tYXRpb25dKGh0dHBzOi8vZW4ud2lraXBlZGlhLm9yZy93aWtpL0VpbnN0ZWluX25vdGF0aW9uKS5cbiAqXG4gKiBTb21lIHNwZWNpYWwgY2FzZXMgaW5jbHVkZTpcbiAqXG4gKiBNYXRyaXggbXVsdGlwbGljYXRpb246XG4gKiBgYGBqc1xuICogY29uc3QgeCA9IHRmLnRlbnNvcjJkKFtbMSwgMiwgM10sIFs0LCA1LCA2XV0pO1xuICogY29uc3QgeSA9IHRmLnRlbnNvcjJkKFtbMCwgMV0sIFsyLCAzXSwgWzQsIDVdXSk7XG4gKiB4LnByaW50KCk7XG4gKiB5LnByaW50KCk7XG4gKiB0Zi5laW5zdW0oJ2lqLGprLT5paycsIHgsIHkpLnByaW50KCk7XG4gKiBgYGBcbiAqXG4gKiBEb3QgcHJvZHVjdDpcbiAqIGBgYGpzXG4gKiBjb25zdCB4ID0gdGYudGVuc29yMWQoWzEsIDIsIDNdKTtcbiAqIGNvbnN0IHkgPSB0Zi50ZW5zb3IxZChbMCwgMSwgMl0pO1xuICogeC5wcmludCgpO1xuICogeS5wcmludCgpO1xuICogdGYuZWluc3VtKCdpLGktPicsIHgsIHkpLnByaW50KCk7XG4gKiBgYGBcbiAqXG4gKiBCYXRjaCBkb3QgcHJvZHVjdDpcbiAqIGBgYGpzXG4gKiBjb25zdCB4ID0gdGYudGVuc29yMmQoW1sxLCAyLCAzXSwgWzQsIDUsIDZdXSk7XG4gKiBjb25zdCB5ID0gdGYudGVuc29yMmQoW1swLCAxLCAyXSwgWzMsIDQsIDVdXSk7XG4gKiB4LnByaW50KCk7XG4gKiB5LnByaW50KCk7XG4gKiB0Zi5laW5zdW0oJ2JpLGJpLT5iJywgeCwgeSkucHJpbnQoKTtcbiAqIGBgYFxuICpcbiAqIE91dGVyIHByb3VkdWN0OlxuICogYGBganNcbiAqIGNvbnN0IHggPSB0Zi50ZW5zb3IxZChbMSwgMywgNV0pO1xuICogY29uc3QgeSA9IHRmLnRlbnNvcjFkKFsyLCA0LCA2XSk7XG4gKiB4LnByaW50KCk7XG4gKiB5LnByaW50KCk7XG4gKiB0Zi5laW5zdW0oJ2ksai0+aWonLCB4LCB5KS5wcmludCgpO1xuICogYGBgXG4gKlxuICogTWF0cml4IHRyYW5zcG9zZTpcbiAqIGBgYGpzXG4gKiBjb25zdCB4ID0gdGYudGVuc29yMmQoW1sxLCAyXSwgWzMsIDRdXSk7XG4gKiB4LnByaW50KCk7XG4gKiB0Zi5laW5zdW0oJ2lqLT5qaScsIHgpLnByaW50KCk7XG4gKiBgYGBcbiAqXG4gKiBCYXRjaCBtYXRyaXggdHJhbnNwb3NlOlxuICogYGBganNcbiAqIGNvbnN0IHggPSB0Zi50ZW5zb3IzZChbW1sxLCAyXSwgWzMsIDRdXSwgW1stMSwgLTJdLCBbLTMsIC00XV1dKTtcbiAqIHgucHJpbnQoKTtcbiAqIHRmLmVpbnN1bSgnYmlqLT5iamknLCB4KS5wcmludCgpO1xuICogYGBgXG4gKlxuICogTGltaXRhdGlvbnM6XG4gKlxuICogVGhpcyBpbXBsZW1lbnRhdGlvbiBvZiBlaW5zdW0gaGFzIHRoZSBmb2xsb3dpbmcgbGltaXRhdGlvbnM6XG4gKlxuICogLSBEb2VzIG5vdCBzdXBwb3J0ID4yIGlucHV0IHRlbnNvcnMuXG4gKiAtIERvZXMgbm90IHN1cHBvcnQgZHVwbGljYXRlIGF4ZXMgZm9yIGFueSBnaXZlbiBpbnB1dCB0ZW5zb3IuIEUuZy4sIGVxdWF0aW9uXG4gKiAgICdpaS0+JyBpcyBub3Qgc3VwcG9ydGVkLlxuICogLSBUaGUgYC4uLmAgbm90YXRpb24gaXMgbm90IHN1cHBvcnRlZC5cbiAqXG4gKiBAcGFyYW0gZXF1YXRpb24gYSBzdHJpbmcgZGVzY3JpYmluZyB0aGUgY29udHJhY3Rpb24sIGluIHRoZSBzYW1lIGZvcm1hdCBhc1xuICogW251bXB5LmVpbnN1bV0oaHR0cHM6Ly9udW1weS5vcmcvZG9jL3N0YWJsZS9yZWZlcmVuY2UvZ2VuZXJhdGVkL251bXB5LmVpbnN1bS5odG1sKS5cbiAqIEBwYXJhbSB0ZW5zb3JzIHRoZSBpbnB1dChzKSB0byBjb250cmFjdCAoZWFjaCBvbmUgYSBUZW5zb3IpLCB3aG9zZSBzaGFwZXNcbiAqICAgICBzaG91bGQgYmUgY29uc2lzdGVudCB3aXRoIGVxdWF0aW9uLlxuICogQHJldHVybnMgVGhlIG91dHB1dCB0ZW5zb3IuXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ1RlbnNvcnMnLCBzdWJoZWFkaW5nOiAnTWF0cmljZXMnfVxuICovXG5leHBvcnQgZnVuY3Rpb24gZWluc3VtXyhlcXVhdGlvbjogc3RyaW5nLCAuLi50ZW5zb3JzOiBUZW5zb3JbXSk6IFRlbnNvciB7XG4gIGNvbnN0ICR0ZW5zb3JzID1cbiAgICAgIHRlbnNvcnMubWFwKCh0LCBpKSA9PiBjb252ZXJ0VG9UZW5zb3IodCwgYHRlbnNvcnMke2l9YCwgJ2VpbnN1bScpKTtcbiAgY29uc3QgYXR0cnM6IEVpbnN1bUF0dHJzID0ge2VxdWF0aW9ufTtcbiAgcmV0dXJuIEVOR0lORS5ydW5LZXJuZWwoXG4gICAgICBFaW5zdW0sICR0ZW5zb3JzIGFzIHVua25vd24gYXMgTmFtZWRUZW5zb3JNYXAsXG4gICAgICBhdHRycyBhcyB1bmtub3duIGFzIE5hbWVkQXR0ck1hcCk7XG59XG5cbmV4cG9ydCBjb25zdCBlaW5zdW0gPSAvKiBAX19QVVJFX18gKi8gb3Aoe2VpbnN1bV99KTtcbiJdfQ==