gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { AddN, env, upcastType } from '@tensorflow/tfjs-core';
import { AddNProgram } from '../addn_gpu';
import { AddNPackedProgram } from '../addn_packed_gpu';
import { identity } from './Identity';
export function addN(args) {
    const { inputs, backend } = args;
    const tensors = inputs;
    if (tensors.length === 1) {
        return identity({ inputs: { x: tensors[0] }, backend });
    }
    // Limit the number of uploaded textures for optimization.
    if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
        const midIndex = Math.floor(tensors.length / 2);
        const leftSide = addN({ inputs: tensors.slice(0, midIndex), backend });
        const rightSide = addN({ inputs: tensors.slice(midIndex), backend });
        return addN({ inputs: [leftSide, rightSide], backend });
    }
    const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2));
    const shapes = tensors.map(t => t.shape);
    // We can make sure shapes are identical in op level.
    const usePackedOp = env().getBool('WEBGL_PACK');
    const program = usePackedOp ?
        new AddNPackedProgram(tensors[0].shape, shapes) :
        new AddNProgram(tensors[0].shape, shapes);
    return backend.runWebGLProgram(program, tensors, dtype);
}
export const addNConfig = {
    kernelName: AddN,
    backendName: 'webgl',
    kernelFunc: addN
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQWRkTi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9BZGROLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxJQUFJLEVBQWMsR0FBRyxFQUF3QyxVQUFVLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUU5RyxPQUFPLEVBQUMsV0FBVyxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBQ3hDLE9BQU8sRUFBQyxpQkFBaUIsRUFBQyxNQUFNLG9CQUFvQixDQUFDO0FBRXJELE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFFcEMsTUFBTSxVQUFVLElBQUksQ0FBQyxJQUFxRDtJQUV4RSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBQyxHQUFHLElBQUksQ0FBQztJQUUvQixNQUFNLE9BQU8sR0FBRyxNQUFNLENBQUM7SUFDdkIsSUFBSSxPQUFPLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUN4QixPQUFPLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO0tBQ3JEO0lBRUQsMERBQTBEO0lBQzFELElBQUksT0FBTyxDQUFDLE1BQU0sR0FBRyxHQUFHLEVBQUUsQ0FBQyxTQUFTLENBQUMsOEJBQThCLENBQUMsRUFBRTtRQUNwRSxNQUFNLFFBQVEsR0FBRyxJQUFJLENBQUMsS0FBSyxDQUFDLE9BQU8sQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDLENBQUM7UUFDaEQsTUFBTSxRQUFRLEdBQUcsSUFBSSxDQUFDLEVBQUMsTUFBTSxFQUFFLE9BQU8sQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFLFFBQVEsQ0FBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7UUFDckUsTUFBTSxTQUFTLEdBQUcsSUFBSSxDQUFDLEVBQUMsTUFBTSxFQUFFLE9BQU8sQ0FBQyxLQUFLLENBQUMsUUFBUSxDQUFDLEVBQUUsT0FBTyxFQUFDLENBQUMsQ0FBQztRQUNuRSxPQUFPLElBQUksQ0FBQyxFQUFDLE1BQU0sRUFBRSxDQUFDLFFBQVEsRUFBRSxTQUFTLENBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO0tBQ3ZEO0lBRUQsTUFBTSxLQUFLLEdBQ1AsT0FBTyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQUUsRUFBRSxFQUFFLEVBQUUsQ0FBQyxVQUFVLENBQUMsRUFBRSxFQUFFLEVBQUUsQ0FBQyxDQUFDLENBQUM7SUFDckUsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUN6QyxxREFBcUQ7SUFDckQsTUFBTSxXQUFXLEdBQUcsR0FBRyxFQUFFLENBQUMsT0FBTyxDQUFDLFlBQVksQ0FBQyxDQUFDO0lBQ2hELE1BQU0sT0FBTyxHQUFHLFdBQVcsQ0FBQyxDQUFDO1FBQ3pCLElBQUksaUJBQWlCLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxNQUFNLENBQUMsQ0FBQyxDQUFDO1FBQ2pELElBQUksV0FBVyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLENBQUM7SUFDOUMsT0FBTyxPQUFPLENBQUMsZUFBZSxDQUFDLE9BQU8sRUFBRSxPQUFPLEVBQUUsS0FBSyxDQUFDLENBQUM7QUFDMUQsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLFVBQVUsR0FBaUI7SUFDdEMsVUFBVSxFQUFFLElBQUk7SUFDaEIsV0FBVyxFQUFFLE9BQU87SUFDcEIsVUFBVSxFQUFFLElBQTZCO0NBQzFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7QWRkTiwgQWRkTklucHV0cywgZW52LCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIHVwY2FzdFR5cGV9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QWRkTlByb2dyYW19IGZyb20gJy4uL2FkZG5fZ3B1JztcbmltcG9ydCB7QWRkTlBhY2tlZFByb2dyYW19IGZyb20gJy4uL2FkZG5fcGFja2VkX2dwdSc7XG5pbXBvcnQge01hdGhCYWNrZW5kV2ViR0x9IGZyb20gJy4uL2JhY2tlbmRfd2ViZ2wnO1xuaW1wb3J0IHtpZGVudGl0eX0gZnJvbSAnLi9JZGVudGl0eSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBhZGROKGFyZ3M6IHtpbnB1dHM6IEFkZE5JbnB1dHMsIGJhY2tlbmQ6IE1hdGhCYWNrZW5kV2ViR0x9KTpcbiAgICBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZH0gPSBhcmdzO1xuXG4gIGNvbnN0IHRlbnNvcnMgPSBpbnB1dHM7XG4gIGlmICh0ZW5zb3JzLmxlbmd0aCA9PT0gMSkge1xuICAgIHJldHVybiBpZGVudGl0eSh7aW5wdXRzOiB7eDogdGVuc29yc1swXX0sIGJhY2tlbmR9KTtcbiAgfVxuXG4gIC8vIExpbWl0IHRoZSBudW1iZXIgb2YgdXBsb2FkZWQgdGV4dHVyZXMgZm9yIG9wdGltaXphdGlvbi5cbiAgaWYgKHRlbnNvcnMubGVuZ3RoID4gZW52KCkuZ2V0TnVtYmVyKCdXRUJHTF9NQVhfVEVYVFVSRVNfSU5fU0hBREVSJykpIHtcbiAgICBjb25zdCBtaWRJbmRleCA9IE1hdGguZmxvb3IodGVuc29ycy5sZW5ndGggLyAyKTtcbiAgICBjb25zdCBsZWZ0U2lkZSA9IGFkZE4oe2lucHV0czogdGVuc29ycy5zbGljZSgwLCBtaWRJbmRleCksIGJhY2tlbmR9KTtcbiAgICBjb25zdCByaWdodFNpZGUgPSBhZGROKHtpbnB1dHM6IHRlbnNvcnMuc2xpY2UobWlkSW5kZXgpLCBiYWNrZW5kfSk7XG4gICAgcmV0dXJuIGFkZE4oe2lucHV0czogW2xlZnRTaWRlLCByaWdodFNpZGVdLCBiYWNrZW5kfSk7XG4gIH1cblxuICBjb25zdCBkdHlwZSA9XG4gICAgICB0ZW5zb3JzLm1hcCh0ID0+IHQuZHR5cGUpLnJlZHVjZSgoZDEsIGQyKSA9PiB1cGNhc3RUeXBlKGQxLCBkMikpO1xuICBjb25zdCBzaGFwZXMgPSB0ZW5zb3JzLm1hcCh0ID0+IHQuc2hhcGUpO1xuICAvLyBXZSBjYW4gbWFrZSBzdXJlIHNoYXBlcyBhcmUgaWRlbnRpY2FsIGluIG9wIGxldmVsLlxuICBjb25zdCB1c2VQYWNrZWRPcCA9IGVudigpLmdldEJvb2woJ1dFQkdMX1BBQ0snKTtcbiAgY29uc3QgcHJvZ3JhbSA9IHVzZVBhY2tlZE9wID9cbiAgICAgIG5ldyBBZGROUGFja2VkUHJvZ3JhbSh0ZW5zb3JzWzBdLnNoYXBlLCBzaGFwZXMpIDpcbiAgICAgIG5ldyBBZGROUHJvZ3JhbSh0ZW5zb3JzWzBdLnNoYXBlLCBzaGFwZXMpO1xuICByZXR1cm4gYmFja2VuZC5ydW5XZWJHTFByb2dyYW0ocHJvZ3JhbSwgdGVuc29ycywgZHR5cGUpO1xufVxuXG5leHBvcnQgY29uc3QgYWRkTkNvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBBZGROLFxuICBiYWNrZW5kTmFtZTogJ3dlYmdsJyxcbiAga2VybmVsRnVuYzogYWRkTiBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=