/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, util } from '@tensorflow/tfjs-core';
|
/**
|
* Template that creates implementation for binary ops. Supports broadcast.
|
*/
|
export function createSimpleBinaryKernelImpl(op) {
|
return (aShape, bShape, aVals, bVals, dtype) => {
|
const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
const resultRank = newShape.length;
|
const resultStrides = util.computeStrides(newShape);
|
const resultSize = util.sizeFromShape(newShape);
|
const result = util.getTypedArrayFromDType(dtype, resultSize);
|
const aRank = aShape.length;
|
const bRank = bShape.length;
|
const aStrides = util.computeStrides(aShape);
|
const bStrides = util.computeStrides(bShape);
|
const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape);
|
const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape);
|
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
|
for (let i = 0; i < result.length; ++i) {
|
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
|
}
|
}
|
else {
|
for (let i = 0; i < result.length; ++i) {
|
const loc = util.indexToLoc(i, resultRank, resultStrides);
|
const aLoc = loc.slice(-aRank);
|
aBroadcastDims.forEach(d => aLoc[d] = 0);
|
const aIndex = util.locToIndex(aLoc, aRank, aStrides);
|
const bLoc = loc.slice(-bRank);
|
bBroadcastDims.forEach(d => bLoc[d] = 0);
|
const bIndex = util.locToIndex(bLoc, bRank, bStrides);
|
result[i] = op(aVals[aIndex], bVals[bIndex]);
|
}
|
}
|
return [result, newShape];
|
};
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmluYXJ5X2ltcGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy91dGlscy9iaW5hcnlfaW1wbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFxRCxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUk1Rzs7R0FFRztBQUNILE1BQU0sVUFBVSw0QkFBNEIsQ0FBQyxFQUF5QjtJQUVwRSxPQUFPLENBQUMsTUFBZ0IsRUFBRSxNQUFnQixFQUFFLEtBQWlCLEVBQ3JELEtBQWlCLEVBQUUsS0FBZSxFQUEwQixFQUFFO1FBQ3BFLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQywwQkFBMEIsQ0FBQyxNQUFNLEVBQUUsTUFBTSxDQUFDLENBQUM7UUFFekUsTUFBTSxVQUFVLEdBQUcsUUFBUSxDQUFDLE1BQU0sQ0FBQztRQUNuQyxNQUFNLGFBQWEsR0FBRyxJQUFJLENBQUMsY0FBYyxDQUFDLFFBQVEsQ0FBQyxDQUFDO1FBQ3BELE1BQU0sVUFBVSxHQUFHLElBQUksQ0FBQyxhQUFhLENBQUMsUUFBUSxDQUFDLENBQUM7UUFFaEQsTUFBTSxNQUFNLEdBQ1IsSUFBSSxDQUFDLHNCQUFzQixDQUFDLEtBQXdCLEVBQUUsVUFBVSxDQUFDLENBQUM7UUFFdEUsTUFBTSxLQUFLLEdBQUcsTUFBTSxDQUFDLE1BQU0sQ0FBQztRQUM1QixNQUFNLEtBQUssR0FBRyxNQUFNLENBQUMsTUFBTSxDQUFDO1FBRTVCLE1BQU0sUUFBUSxHQUFHLElBQUksQ0FBQyxjQUFjLENBQUMsTUFBTSxDQUFDLENBQUM7UUFDN0MsTUFBTSxRQUFRLEdBQUcsSUFBSSxDQUFDLGNBQWMsQ0FBQyxNQUFNLENBQUMsQ0FBQztRQUU3QyxNQUFNLGNBQWMsR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsTUFBTSxFQUFFLFFBQVEsQ0FBQyxDQUFDO1FBQ3ZFLE1BQU0sY0FBYyxHQUFHLFlBQVksQ0FBQyxnQkFBZ0IsQ0FBQyxNQUFNLEVBQUUsUUFBUSxDQUFDLENBQUM7UUFFdkUsSUFBSSxjQUFjLENBQUMsTUFBTSxHQUFHLGNBQWMsQ0FBQyxNQUFNLEtBQUssQ0FBQyxFQUFFO1lBQ3ZELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxNQUFNLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO2dCQUN0QyxNQUFNLENBQUMsQ0FBQyxDQUFDLEdBQUcsRUFBRSxDQUFDLEtBQUssQ0FBQyxDQUFDLEdBQUcsS0FBSyxDQUFDLE1BQU0sQ0FBQyxFQUFFLEtBQUssQ0FBQyxDQUFDLEdBQUcsS0FBSyxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUM7YUFDbEU7U0FDRjthQUFNO1lBQ0wsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLE1BQU0sQ0FBQyxNQUFNLEVBQUUsRUFBRSxDQUFDLEVBQUU7Z0JBQ3RDLE1BQU0sR0FBRyxHQUFHLElBQUksQ0FBQyxVQUFVLENBQUMsQ0FBQyxFQUFFLFVBQVUsRUFBRSxhQUFhLENBQUMsQ0FBQztnQkFFMUQsTUFBTSxJQUFJLEdBQUcsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO2dCQUMvQixjQUFjLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDO2dCQUN6QyxNQUFNLE1BQU0sR0FBRyxJQUFJLENBQUMsVUFBVSxDQUFDLElBQUksRUFBRSxLQUFLLEVBQUUsUUFBUSxDQUFDLENBQUM7Z0JBRXRELE1BQU0sSUFBSSxHQUFHLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztnQkFDL0IsY0FBYyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztnQkFDekMsTUFBTSxNQUFNLEdBQUcsSUFBSSxDQUFDLFVBQVUsQ0FBQyxJQUFJLEVBQUUsS0FBSyxFQUFFLFFBQVEsQ0FBQyxDQUFDO2dCQUV0RCxNQUFNLENBQUMsQ0FBQyxDQUFDLEdBQUcsRUFBRSxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsRUFBRSxLQUFLLENBQUMsTUFBTSxDQUFDLENBQUMsQ0FBQzthQUM5QztTQUNGO1FBRUQsT0FBTyxDQUFDLE1BQU0sRUFBRSxRQUFRLENBQUMsQ0FBQztJQUM1QixDQUFDLENBQUM7QUFDSixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgRGF0YVR5cGUsIERhdGFWYWx1ZXMsIE51bWVyaWNEYXRhVHlwZSwgVHlwZWRBcnJheSwgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtTaW1wbGVCaW5hcnlLZXJuZWxJbXBsLCBTaW1wbGVCaW5hcnlPcGVyYXRpb259IGZyb20gJy4vYmluYXJ5X3R5cGVzJztcblxuLyoqXG4gKiBUZW1wbGF0ZSB0aGF0IGNyZWF0ZXMgaW1wbGVtZW50YXRpb24gZm9yIGJpbmFyeSBvcHMuIFN1cHBvcnRzIGJyb2FkY2FzdC5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGNyZWF0ZVNpbXBsZUJpbmFyeUtlcm5lbEltcGwob3A6IFNpbXBsZUJpbmFyeU9wZXJhdGlvbik6XG4gICAgU2ltcGxlQmluYXJ5S2VybmVsSW1wbCB7XG4gIHJldHVybiAoYVNoYXBlOiBudW1iZXJbXSwgYlNoYXBlOiBudW1iZXJbXSwgYVZhbHM6IERhdGFWYWx1ZXMsXG4gICAgICAgICAgYlZhbHM6IERhdGFWYWx1ZXMsIGR0eXBlOiBEYXRhVHlwZSk6IFtUeXBlZEFycmF5LCBudW1iZXJbXV0gPT4ge1xuICAgIGNvbnN0IG5ld1NoYXBlID0gYmFja2VuZF91dGlsLmFzc2VydEFuZEdldEJyb2FkY2FzdFNoYXBlKGFTaGFwZSwgYlNoYXBlKTtcblxuICAgIGNvbnN0IHJlc3VsdFJhbmsgPSBuZXdTaGFwZS5sZW5ndGg7XG4gICAgY29uc3QgcmVzdWx0U3RyaWRlcyA9IHV0aWwuY29tcHV0ZVN0cmlkZXMobmV3U2hhcGUpO1xuICAgIGNvbnN0IHJlc3VsdFNpemUgPSB1dGlsLnNpemVGcm9tU2hhcGUobmV3U2hhcGUpO1xuXG4gICAgY29uc3QgcmVzdWx0ID1cbiAgICAgICAgdXRpbC5nZXRUeXBlZEFycmF5RnJvbURUeXBlKGR0eXBlIGFzIE51bWVyaWNEYXRhVHlwZSwgcmVzdWx0U2l6ZSk7XG5cbiAgICBjb25zdCBhUmFuayA9IGFTaGFwZS5sZW5ndGg7XG4gICAgY29uc3QgYlJhbmsgPSBiU2hhcGUubGVuZ3RoO1xuXG4gICAgY29uc3QgYVN0cmlkZXMgPSB1dGlsLmNvbXB1dGVTdHJpZGVzKGFTaGFwZSk7XG4gICAgY29uc3QgYlN0cmlkZXMgPSB1dGlsLmNvbXB1dGVTdHJpZGVzKGJTaGFwZSk7XG5cbiAgICBjb25zdCBhQnJvYWRjYXN0RGltcyA9IGJhY2tlbmRfdXRpbC5nZXRCcm9hZGNhc3REaW1zKGFTaGFwZSwgbmV3U2hhcGUpO1xuICAgIGNvbnN0IGJCcm9hZGNhc3REaW1zID0gYmFja2VuZF91dGlsLmdldEJyb2FkY2FzdERpbXMoYlNoYXBlLCBuZXdTaGFwZSk7XG5cbiAgICBpZiAoYUJyb2FkY2FzdERpbXMubGVuZ3RoICsgYkJyb2FkY2FzdERpbXMubGVuZ3RoID09PSAwKSB7XG4gICAgICBmb3IgKGxldCBpID0gMDsgaSA8IHJlc3VsdC5sZW5ndGg7ICsraSkge1xuICAgICAgICByZXN1bHRbaV0gPSBvcChhVmFsc1tpICUgYVZhbHMubGVuZ3RoXSwgYlZhbHNbaSAlIGJWYWxzLmxlbmd0aF0pO1xuICAgICAgfVxuICAgIH0gZWxzZSB7XG4gICAgICBmb3IgKGxldCBpID0gMDsgaSA8IHJlc3VsdC5sZW5ndGg7ICsraSkge1xuICAgICAgICBjb25zdCBsb2MgPSB1dGlsLmluZGV4VG9Mb2MoaSwgcmVzdWx0UmFuaywgcmVzdWx0U3RyaWRlcyk7XG5cbiAgICAgICAgY29uc3QgYUxvYyA9IGxvYy5zbGljZSgtYVJhbmspO1xuICAgICAgICBhQnJvYWRjYXN0RGltcy5mb3JFYWNoKGQgPT4gYUxvY1tkXSA9IDApO1xuICAgICAgICBjb25zdCBhSW5kZXggPSB1dGlsLmxvY1RvSW5kZXgoYUxvYywgYVJhbmssIGFTdHJpZGVzKTtcblxuICAgICAgICBjb25zdCBiTG9jID0gbG9jLnNsaWNlKC1iUmFuayk7XG4gICAgICAgIGJCcm9hZGNhc3REaW1zLmZvckVhY2goZCA9PiBiTG9jW2RdID0gMCk7XG4gICAgICAgIGNvbnN0IGJJbmRleCA9IHV0aWwubG9jVG9JbmRleChiTG9jLCBiUmFuaywgYlN0cmlkZXMpO1xuXG4gICAgICAgIHJlc3VsdFtpXSA9IG9wKGFWYWxzW2FJbmRleF0sIGJWYWxzW2JJbmRleF0pO1xuICAgICAgfVxuICAgIH1cblxuICAgIHJldHVybiBbcmVzdWx0LCBuZXdTaGFwZV07XG4gIH07XG59XG4iXX0=
|