gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import * as tf from '../index';
import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
import { expectArraysEqual } from '../test_util';
describeWithFlags('broadcastArgs', ALL_ENVS, () => {
    it('([1,1], [1,1]) -> [1,1]', async () => {
        const s1 = tf.tensor1d([1, 1], 'int32');
        const s2 = tf.tensor1d([1, 1], 'int32');
        const expected = [1, 1];
        expectArraysEqual(expected, await tf.broadcastArgs(s1, s2).array());
    });
    it('([1,1], [4,2]) -> [4,2]', async () => {
        const s1 = tf.tensor1d([1, 1], 'int32');
        const s2 = tf.tensor1d([4, 2], 'int32');
        const expected = [4, 2];
        expectArraysEqual(expected, await tf.broadcastArgs(s1, s2).array());
    });
    it('([1,6], [3,1]) -> [3,6]', async () => {
        const s1 = tf.tensor1d([1, 6], 'int32');
        const s2 = tf.tensor1d([3, 1], 'int32');
        const expected = [3, 6];
        expectArraysEqual(expected, await tf.broadcastArgs(s1, s2).array());
    });
    it('([1,6], [3,1,1,1]) -> [3,1,1,6]', async () => {
        const s1 = tf.tensor1d([1, 6], 'int32');
        const s2 = tf.tensor1d([3, 1, 1, 1], 'int32');
        const expected = [3, 1, 1, 6];
        expectArraysEqual(expected, await tf.broadcastArgs(s1, s2).array());
    });
    it('([1,6,-1], [3,1,1,1]) -> [3,1,6,-1]', async () => {
        const s1 = tf.tensor1d([1, 6, -1], 'int32');
        const s2 = tf.tensor1d([3, 1, 1, 1], 'int32');
        const expected = [3, 1, 6, -1];
        expectArraysEqual(expected, await tf.broadcastArgs(s1, s2).array());
    });
    it('([1,2], [1,3]) -> error', async () => {
        const s1 = tf.tensor1d([1, 2], 'int32');
        const s2 = tf.tensor1d([1, 3], 'int32');
        expect(() => tf.broadcastArgs(s1, s2).arraySync()).toThrowError();
    });
    it('([[1,1],[1,1]], [[1,1],[1,1]]) -> error', async () => {
        const s1 = tf.tensor2d([[1, 1], [1, 1]], [2, 2], 'int32');
        const s2 = tf.tensor2d([[1, 1], [1, 1]], [2, 2], 'int32');
        expect(() => tf.broadcastArgs(s1, s2).arraySync()).toThrowError();
    });
});
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYnJvYWRjYXN0X2FyZ3NfdGVzdC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL2Jyb2FkY2FzdF9hcmdzX3Rlc3QudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxLQUFLLEVBQUUsTUFBTSxVQUFVLENBQUM7QUFDL0IsT0FBTyxFQUFDLFFBQVEsRUFBRSxpQkFBaUIsRUFBQyxNQUFNLGlCQUFpQixDQUFDO0FBQzVELE9BQU8sRUFBQyxpQkFBaUIsRUFBQyxNQUFNLGNBQWMsQ0FBQztBQUUvQyxpQkFBaUIsQ0FBQyxlQUFlLEVBQUUsUUFBUSxFQUFFLEdBQUcsRUFBRTtJQUNoRCxFQUFFLENBQUMseUJBQXlCLEVBQUUsS0FBSyxJQUFJLEVBQUU7UUFDdkMsTUFBTSxFQUFFLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztRQUN4QyxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDO1FBQ3hDLE1BQU0sUUFBUSxHQUFHLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO1FBRXhCLGlCQUFpQixDQUFDLFFBQVEsRUFBRSxNQUFNLEVBQUUsQ0FBQyxhQUFhLENBQUMsRUFBRSxFQUFFLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7SUFDdEUsQ0FBQyxDQUFDLENBQUM7SUFFSCxFQUFFLENBQUMseUJBQXlCLEVBQUUsS0FBSyxJQUFJLEVBQUU7UUFDdkMsTUFBTSxFQUFFLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztRQUN4QyxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDO1FBQ3hDLE1BQU0sUUFBUSxHQUFHLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO1FBRXhCLGlCQUFpQixDQUFDLFFBQVEsRUFBRSxNQUFNLEVBQUUsQ0FBQyxhQUFhLENBQUMsRUFBRSxFQUFFLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7SUFDdEUsQ0FBQyxDQUFDLENBQUM7SUFFSCxFQUFFLENBQUMseUJBQXlCLEVBQUUsS0FBSyxJQUFJLEVBQUU7UUFDdkMsTUFBTSxFQUFFLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztRQUN4QyxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDO1FBQ3hDLE1BQU0sUUFBUSxHQUFHLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO1FBRXhCLGlCQUFpQixDQUFDLFFBQVEsRUFBRSxNQUFNLEVBQUUsQ0FBQyxhQUFhLENBQUMsRUFBRSxFQUFFLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7SUFDdEUsQ0FBQyxDQUFDLENBQUM7SUFFSCxFQUFFLENBQUMsaUNBQWlDLEVBQUUsS0FBSyxJQUFJLEVBQUU7UUFDL0MsTUFBTSxFQUFFLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztRQUN4QyxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsT0FBTyxDQUFDLENBQUM7UUFDOUMsTUFBTSxRQUFRLEdBQUcsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQztRQUU5QixpQkFBaUIsQ0FBQyxRQUFRLEVBQUUsTUFBTSxFQUFFLENBQUMsYUFBYSxDQUFDLEVBQUUsRUFBRSxFQUFFLENBQUMsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0lBQ3RFLENBQUMsQ0FBQyxDQUFDO0lBRUgsRUFBRSxDQUFDLHFDQUFxQyxFQUFFLEtBQUssSUFBSSxFQUFFO1FBQ25ELE1BQU0sRUFBRSxHQUFHLEVBQUUsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsT0FBTyxDQUFDLENBQUM7UUFDNUMsTUFBTSxFQUFFLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDO1FBQzlDLE1BQU0sUUFBUSxHQUFHLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUUvQixpQkFBaUIsQ0FBQyxRQUFRLEVBQUUsTUFBTSxFQUFFLENBQUMsYUFBYSxDQUFDLEVBQUUsRUFBRSxFQUFFLENBQUMsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0lBQ3RFLENBQUMsQ0FBQyxDQUFDO0lBRUgsRUFBRSxDQUFDLHlCQUF5QixFQUFFLEtBQUssSUFBSSxFQUFFO1FBQ3ZDLE1BQU0sRUFBRSxHQUFHLEVBQUUsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsT0FBTyxDQUFDLENBQUM7UUFDeEMsTUFBTSxFQUFFLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztRQUV4QyxNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsRUFBRSxDQUFDLGFBQWEsQ0FBQyxFQUFFLEVBQUUsRUFBRSxDQUFDLENBQUMsU0FBUyxFQUFFLENBQUMsQ0FBQyxZQUFZLEVBQUUsQ0FBQztJQUNwRSxDQUFDLENBQUMsQ0FBQztJQUVILEVBQUUsQ0FBQyx5Q0FBeUMsRUFBRSxLQUFLLElBQUksRUFBRTtRQUN2RCxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztRQUMxRCxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQztRQUUxRCxNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsRUFBRSxDQUFDLGFBQWEsQ0FBQyxFQUFFLEVBQUUsRUFBRSxDQUFDLENBQUMsU0FBUyxFQUFFLENBQUMsQ0FBQyxZQUFZLEVBQUUsQ0FBQztJQUNwRSxDQUFDLENBQUMsQ0FBQztBQUNMLENBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjEgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQgKiBhcyB0ZiBmcm9tICcuLi9pbmRleCc7XG5pbXBvcnQge0FMTF9FTlZTLCBkZXNjcmliZVdpdGhGbGFnc30gZnJvbSAnLi4vamFzbWluZV91dGlsJztcbmltcG9ydCB7ZXhwZWN0QXJyYXlzRXF1YWx9IGZyb20gJy4uL3Rlc3RfdXRpbCc7XG5cbmRlc2NyaWJlV2l0aEZsYWdzKCdicm9hZGNhc3RBcmdzJywgQUxMX0VOVlMsICgpID0+IHtcbiAgaXQoJyhbMSwxXSwgWzEsMV0pIC0+IFsxLDFdJywgYXN5bmMgKCkgPT4ge1xuICAgIGNvbnN0IHMxID0gdGYudGVuc29yMWQoWzEsIDFdLCAnaW50MzInKTtcbiAgICBjb25zdCBzMiA9IHRmLnRlbnNvcjFkKFsxLCAxXSwgJ2ludDMyJyk7XG4gICAgY29uc3QgZXhwZWN0ZWQgPSBbMSwgMV07XG5cbiAgICBleHBlY3RBcnJheXNFcXVhbChleHBlY3RlZCwgYXdhaXQgdGYuYnJvYWRjYXN0QXJncyhzMSwgczIpLmFycmF5KCkpO1xuICB9KTtcblxuICBpdCgnKFsxLDFdLCBbNCwyXSkgLT4gWzQsMl0nLCBhc3luYyAoKSA9PiB7XG4gICAgY29uc3QgczEgPSB0Zi50ZW5zb3IxZChbMSwgMV0sICdpbnQzMicpO1xuICAgIGNvbnN0IHMyID0gdGYudGVuc29yMWQoWzQsIDJdLCAnaW50MzInKTtcbiAgICBjb25zdCBleHBlY3RlZCA9IFs0LCAyXTtcblxuICAgIGV4cGVjdEFycmF5c0VxdWFsKGV4cGVjdGVkLCBhd2FpdCB0Zi5icm9hZGNhc3RBcmdzKHMxLCBzMikuYXJyYXkoKSk7XG4gIH0pO1xuXG4gIGl0KCcoWzEsNl0sIFszLDFdKSAtPiBbMyw2XScsIGFzeW5jICgpID0+IHtcbiAgICBjb25zdCBzMSA9IHRmLnRlbnNvcjFkKFsxLCA2XSwgJ2ludDMyJyk7XG4gICAgY29uc3QgczIgPSB0Zi50ZW5zb3IxZChbMywgMV0sICdpbnQzMicpO1xuICAgIGNvbnN0IGV4cGVjdGVkID0gWzMsIDZdO1xuXG4gICAgZXhwZWN0QXJyYXlzRXF1YWwoZXhwZWN0ZWQsIGF3YWl0IHRmLmJyb2FkY2FzdEFyZ3MoczEsIHMyKS5hcnJheSgpKTtcbiAgfSk7XG5cbiAgaXQoJyhbMSw2XSwgWzMsMSwxLDFdKSAtPiBbMywxLDEsNl0nLCBhc3luYyAoKSA9PiB7XG4gICAgY29uc3QgczEgPSB0Zi50ZW5zb3IxZChbMSwgNl0sICdpbnQzMicpO1xuICAgIGNvbnN0IHMyID0gdGYudGVuc29yMWQoWzMsIDEsIDEsIDFdLCAnaW50MzInKTtcbiAgICBjb25zdCBleHBlY3RlZCA9IFszLCAxLCAxLCA2XTtcblxuICAgIGV4cGVjdEFycmF5c0VxdWFsKGV4cGVjdGVkLCBhd2FpdCB0Zi5icm9hZGNhc3RBcmdzKHMxLCBzMikuYXJyYXkoKSk7XG4gIH0pO1xuXG4gIGl0KCcoWzEsNiwtMV0sIFszLDEsMSwxXSkgLT4gWzMsMSw2LC0xXScsIGFzeW5jICgpID0+IHtcbiAgICBjb25zdCBzMSA9IHRmLnRlbnNvcjFkKFsxLCA2LCAtMV0sICdpbnQzMicpO1xuICAgIGNvbnN0IHMyID0gdGYudGVuc29yMWQoWzMsIDEsIDEsIDFdLCAnaW50MzInKTtcbiAgICBjb25zdCBleHBlY3RlZCA9IFszLCAxLCA2LCAtMV07XG5cbiAgICBleHBlY3RBcnJheXNFcXVhbChleHBlY3RlZCwgYXdhaXQgdGYuYnJvYWRjYXN0QXJncyhzMSwgczIpLmFycmF5KCkpO1xuICB9KTtcblxuICBpdCgnKFsxLDJdLCBbMSwzXSkgLT4gZXJyb3InLCBhc3luYyAoKSA9PiB7XG4gICAgY29uc3QgczEgPSB0Zi50ZW5zb3IxZChbMSwgMl0sICdpbnQzMicpO1xuICAgIGNvbnN0IHMyID0gdGYudGVuc29yMWQoWzEsIDNdLCAnaW50MzInKTtcblxuICAgIGV4cGVjdCgoKSA9PiB0Zi5icm9hZGNhc3RBcmdzKHMxLCBzMikuYXJyYXlTeW5jKCkpLnRvVGhyb3dFcnJvcigpO1xuICB9KTtcblxuICBpdCgnKFtbMSwxXSxbMSwxXV0sIFtbMSwxXSxbMSwxXV0pIC0+IGVycm9yJywgYXN5bmMgKCkgPT4ge1xuICAgIGNvbnN0IHMxID0gdGYudGVuc29yMmQoW1sxLCAxXSwgWzEsIDFdXSwgWzIsIDJdLCAnaW50MzInKTtcbiAgICBjb25zdCBzMiA9IHRmLnRlbnNvcjJkKFtbMSwgMV0sIFsxLCAxXV0sIFsyLCAyXSwgJ2ludDMyJyk7XG5cbiAgICBleHBlY3QoKCkgPT4gdGYuYnJvYWRjYXN0QXJncyhzMSwgczIpLmFycmF5U3luYygpKS50b1Rocm93RXJyb3IoKTtcbiAgfSk7XG59KTtcbiJdfQ==