/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, Sum, util } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
import { zeros } from '../utils/zeros_impl';
|
import { cast } from './Cast';
|
import { identity } from './Identity';
|
import { reshape } from './Reshape';
|
import { transpose } from './Transpose';
|
export function sum(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, keepDims } = attrs;
|
assertNotComplex(x, 'sum');
|
let $x;
|
if (x.dtype === 'bool') {
|
$x = cast({ inputs: { x }, backend, attrs: { dtype: 'int32' } });
|
}
|
else {
|
$x = identity({ inputs: { x }, backend });
|
}
|
const xRank = $x.shape.length;
|
const axes = util.parseAxisParam(axis, $x.shape);
|
const permutation = backend_util.getAxesPermutation(axes, xRank);
|
let reductionAxes = axes;
|
let permutedX = $x;
|
if (permutation != null) {
|
permutedX =
|
transpose({ inputs: { x: $x }, backend, attrs: { perm: permutation } });
|
reductionAxes = backend_util.getInnerMostAxes(reductionAxes.length, xRank);
|
}
|
backend_util.assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(permutedX.shape, reductionAxes);
|
const resultDtype = backend_util.upcastType(permutedX.dtype, 'int32');
|
let result = zeros(backend, outShape, resultDtype);
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const vals = backend.data.get(result.dataId).values;
|
const aVals = backend.data.get(permutedX.dataId).values;
|
for (let i = 0; i < vals.length; ++i) {
|
const offset = i * reduceSize;
|
let sum = 0;
|
for (let j = 0; j < reduceSize; ++j) {
|
sum += aVals[offset + j];
|
}
|
vals[i] = sum;
|
}
|
if (keepDims) {
|
const newShape = backend_util.expandShapeToKeepDim(result.shape, axes);
|
const oldResult = result;
|
result = reshape({ inputs: { x: result }, backend, attrs: { shape: newShape } });
|
backend.disposeIntermediateTensorInfo(oldResult);
|
}
|
backend.disposeIntermediateTensorInfo($x);
|
if (permutation != null) {
|
backend.disposeIntermediateTensorInfo(permutedX);
|
}
|
return result;
|
}
|
export const sumConfig = {
|
kernelName: Sum,
|
backendName: 'cpu',
|
kernelFunc: sum
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9TdW0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBNEIsR0FBRyxFQUErQyxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUdySSxPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLEtBQUssRUFBQyxNQUFNLHFCQUFxQixDQUFDO0FBQzFDLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFDNUIsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUNwQyxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2xDLE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFdEMsTUFBTSxVQUFVLEdBQUcsQ0FDZixJQUFtRTtJQUVyRSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsSUFBSSxFQUFFLFFBQVEsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUUvQixnQkFBZ0IsQ0FBQyxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFM0IsSUFBSSxFQUFFLENBQUM7SUFDUCxJQUFJLENBQUMsQ0FBQyxLQUFLLEtBQUssTUFBTSxFQUFFO1FBQ3RCLEVBQUUsR0FBRyxJQUFJLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLE9BQU8sRUFBQyxFQUFDLENBQUMsQ0FBQztLQUM1RDtTQUFNO1FBQ0wsRUFBRSxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7S0FDdkM7SUFFRCxNQUFNLEtBQUssR0FBRyxFQUFFLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztJQUM5QixNQUFNLElBQUksR0FBRyxJQUFJLENBQUMsY0FBYyxDQUFDLElBQUksRUFBRSxFQUFFLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDakQsTUFBTSxXQUFXLEdBQUcsWUFBWSxDQUFDLGtCQUFrQixDQUFDLElBQUksRUFBRSxLQUFLLENBQUMsQ0FBQztJQUVqRSxJQUFJLGFBQWEsR0FBRyxJQUFJLENBQUM7SUFDekIsSUFBSSxTQUFTLEdBQUcsRUFBRSxDQUFDO0lBQ25CLElBQUksV0FBVyxJQUFJLElBQUksRUFBRTtRQUN2QixTQUFTO1lBQ0wsU0FBUyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLEVBQUUsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxJQUFJLEVBQUUsV0FBVyxFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBQ3RFLGFBQWEsR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsYUFBYSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztLQUM1RTtJQUVELFlBQVksQ0FBQywwQkFBMEIsQ0FDbkMsS0FBSyxFQUFFLGFBQWEsRUFBRSxTQUFTLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBRWxELE1BQU0sQ0FBQyxRQUFRLEVBQUUsV0FBVyxDQUFDLEdBQ3pCLFlBQVksQ0FBQyx5QkFBeUIsQ0FBQyxTQUFTLENBQUMsS0FBSyxFQUFFLGFBQWEsQ0FBQyxDQUFDO0lBQzNFLE1BQU0sV0FBVyxHQUFHLFlBQVksQ0FBQyxVQUFVLENBQUMsU0FBUyxDQUFDLEtBQUssRUFBRSxPQUFPLENBQUMsQ0FBQztJQUN0RSxJQUFJLE1BQU0sR0FBRyxLQUFLLENBQUMsT0FBTyxFQUFFLFFBQVEsRUFBRSxXQUFXLENBQUMsQ0FBQztJQUNuRCxNQUFNLFVBQVUsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLFdBQVcsQ0FBQyxDQUFDO0lBQ25ELE1BQU0sSUFBSSxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBRWxFLE1BQU0sS0FBSyxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBQ3RFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQ3BDLE1BQU0sTUFBTSxHQUFHLENBQUMsR0FBRyxVQUFVLENBQUM7UUFDOUIsSUFBSSxHQUFHLEdBQUcsQ0FBQyxDQUFDO1FBQ1osS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFVBQVUsRUFBRSxFQUFFLENBQUMsRUFBRTtZQUNuQyxHQUFHLElBQUksS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUMsQ0FBQztTQUMxQjtRQUNELElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxHQUFHLENBQUM7S0FDZjtJQUVELElBQUksUUFBUSxFQUFFO1FBQ1osTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLE1BQU0sQ0FBQyxLQUFLLEVBQUUsSUFBSSxDQUFDLENBQUM7UUFDdkUsTUFBTSxTQUFTLEdBQUcsTUFBTSxDQUFDO1FBQ3pCLE1BQU0sR0FBRyxPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsTUFBTSxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxRQUFRLEVBQUMsRUFBQyxDQUFDLENBQUM7UUFDM0UsT0FBTyxDQUFDLDZCQUE2QixDQUFDLFNBQVMsQ0FBQyxDQUFDO0tBQ2xEO0lBRUQsT0FBTyxDQUFDLDZCQUE2QixDQUFDLEVBQUUsQ0FBQyxDQUFDO0lBRTFDLElBQUksV0FBVyxJQUFJLElBQUksRUFBRTtRQUN2QixPQUFPLENBQUMsNkJBQTZCLENBQUMsU0FBUyxDQUFDLENBQUM7S0FDbEQ7SUFFRCxPQUFPLE1BQU0sQ0FBQztBQUNoQixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sU0FBUyxHQUFpQjtJQUNyQyxVQUFVLEVBQUUsR0FBRztJQUNmLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxHQUE0QjtDQUN6QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTdW0sIFN1bUF0dHJzLCBTdW1JbnB1dHMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuaW1wb3J0IHt6ZXJvc30gZnJvbSAnLi4vdXRpbHMvemVyb3NfaW1wbCc7XG5pbXBvcnQge2Nhc3R9IGZyb20gJy4vQ2FzdCc7XG5pbXBvcnQge2lkZW50aXR5fSBmcm9tICcuL0lkZW50aXR5JztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzdW0oXG4gICAgYXJnczoge2lucHV0czogU3VtSW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSwgYXR0cnM6IFN1bUF0dHJzfSk6XG4gICAgVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4fSA9IGlucHV0cztcbiAgY29uc3Qge2F4aXMsIGtlZXBEaW1zfSA9IGF0dHJzO1xuXG4gIGFzc2VydE5vdENvbXBsZXgoeCwgJ3N1bScpO1xuXG4gIGxldCAkeDtcbiAgaWYgKHguZHR5cGUgPT09ICdib29sJykge1xuICAgICR4ID0gY2FzdCh7aW5wdXRzOiB7eH0sIGJhY2tlbmQsIGF0dHJzOiB7ZHR5cGU6ICdpbnQzMid9fSk7XG4gIH0gZWxzZSB7XG4gICAgJHggPSBpZGVudGl0eSh7aW5wdXRzOiB7eH0sIGJhY2tlbmR9KTtcbiAgfVxuXG4gIGNvbnN0IHhSYW5rID0gJHguc2hhcGUubGVuZ3RoO1xuICBjb25zdCBheGVzID0gdXRpbC5wYXJzZUF4aXNQYXJhbShheGlzLCAkeC5zaGFwZSk7XG4gIGNvbnN0IHBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldEF4ZXNQZXJtdXRhdGlvbihheGVzLCB4UmFuayk7XG5cbiAgbGV0IHJlZHVjdGlvbkF4ZXMgPSBheGVzO1xuICBsZXQgcGVybXV0ZWRYID0gJHg7XG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgcGVybXV0ZWRYID1cbiAgICAgICAgdHJhbnNwb3NlKHtpbnB1dHM6IHt4OiAkeH0sIGJhY2tlbmQsIGF0dHJzOiB7cGVybTogcGVybXV0YXRpb259fSk7XG4gICAgcmVkdWN0aW9uQXhlcyA9IGJhY2tlbmRfdXRpbC5nZXRJbm5lck1vc3RBeGVzKHJlZHVjdGlvbkF4ZXMubGVuZ3RoLCB4UmFuayk7XG4gIH1cblxuICBiYWNrZW5kX3V0aWwuYXNzZXJ0QXhlc0FyZUlubmVyTW9zdERpbXMoXG4gICAgICAnc3VtJywgcmVkdWN0aW9uQXhlcywgcGVybXV0ZWRYLnNoYXBlLmxlbmd0aCk7XG5cbiAgY29uc3QgW291dFNoYXBlLCByZWR1Y2VTaGFwZV0gPVxuICAgICAgYmFja2VuZF91dGlsLmNvbXB1dGVPdXRBbmRSZWR1Y2VTaGFwZXMocGVybXV0ZWRYLnNoYXBlLCByZWR1Y3Rpb25BeGVzKTtcbiAgY29uc3QgcmVzdWx0RHR5cGUgPSBiYWNrZW5kX3V0aWwudXBjYXN0VHlwZShwZXJtdXRlZFguZHR5cGUsICdpbnQzMicpO1xuICBsZXQgcmVzdWx0ID0gemVyb3MoYmFja2VuZCwgb3V0U2hhcGUsIHJlc3VsdER0eXBlKTtcbiAgY29uc3QgcmVkdWNlU2l6ZSA9IHV0aWwuc2l6ZUZyb21TaGFwZShyZWR1Y2VTaGFwZSk7XG4gIGNvbnN0IHZhbHMgPSBiYWNrZW5kLmRhdGEuZ2V0KHJlc3VsdC5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuXG4gIGNvbnN0IGFWYWxzID0gYmFja2VuZC5kYXRhLmdldChwZXJtdXRlZFguZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCB2YWxzLmxlbmd0aDsgKytpKSB7XG4gICAgY29uc3Qgb2Zmc2V0ID0gaSAqIHJlZHVjZVNpemU7XG4gICAgbGV0IHN1bSA9IDA7XG4gICAgZm9yIChsZXQgaiA9IDA7IGogPCByZWR1Y2VTaXplOyArK2opIHtcbiAgICAgIHN1bSArPSBhVmFsc1tvZmZzZXQgKyBqXTtcbiAgICB9XG4gICAgdmFsc1tpXSA9IHN1bTtcbiAgfVxuXG4gIGlmIChrZWVwRGltcykge1xuICAgIGNvbnN0IG5ld1NoYXBlID0gYmFja2VuZF91dGlsLmV4cGFuZFNoYXBlVG9LZWVwRGltKHJlc3VsdC5zaGFwZSwgYXhlcyk7XG4gICAgY29uc3Qgb2xkUmVzdWx0ID0gcmVzdWx0O1xuICAgIHJlc3VsdCA9IHJlc2hhcGUoe2lucHV0czoge3g6IHJlc3VsdH0sIGJhY2tlbmQsIGF0dHJzOiB7c2hhcGU6IG5ld1NoYXBlfX0pO1xuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ob2xkUmVzdWx0KTtcbiAgfVxuXG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oJHgpO1xuXG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwZXJtdXRlZFgpO1xuICB9XG5cbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuZXhwb3J0IGNvbnN0IHN1bUNvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBTdW0sXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogc3VtIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==
|