/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { AddN, env, upcastType } from '@tensorflow/tfjs-core'; import { AddNProgram } from '../addn_gpu'; import { AddNPackedProgram } from '../addn_packed_gpu'; import { identity } from './Identity'; export function addN(args) { const { inputs, backend } = args; const tensors = inputs; if (tensors.length === 1) { return identity({ inputs: { x: tensors[0] }, backend }); } // Limit the number of uploaded textures for optimization. if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { const midIndex = Math.floor(tensors.length / 2); const leftSide = addN({ inputs: tensors.slice(0, midIndex), backend }); const rightSide = addN({ inputs: tensors.slice(midIndex), backend }); return addN({ inputs: [leftSide, rightSide], backend }); } const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2)); const shapes = tensors.map(t => t.shape); // We can make sure shapes are identical in op level. const usePackedOp = env().getBool('WEBGL_PACK'); const program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes); return backend.runWebGLProgram(program, tensors, dtype); } export const addNConfig = { kernelName: AddN, backendName: 'webgl', kernelFunc: addN }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQWRkTi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9BZGROLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxJQUFJLEVBQWMsR0FBRyxFQUF3QyxVQUFVLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUU5RyxPQUFPLEVBQUMsV0FBVyxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBQ3hDLE9BQU8sRUFBQyxpQkFBaUIsRUFBQyxNQUFNLG9CQUFvQixDQUFDO0FBRXJELE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFFcEMsTUFBTSxVQUFVLElBQUksQ0FBQyxJQUFxRDtJQUV4RSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBQyxHQUFHLElBQUksQ0FBQztJQUUvQixNQUFNLE9BQU8sR0FBRyxNQUFNLENBQUM7SUFDdkIsSUFBSSxPQUFPLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUN4QixPQUFPLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsQ0FBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO0tBQ3JEO0lBRUQsMERBQTBEO0lBQzFELElBQUksT0FBTyxDQUFDLE1BQU0sR0FBRyxHQUFHLEVBQUUsQ0FBQyxTQUFTLENBQUMsOEJBQThCLENBQUMsRUFBRTtRQUNwRSxNQUFNLFFBQVEsR0FBRyxJQUFJLENBQUMsS0FBSyxDQUFDLE9BQU8sQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDLENBQUM7UUFDaEQsTUFBTSxRQUFRLEdBQUcsSUFBSSxDQUFDLEVBQUMsTUFBTSxFQUFFLE9BQU8sQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFLFFBQVEsQ0FBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7UUFDckUsTUFBTSxTQUFTLEdBQUcsSUFBSSxDQUFDLEVBQUMsTUFBTSxFQUFFLE9BQU8sQ0FBQyxLQUFLLENBQUMsUUFBUSxDQUFDLEVBQUUsT0FBTyxFQUFDLENBQUMsQ0FBQztRQUNuRSxPQUFPLElBQUksQ0FBQyxFQUFDLE1BQU0sRUFBRSxDQUFDLFFBQVEsRUFBRSxTQUFTLENBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO0tBQ3ZEO0lBRUQsTUFBTSxLQUFLLEdBQ1AsT0FBTyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQUUsRUFBRSxFQUFFLEVBQUUsQ0FBQyxVQUFVLENBQUMsRUFBRSxFQUFFLEVBQUUsQ0FBQyxDQUFDLENBQUM7SUFDckUsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUN6QyxxREFBcUQ7SUFDckQsTUFBTSxXQUFXLEdBQUcsR0FBRyxFQUFFLENBQUMsT0FBTyxDQUFDLFlBQVksQ0FBQyxDQUFDO0lBQ2hELE1BQU0sT0FBTyxHQUFHLFdBQVcsQ0FBQyxDQUFDO1FBQ3pCLElBQUksaUJBQWlCLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxNQUFNLENBQUMsQ0FBQyxDQUFDO1FBQ2pELElBQUksV0FBVyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLENBQUM7SUFDOUMsT0FBTyxPQUFPLENBQUMsZUFBZSxDQUFDLE9BQU8sRUFBRSxPQUFPLEVBQUUsS0FBSyxDQUFDLENBQUM7QUFDMUQsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLFVBQVUsR0FBaUI7SUFDdEMsVUFBVSxFQUFFLElBQUk7SUFDaEIsV0FBVyxFQUFFLE9BQU87SUFDcEIsVUFBVSxFQUFFLElBQTZCO0NBQzFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7QWRkTiwgQWRkTklucHV0cywgZW52LCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIHVwY2FzdFR5cGV9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QWRkTlByb2dyYW19IGZyb20gJy4uL2FkZG5fZ3B1JztcbmltcG9ydCB7QWRkTlBhY2tlZFByb2dyYW19IGZyb20gJy4uL2FkZG5fcGFja2VkX2dwdSc7XG5pbXBvcnQge01hdGhCYWNrZW5kV2ViR0x9IGZyb20gJy4uL2JhY2tlbmRfd2ViZ2wnO1xuaW1wb3J0IHtpZGVudGl0eX0gZnJvbSAnLi9JZGVudGl0eSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBhZGROKGFyZ3M6IHtpbnB1dHM6IEFkZE5JbnB1dHMsIGJhY2tlbmQ6IE1hdGhCYWNrZW5kV2ViR0x9KTpcbiAgICBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZH0gPSBhcmdzO1xuXG4gIGNvbnN0IHRlbnNvcnMgPSBpbnB1dHM7XG4gIGlmICh0ZW5zb3JzLmxlbmd0aCA9PT0gMSkge1xuICAgIHJldHVybiBpZGVudGl0eSh7aW5wdXRzOiB7eDogdGVuc29yc1swXX0sIGJhY2tlbmR9KTtcbiAgfVxuXG4gIC8vIExpbWl0IHRoZSBudW1iZXIgb2YgdXBsb2FkZWQgdGV4dHVyZXMgZm9yIG9wdGltaXphdGlvbi5cbiAgaWYgKHRlbnNvcnMubGVuZ3RoID4gZW52KCkuZ2V0TnVtYmVyKCdXRUJHTF9NQVhfVEVYVFVSRVNfSU5fU0hBREVSJykpIHtcbiAgICBjb25zdCBtaWRJbmRleCA9IE1hdGguZmxvb3IodGVuc29ycy5sZW5ndGggLyAyKTtcbiAgICBjb25zdCBsZWZ0U2lkZSA9IGFkZE4oe2lucHV0czogdGVuc29ycy5zbGljZSgwLCBtaWRJbmRleCksIGJhY2tlbmR9KTtcbiAgICBjb25zdCByaWdodFNpZGUgPSBhZGROKHtpbnB1dHM6IHRlbnNvcnMuc2xpY2UobWlkSW5kZXgpLCBiYWNrZW5kfSk7XG4gICAgcmV0dXJuIGFkZE4oe2lucHV0czogW2xlZnRTaWRlLCByaWdodFNpZGVdLCBiYWNrZW5kfSk7XG4gIH1cblxuICBjb25zdCBkdHlwZSA9XG4gICAgICB0ZW5zb3JzLm1hcCh0ID0+IHQuZHR5cGUpLnJlZHVjZSgoZDEsIGQyKSA9PiB1cGNhc3RUeXBlKGQxLCBkMikpO1xuICBjb25zdCBzaGFwZXMgPSB0ZW5zb3JzLm1hcCh0ID0+IHQuc2hhcGUpO1xuICAvLyBXZSBjYW4gbWFrZSBzdXJlIHNoYXBlcyBhcmUgaWRlbnRpY2FsIGluIG9wIGxldmVsLlxuICBjb25zdCB1c2VQYWNrZWRPcCA9IGVudigpLmdldEJvb2woJ1dFQkdMX1BBQ0snKTtcbiAgY29uc3QgcHJvZ3JhbSA9IHVzZVBhY2tlZE9wID9cbiAgICAgIG5ldyBBZGROUGFja2VkUHJvZ3JhbSh0ZW5zb3JzWzBdLnNoYXBlLCBzaGFwZXMpIDpcbiAgICAgIG5ldyBBZGROUHJvZ3JhbSh0ZW5zb3JzWzBdLnNoYXBlLCBzaGFwZXMpO1xuICByZXR1cm4gYmFja2VuZC5ydW5XZWJHTFByb2dyYW0ocHJvZ3JhbSwgdGVuc29ycywgZHR5cGUpO1xufVxuXG5leHBvcnQgY29uc3QgYWRkTkNvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBBZGROLFxuICBiYWNrZW5kTmFtZTogJ3dlYmdsJyxcbiAga2VybmVsRnVuYzogYWRkTiBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=