/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, sumOutType, util } from '@tensorflow/tfjs-core';
|
import { reduce } from '../kernel_utils/reduce';
|
import { reshape } from './Reshape';
|
import { transposeImpl } from './Transpose_impl';
|
export function sumImpl(x, axis, keepDims, backend) {
|
const reductionIndices = axis;
|
const xRank = x.shape.length;
|
const origAxes = util.parseAxisParam(reductionIndices, x.shape);
|
let axes = origAxes;
|
const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
|
const sumInputIsTransposed = permutedAxes != null;
|
let sumInput = x;
|
if (sumInputIsTransposed) {
|
sumInput = transposeImpl(x, permutedAxes, backend);
|
axes = backend_util.getInnerMostAxes(axes.length, xRank);
|
}
|
backend_util.assertAxesAreInnerMostDims('sum', axes, xRank);
|
const [sumOutShape, reduceShape] = backend_util.computeOutAndReduceShapes(sumInput.shape, axes);
|
let outShape = sumOutShape;
|
if (keepDims) {
|
// rather than reshape at the end, set the target shape here.
|
outShape = backend_util.expandShapeToKeepDim(sumOutShape, origAxes);
|
}
|
const inSize = util.sizeFromShape(reduceShape);
|
const xSize = util.sizeFromShape(x.shape);
|
const batchSize = xSize / inSize;
|
const reshapedInput = reshape({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend });
|
const outType = sumOutType(x.dtype);
|
const reduced = reduce(reshapedInput, outType, 'sum', backend);
|
const out = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
|
backend.disposeIntermediateTensorInfo(reshapedInput);
|
backend.disposeIntermediateTensorInfo(reduced);
|
if (sumInputIsTransposed) {
|
backend.disposeIntermediateTensorInfo(sumInput);
|
}
|
return out;
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3VtX2ltcGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2tlcm5lbHMvU3VtX2ltcGwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBRSxVQUFVLEVBQWMsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHakYsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLHdCQUF3QixDQUFDO0FBQzlDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFFbEMsT0FBTyxFQUFDLGFBQWEsRUFBQyxNQUFNLGtCQUFrQixDQUFDO0FBRS9DLE1BQU0sVUFBVSxPQUFPLENBQ25CLENBQWEsRUFBRSxJQUFxQixFQUFFLFFBQWlCLEVBQ3ZELE9BQXlCO0lBQzNCLE1BQU0sZ0JBQWdCLEdBQUcsSUFBSSxDQUFDO0lBRTlCLE1BQU0sS0FBSyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBRTdCLE1BQU0sUUFBUSxHQUFHLElBQUksQ0FBQyxjQUFjLENBQUMsZ0JBQWdCLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBQ2hFLElBQUksSUFBSSxHQUFHLFFBQVEsQ0FBQztJQUNwQixNQUFNLFlBQVksR0FBRyxZQUFZLENBQUMsa0JBQWtCLENBQUMsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBQ2xFLE1BQU0sb0JBQW9CLEdBQUcsWUFBWSxJQUFJLElBQUksQ0FBQztJQUVsRCxJQUFJLFFBQVEsR0FBRyxDQUFDLENBQUM7SUFDakIsSUFBSSxvQkFBb0IsRUFBRTtRQUN4QixRQUFRLEdBQUcsYUFBYSxDQUFDLENBQUMsRUFBRSxZQUFZLEVBQUUsT0FBTyxDQUFDLENBQUM7UUFFbkQsSUFBSSxHQUFHLFlBQVksQ0FBQyxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsTUFBTSxFQUFFLEtBQUssQ0FBQyxDQUFDO0tBQzFEO0lBRUQsWUFBWSxDQUFDLDBCQUEwQixDQUFDLEtBQUssRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFDNUQsTUFBTSxDQUFDLFdBQVcsRUFBRSxXQUFXLENBQUMsR0FDNUIsWUFBWSxDQUFDLHlCQUF5QixDQUFDLFFBQVEsQ0FBQyxLQUFLLEVBQUUsSUFBSSxDQUFDLENBQUM7SUFFakUsSUFBSSxRQUFRLEdBQUcsV0FBVyxDQUFDO0lBQzNCLElBQUksUUFBUSxFQUFFO1FBQ1osNkRBQTZEO1FBQzdELFFBQVEsR0FBRyxZQUFZLENBQUMsb0JBQW9CLENBQUMsV0FBVyxFQUFFLFFBQVEsQ0FBQyxDQUFDO0tBQ3JFO0lBRUQsTUFBTSxNQUFNLEdBQUcsSUFBSSxDQUFDLGFBQWEsQ0FBQyxXQUFXLENBQUMsQ0FBQztJQUMvQyxNQUFNLEtBQUssR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUMxQyxNQUFNLFNBQVMsR0FBRyxLQUFLLEdBQUcsTUFBTSxDQUFDO0lBQ2pDLE1BQU0sYUFBYSxHQUFHLE9BQU8sQ0FDekIsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsUUFBUSxFQUFDLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLENBQUMsU0FBUyxFQUFFLE1BQU0sQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQUMsQ0FBQztJQUUzRSxNQUFNLE9BQU8sR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBRXBDLE1BQU0sT0FBTyxHQUFHLE1BQU0sQ0FBQyxhQUFhLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxPQUFPLENBQUMsQ0FBQztJQUMvRCxNQUFNLEdBQUcsR0FDTCxPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsT0FBTyxFQUFDLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFFBQVEsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7SUFFdkUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLGFBQWEsQ0FBQyxDQUFDO0lBQ3JELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxPQUFPLENBQUMsQ0FBQztJQUMvQyxJQUFJLG9CQUFvQixFQUFFO1FBQ3hCLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxRQUFRLENBQUMsQ0FBQztLQUNqRDtJQUVELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIHN1bU91dFR5cGUsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5pbXBvcnQge3JlZHVjZX0gZnJvbSAnLi4va2VybmVsX3V0aWxzL3JlZHVjZSc7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vUmVzaGFwZSc7XG5cbmltcG9ydCB7dHJhbnNwb3NlSW1wbH0gZnJvbSAnLi9UcmFuc3Bvc2VfaW1wbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzdW1JbXBsKFxuICAgIHg6IFRlbnNvckluZm8sIGF4aXM6IG51bWJlcnxudW1iZXJbXSwga2VlcERpbXM6IGJvb2xlYW4sXG4gICAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTCk6IFRlbnNvckluZm8ge1xuICBjb25zdCByZWR1Y3Rpb25JbmRpY2VzID0gYXhpcztcblxuICBjb25zdCB4UmFuayA9IHguc2hhcGUubGVuZ3RoO1xuXG4gIGNvbnN0IG9yaWdBeGVzID0gdXRpbC5wYXJzZUF4aXNQYXJhbShyZWR1Y3Rpb25JbmRpY2VzLCB4LnNoYXBlKTtcbiAgbGV0IGF4ZXMgPSBvcmlnQXhlcztcbiAgY29uc3QgcGVybXV0ZWRBeGVzID0gYmFja2VuZF91dGlsLmdldEF4ZXNQZXJtdXRhdGlvbihheGVzLCB4UmFuayk7XG4gIGNvbnN0IHN1bUlucHV0SXNUcmFuc3Bvc2VkID0gcGVybXV0ZWRBeGVzICE9IG51bGw7XG5cbiAgbGV0IHN1bUlucHV0ID0geDtcbiAgaWYgKHN1bUlucHV0SXNUcmFuc3Bvc2VkKSB7XG4gICAgc3VtSW5wdXQgPSB0cmFuc3Bvc2VJbXBsKHgsIHBlcm11dGVkQXhlcywgYmFja2VuZCk7XG5cbiAgICBheGVzID0gYmFja2VuZF91dGlsLmdldElubmVyTW9zdEF4ZXMoYXhlcy5sZW5ndGgsIHhSYW5rKTtcbiAgfVxuXG4gIGJhY2tlbmRfdXRpbC5hc3NlcnRBeGVzQXJlSW5uZXJNb3N0RGltcygnc3VtJywgYXhlcywgeFJhbmspO1xuICBjb25zdCBbc3VtT3V0U2hhcGUsIHJlZHVjZVNoYXBlXSA9XG4gICAgICBiYWNrZW5kX3V0aWwuY29tcHV0ZU91dEFuZFJlZHVjZVNoYXBlcyhzdW1JbnB1dC5zaGFwZSwgYXhlcyk7XG5cbiAgbGV0IG91dFNoYXBlID0gc3VtT3V0U2hhcGU7XG4gIGlmIChrZWVwRGltcykge1xuICAgIC8vIHJhdGhlciB0aGFuIHJlc2hhcGUgYXQgdGhlIGVuZCwgc2V0IHRoZSB0YXJnZXQgc2hhcGUgaGVyZS5cbiAgICBvdXRTaGFwZSA9IGJhY2tlbmRfdXRpbC5leHBhbmRTaGFwZVRvS2VlcERpbShzdW1PdXRTaGFwZSwgb3JpZ0F4ZXMpO1xuICB9XG5cbiAgY29uc3QgaW5TaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKHJlZHVjZVNoYXBlKTtcbiAgY29uc3QgeFNpemUgPSB1dGlsLnNpemVGcm9tU2hhcGUoeC5zaGFwZSk7XG4gIGNvbnN0IGJhdGNoU2l6ZSA9IHhTaXplIC8gaW5TaXplO1xuICBjb25zdCByZXNoYXBlZElucHV0ID0gcmVzaGFwZShcbiAgICAgIHtpbnB1dHM6IHt4OiBzdW1JbnB1dH0sIGF0dHJzOiB7c2hhcGU6IFtiYXRjaFNpemUsIGluU2l6ZV19LCBiYWNrZW5kfSk7XG5cbiAgY29uc3Qgb3V0VHlwZSA9IHN1bU91dFR5cGUoeC5kdHlwZSk7XG5cbiAgY29uc3QgcmVkdWNlZCA9IHJlZHVjZShyZXNoYXBlZElucHV0LCBvdXRUeXBlLCAnc3VtJywgYmFja2VuZCk7XG4gIGNvbnN0IG91dCA9XG4gICAgICByZXNoYXBlKHtpbnB1dHM6IHt4OiByZWR1Y2VkfSwgYXR0cnM6IHtzaGFwZTogb3V0U2hhcGV9LCBiYWNrZW5kfSk7XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhyZXNoYXBlZElucHV0KTtcbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhyZWR1Y2VkKTtcbiAgaWYgKHN1bUlucHV0SXNUcmFuc3Bvc2VkKSB7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhzdW1JbnB1dCk7XG4gIH1cblxuICByZXR1cm4gb3V0O1xufVxuIl19
|