/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, sumOutType, util } from '@tensorflow/tfjs-core'; import { reduce } from '../kernel_utils/reduce'; import { reshape } from './Reshape'; import { transposeImpl } from './Transpose_impl'; export function sumImpl(x, axis, keepDims, backend) { const reductionIndices = axis; const xRank = x.shape.length; const origAxes = util.parseAxisParam(reductionIndices, x.shape); let axes = origAxes; const permutedAxes = backend_util.getAxesPermutation(axes, xRank); const sumInputIsTransposed = permutedAxes != null; let sumInput = x; if (sumInputIsTransposed) { sumInput = transposeImpl(x, permutedAxes, backend); axes = backend_util.getInnerMostAxes(axes.length, xRank); } backend_util.assertAxesAreInnerMostDims('sum', axes, xRank); const [sumOutShape, reduceShape] = backend_util.computeOutAndReduceShapes(sumInput.shape, axes); let outShape = sumOutShape; if (keepDims) { // rather than reshape at the end, set the target shape here. outShape = backend_util.expandShapeToKeepDim(sumOutShape, origAxes); } const inSize = util.sizeFromShape(reduceShape); const xSize = util.sizeFromShape(x.shape); const batchSize = xSize / inSize; const reshapedInput = reshape({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend }); const outType = sumOutType(x.dtype); const reduced = reduce(reshapedInput, outType, 'sum', backend); const out = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); if (sumInputIsTransposed) { backend.disposeIntermediateTensorInfo(sumInput); } return out; } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3VtX2ltcGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2tlcm5lbHMvU3VtX2ltcGwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBRSxVQUFVLEVBQWMsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHakYsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLHdCQUF3QixDQUFDO0FBQzlDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFFbEMsT0FBTyxFQUFDLGFBQWEsRUFBQyxNQUFNLGtCQUFrQixDQUFDO0FBRS9DLE1BQU0sVUFBVSxPQUFPLENBQ25CLENBQWEsRUFBRSxJQUFxQixFQUFFLFFBQWlCLEVBQ3ZELE9BQXlCO0lBQzNCLE1BQU0sZ0JBQWdCLEdBQUcsSUFBSSxDQUFDO0lBRTlCLE1BQU0sS0FBSyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBRTdCLE1BQU0sUUFBUSxHQUFHLElBQUksQ0FBQyxjQUFjLENBQUMsZ0JBQWdCLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBQ2hFLElBQUksSUFBSSxHQUFHLFFBQVEsQ0FBQztJQUNwQixNQUFNLFlBQVksR0FBRyxZQUFZLENBQUMsa0JBQWtCLENBQUMsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBQ2xFLE1BQU0sb0JBQW9CLEdBQUcsWUFBWSxJQUFJLElBQUksQ0FBQztJQUVsRCxJQUFJLFFBQVEsR0FBRyxDQUFDLENBQUM7SUFDakIsSUFBSSxvQkFBb0IsRUFBRTtRQUN4QixRQUFRLEdBQUcsYUFBYSxDQUFDLENBQUMsRUFBRSxZQUFZLEVBQUUsT0FBTyxDQUFDLENBQUM7UUFFbkQsSUFBSSxHQUFHLFlBQVksQ0FBQyxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsTUFBTSxFQUFFLEtBQUssQ0FBQyxDQUFDO0tBQzFEO0lBRUQsWUFBWSxDQUFDLDBCQUEwQixDQUFDLEtBQUssRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFDNUQsTUFBTSxDQUFDLFdBQVcsRUFBRSxXQUFXLENBQUMsR0FDNUIsWUFBWSxDQUFDLHlCQUF5QixDQUFDLFFBQVEsQ0FBQyxLQUFLLEVBQUUsSUFBSSxDQUFDLENBQUM7SUFFakUsSUFBSSxRQUFRLEdBQUcsV0FBVyxDQUFDO0lBQzNCLElBQUksUUFBUSxFQUFFO1FBQ1osNkRBQTZEO1FBQzdELFFBQVEsR0FBRyxZQUFZLENBQUMsb0JBQW9CLENBQUMsV0FBVyxFQUFFLFFBQVEsQ0FBQyxDQUFDO0tBQ3JFO0lBRUQsTUFBTSxNQUFNLEdBQUcsSUFBSSxDQUFDLGFBQWEsQ0FBQyxXQUFXLENBQUMsQ0FBQztJQUMvQyxNQUFNLEtBQUssR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUMxQyxNQUFNLFNBQVMsR0FBRyxLQUFLLEdBQUcsTUFBTSxDQUFDO0lBQ2pDLE1BQU0sYUFBYSxHQUFHLE9BQU8sQ0FDekIsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsUUFBUSxFQUFDLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLENBQUMsU0FBUyxFQUFFLE1BQU0sQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQUMsQ0FBQztJQUUzRSxNQUFNLE9BQU8sR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBRXBDLE1BQU0sT0FBTyxHQUFHLE1BQU0sQ0FBQyxhQUFhLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxPQUFPLENBQUMsQ0FBQztJQUMvRCxNQUFNLEdBQUcsR0FDTCxPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsT0FBTyxFQUFDLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFFBQVEsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7SUFFdkUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLGFBQWEsQ0FBQyxDQUFDO0lBQ3JELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxPQUFPLENBQUMsQ0FBQztJQUMvQyxJQUFJLG9CQUFvQixFQUFFO1FBQ3hCLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxRQUFRLENBQUMsQ0FBQztLQUNqRDtJQUVELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIHN1bU91dFR5cGUsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5pbXBvcnQge3JlZHVjZX0gZnJvbSAnLi4va2VybmVsX3V0aWxzL3JlZHVjZSc7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vUmVzaGFwZSc7XG5cbmltcG9ydCB7dHJhbnNwb3NlSW1wbH0gZnJvbSAnLi9UcmFuc3Bvc2VfaW1wbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzdW1JbXBsKFxuICAgIHg6IFRlbnNvckluZm8sIGF4aXM6IG51bWJlcnxudW1iZXJbXSwga2VlcERpbXM6IGJvb2xlYW4sXG4gICAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTCk6IFRlbnNvckluZm8ge1xuICBjb25zdCByZWR1Y3Rpb25JbmRpY2VzID0gYXhpcztcblxuICBjb25zdCB4UmFuayA9IHguc2hhcGUubGVuZ3RoO1xuXG4gIGNvbnN0IG9yaWdBeGVzID0gdXRpbC5wYXJzZUF4aXNQYXJhbShyZWR1Y3Rpb25JbmRpY2VzLCB4LnNoYXBlKTtcbiAgbGV0IGF4ZXMgPSBvcmlnQXhlcztcbiAgY29uc3QgcGVybXV0ZWRBeGVzID0gYmFja2VuZF91dGlsLmdldEF4ZXNQZXJtdXRhdGlvbihheGVzLCB4UmFuayk7XG4gIGNvbnN0IHN1bUlucHV0SXNUcmFuc3Bvc2VkID0gcGVybXV0ZWRBeGVzICE9IG51bGw7XG5cbiAgbGV0IHN1bUlucHV0ID0geDtcbiAgaWYgKHN1bUlucHV0SXNUcmFuc3Bvc2VkKSB7XG4gICAgc3VtSW5wdXQgPSB0cmFuc3Bvc2VJbXBsKHgsIHBlcm11dGVkQXhlcywgYmFja2VuZCk7XG5cbiAgICBheGVzID0gYmFja2VuZF91dGlsLmdldElubmVyTW9zdEF4ZXMoYXhlcy5sZW5ndGgsIHhSYW5rKTtcbiAgfVxuXG4gIGJhY2tlbmRfdXRpbC5hc3NlcnRBeGVzQXJlSW5uZXJNb3N0RGltcygnc3VtJywgYXhlcywgeFJhbmspO1xuICBjb25zdCBbc3VtT3V0U2hhcGUsIHJlZHVjZVNoYXBlXSA9XG4gICAgICBiYWNrZW5kX3V0aWwuY29tcHV0ZU91dEFuZFJlZHVjZVNoYXBlcyhzdW1JbnB1dC5zaGFwZSwgYXhlcyk7XG5cbiAgbGV0IG91dFNoYXBlID0gc3VtT3V0U2hhcGU7XG4gIGlmIChrZWVwRGltcykge1xuICAgIC8vIHJhdGhlciB0aGFuIHJlc2hhcGUgYXQgdGhlIGVuZCwgc2V0IHRoZSB0YXJnZXQgc2hhcGUgaGVyZS5cbiAgICBvdXRTaGFwZSA9IGJhY2tlbmRfdXRpbC5leHBhbmRTaGFwZVRvS2VlcERpbShzdW1PdXRTaGFwZSwgb3JpZ0F4ZXMpO1xuICB9XG5cbiAgY29uc3QgaW5TaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKHJlZHVjZVNoYXBlKTtcbiAgY29uc3QgeFNpemUgPSB1dGlsLnNpemVGcm9tU2hhcGUoeC5zaGFwZSk7XG4gIGNvbnN0IGJhdGNoU2l6ZSA9IHhTaXplIC8gaW5TaXplO1xuICBjb25zdCByZXNoYXBlZElucHV0ID0gcmVzaGFwZShcbiAgICAgIHtpbnB1dHM6IHt4OiBzdW1JbnB1dH0sIGF0dHJzOiB7c2hhcGU6IFtiYXRjaFNpemUsIGluU2l6ZV19LCBiYWNrZW5kfSk7XG5cbiAgY29uc3Qgb3V0VHlwZSA9IHN1bU91dFR5cGUoeC5kdHlwZSk7XG5cbiAgY29uc3QgcmVkdWNlZCA9IHJlZHVjZShyZXNoYXBlZElucHV0LCBvdXRUeXBlLCAnc3VtJywgYmFja2VuZCk7XG4gIGNvbnN0IG91dCA9XG4gICAgICByZXNoYXBlKHtpbnB1dHM6IHt4OiByZWR1Y2VkfSwgYXR0cnM6IHtzaGFwZTogb3V0U2hhcGV9LCBiYWNrZW5kfSk7XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhyZXNoYXBlZElucHV0KTtcbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhyZWR1Y2VkKTtcbiAgaWYgKHN1bUlucHV0SXNUcmFuc3Bvc2VkKSB7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhzdW1JbnB1dCk7XG4gIH1cblxuICByZXR1cm4gb3V0O1xufVxuIl19