gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, util } from '@tensorflow/tfjs-core';
/**
 * Template that creates implementation for binary ops. Supports broadcast.
 */
export function createSimpleBinaryKernelImpl(op) {
    return (aShape, bShape, aVals, bVals, dtype) => {
        const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
        const resultRank = newShape.length;
        const resultStrides = util.computeStrides(newShape);
        const resultSize = util.sizeFromShape(newShape);
        const result = util.getTypedArrayFromDType(dtype, resultSize);
        const aRank = aShape.length;
        const bRank = bShape.length;
        const aStrides = util.computeStrides(aShape);
        const bStrides = util.computeStrides(bShape);
        const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape);
        const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape);
        if (aBroadcastDims.length + bBroadcastDims.length === 0) {
            for (let i = 0; i < result.length; ++i) {
                result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
            }
        }
        else {
            for (let i = 0; i < result.length; ++i) {
                const loc = util.indexToLoc(i, resultRank, resultStrides);
                const aLoc = loc.slice(-aRank);
                aBroadcastDims.forEach(d => aLoc[d] = 0);
                const aIndex = util.locToIndex(aLoc, aRank, aStrides);
                const bLoc = loc.slice(-bRank);
                bBroadcastDims.forEach(d => bLoc[d] = 0);
                const bIndex = util.locToIndex(bLoc, bRank, bStrides);
                result[i] = op(aVals[aIndex], bVals[bIndex]);
            }
        }
        return [result, newShape];
    };
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmluYXJ5X2ltcGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy91dGlscy9iaW5hcnlfaW1wbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFxRCxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUk1Rzs7R0FFRztBQUNILE1BQU0sVUFBVSw0QkFBNEIsQ0FBQyxFQUF5QjtJQUVwRSxPQUFPLENBQUMsTUFBZ0IsRUFBRSxNQUFnQixFQUFFLEtBQWlCLEVBQ3JELEtBQWlCLEVBQUUsS0FBZSxFQUEwQixFQUFFO1FBQ3BFLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQywwQkFBMEIsQ0FBQyxNQUFNLEVBQUUsTUFBTSxDQUFDLENBQUM7UUFFekUsTUFBTSxVQUFVLEdBQUcsUUFBUSxDQUFDLE1BQU0sQ0FBQztRQUNuQyxNQUFNLGFBQWEsR0FBRyxJQUFJLENBQUMsY0FBYyxDQUFDLFFBQVEsQ0FBQyxDQUFDO1FBQ3BELE1BQU0sVUFBVSxHQUFHLElBQUksQ0FBQyxhQUFhLENBQUMsUUFBUSxDQUFDLENBQUM7UUFFaEQsTUFBTSxNQUFNLEdBQ1IsSUFBSSxDQUFDLHNCQUFzQixDQUFDLEtBQXdCLEVBQUUsVUFBVSxDQUFDLENBQUM7UUFFdEUsTUFBTSxLQUFLLEdBQUcsTUFBTSxDQUFDLE1BQU0sQ0FBQztRQUM1QixNQUFNLEtBQUssR0FBRyxNQUFNLENBQUMsTUFBTSxDQUFDO1FBRTVCLE1BQU0sUUFBUSxHQUFHLElBQUksQ0FBQyxjQUFjLENBQUMsTUFBTSxDQUFDLENBQUM7UUFDN0MsTUFBTSxRQUFRLEdBQUcsSUFBSSxDQUFDLGNBQWMsQ0FBQyxNQUFNLENBQUMsQ0FBQztRQUU3QyxNQUFNLGNBQWMsR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsTUFBTSxFQUFFLFFBQVEsQ0FBQyxDQUFDO1FBQ3ZFLE1BQU0sY0FBYyxHQUFHLFlBQVksQ0FBQyxnQkFBZ0IsQ0FBQyxNQUFNLEVBQUUsUUFBUSxDQUFDLENBQUM7UUFFdkUsSUFBSSxjQUFjLENBQUMsTUFBTSxHQUFHLGNBQWMsQ0FBQyxNQUFNLEtBQUssQ0FBQyxFQUFFO1lBQ3ZELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxNQUFNLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO2dCQUN0QyxNQUFNLENBQUMsQ0FBQyxDQUFDLEdBQUcsRUFBRSxDQUFDLEtBQUssQ0FBQyxDQUFDLEdBQUcsS0FBSyxDQUFDLE1BQU0sQ0FBQyxFQUFFLEtBQUssQ0FBQyxDQUFDLEdBQUcsS0FBSyxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUM7YUFDbEU7U0FDRjthQUFNO1lBQ0wsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLE1BQU0sQ0FBQyxNQUFNLEVBQUUsRUFBRSxDQUFDLEVBQUU7Z0JBQ3RDLE1BQU0sR0FBRyxHQUFHLElBQUksQ0FBQyxVQUFVLENBQUMsQ0FBQyxFQUFFLFVBQVUsRUFBRSxhQUFhLENBQUMsQ0FBQztnQkFFMUQsTUFBTSxJQUFJLEdBQUcsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO2dCQUMvQixjQUFjLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDO2dCQUN6QyxNQUFNLE1BQU0sR0FBRyxJQUFJLENBQUMsVUFBVSxDQUFDLElBQUksRUFBRSxLQUFLLEVBQUUsUUFBUSxDQUFDLENBQUM7Z0JBRXRELE1BQU0sSUFBSSxHQUFHLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztnQkFDL0IsY0FBYyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztnQkFDekMsTUFBTSxNQUFNLEdBQUcsSUFBSSxDQUFDLFVBQVUsQ0FBQyxJQUFJLEVBQUUsS0FBSyxFQUFFLFFBQVEsQ0FBQyxDQUFDO2dCQUV0RCxNQUFNLENBQUMsQ0FBQyxDQUFDLEdBQUcsRUFBRSxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsRUFBRSxLQUFLLENBQUMsTUFBTSxDQUFDLENBQUMsQ0FBQzthQUM5QztTQUNGO1FBRUQsT0FBTyxDQUFDLE1BQU0sRUFBRSxRQUFRLENBQUMsQ0FBQztJQUM1QixDQUFDLENBQUM7QUFDSixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgRGF0YVR5cGUsIERhdGFWYWx1ZXMsIE51bWVyaWNEYXRhVHlwZSwgVHlwZWRBcnJheSwgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtTaW1wbGVCaW5hcnlLZXJuZWxJbXBsLCBTaW1wbGVCaW5hcnlPcGVyYXRpb259IGZyb20gJy4vYmluYXJ5X3R5cGVzJztcblxuLyoqXG4gKiBUZW1wbGF0ZSB0aGF0IGNyZWF0ZXMgaW1wbGVtZW50YXRpb24gZm9yIGJpbmFyeSBvcHMuIFN1cHBvcnRzIGJyb2FkY2FzdC5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGNyZWF0ZVNpbXBsZUJpbmFyeUtlcm5lbEltcGwob3A6IFNpbXBsZUJpbmFyeU9wZXJhdGlvbik6XG4gICAgU2ltcGxlQmluYXJ5S2VybmVsSW1wbCB7XG4gIHJldHVybiAoYVNoYXBlOiBudW1iZXJbXSwgYlNoYXBlOiBudW1iZXJbXSwgYVZhbHM6IERhdGFWYWx1ZXMsXG4gICAgICAgICAgYlZhbHM6IERhdGFWYWx1ZXMsIGR0eXBlOiBEYXRhVHlwZSk6IFtUeXBlZEFycmF5LCBudW1iZXJbXV0gPT4ge1xuICAgIGNvbnN0IG5ld1NoYXBlID0gYmFja2VuZF91dGlsLmFzc2VydEFuZEdldEJyb2FkY2FzdFNoYXBlKGFTaGFwZSwgYlNoYXBlKTtcblxuICAgIGNvbnN0IHJlc3VsdFJhbmsgPSBuZXdTaGFwZS5sZW5ndGg7XG4gICAgY29uc3QgcmVzdWx0U3RyaWRlcyA9IHV0aWwuY29tcHV0ZVN0cmlkZXMobmV3U2hhcGUpO1xuICAgIGNvbnN0IHJlc3VsdFNpemUgPSB1dGlsLnNpemVGcm9tU2hhcGUobmV3U2hhcGUpO1xuXG4gICAgY29uc3QgcmVzdWx0ID1cbiAgICAgICAgdXRpbC5nZXRUeXBlZEFycmF5RnJvbURUeXBlKGR0eXBlIGFzIE51bWVyaWNEYXRhVHlwZSwgcmVzdWx0U2l6ZSk7XG5cbiAgICBjb25zdCBhUmFuayA9IGFTaGFwZS5sZW5ndGg7XG4gICAgY29uc3QgYlJhbmsgPSBiU2hhcGUubGVuZ3RoO1xuXG4gICAgY29uc3QgYVN0cmlkZXMgPSB1dGlsLmNvbXB1dGVTdHJpZGVzKGFTaGFwZSk7XG4gICAgY29uc3QgYlN0cmlkZXMgPSB1dGlsLmNvbXB1dGVTdHJpZGVzKGJTaGFwZSk7XG5cbiAgICBjb25zdCBhQnJvYWRjYXN0RGltcyA9IGJhY2tlbmRfdXRpbC5nZXRCcm9hZGNhc3REaW1zKGFTaGFwZSwgbmV3U2hhcGUpO1xuICAgIGNvbnN0IGJCcm9hZGNhc3REaW1zID0gYmFja2VuZF91dGlsLmdldEJyb2FkY2FzdERpbXMoYlNoYXBlLCBuZXdTaGFwZSk7XG5cbiAgICBpZiAoYUJyb2FkY2FzdERpbXMubGVuZ3RoICsgYkJyb2FkY2FzdERpbXMubGVuZ3RoID09PSAwKSB7XG4gICAgICBmb3IgKGxldCBpID0gMDsgaSA8IHJlc3VsdC5sZW5ndGg7ICsraSkge1xuICAgICAgICByZXN1bHRbaV0gPSBvcChhVmFsc1tpICUgYVZhbHMubGVuZ3RoXSwgYlZhbHNbaSAlIGJWYWxzLmxlbmd0aF0pO1xuICAgICAgfVxuICAgIH0gZWxzZSB7XG4gICAgICBmb3IgKGxldCBpID0gMDsgaSA8IHJlc3VsdC5sZW5ndGg7ICsraSkge1xuICAgICAgICBjb25zdCBsb2MgPSB1dGlsLmluZGV4VG9Mb2MoaSwgcmVzdWx0UmFuaywgcmVzdWx0U3RyaWRlcyk7XG5cbiAgICAgICAgY29uc3QgYUxvYyA9IGxvYy5zbGljZSgtYVJhbmspO1xuICAgICAgICBhQnJvYWRjYXN0RGltcy5mb3JFYWNoKGQgPT4gYUxvY1tkXSA9IDApO1xuICAgICAgICBjb25zdCBhSW5kZXggPSB1dGlsLmxvY1RvSW5kZXgoYUxvYywgYVJhbmssIGFTdHJpZGVzKTtcblxuICAgICAgICBjb25zdCBiTG9jID0gbG9jLnNsaWNlKC1iUmFuayk7XG4gICAgICAgIGJCcm9hZGNhc3REaW1zLmZvckVhY2goZCA9PiBiTG9jW2RdID0gMCk7XG4gICAgICAgIGNvbnN0IGJJbmRleCA9IHV0aWwubG9jVG9JbmRleChiTG9jLCBiUmFuaywgYlN0cmlkZXMpO1xuXG4gICAgICAgIHJlc3VsdFtpXSA9IG9wKGFWYWxzW2FJbmRleF0sIGJWYWxzW2JJbmRleF0pO1xuICAgICAgfVxuICAgIH1cblxuICAgIHJldHVybiBbcmVzdWx0LCBuZXdTaGFwZV07XG4gIH07XG59XG4iXX0=