/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Sum, util } from '@tensorflow/tfjs-core'; import { assertNotComplex } from '../cpu_util'; import { zeros } from '../utils/zeros_impl'; import { cast } from './Cast'; import { identity } from './Identity'; import { reshape } from './Reshape'; import { transpose } from './Transpose'; export function sum(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; assertNotComplex(x, 'sum'); let $x; if (x.dtype === 'bool') { $x = cast({ inputs: { x }, backend, attrs: { dtype: 'int32' } }); } else { $x = identity({ inputs: { x }, backend }); } const xRank = $x.shape.length; const axes = util.parseAxisParam(axis, $x.shape); const permutation = backend_util.getAxesPermutation(axes, xRank); let reductionAxes = axes; let permutedX = $x; if (permutation != null) { permutedX = transpose({ inputs: { x: $x }, backend, attrs: { perm: permutation } }); reductionAxes = backend_util.getInnerMostAxes(reductionAxes.length, xRank); } backend_util.assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length); const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(permutedX.shape, reductionAxes); const resultDtype = backend_util.upcastType(permutedX.dtype, 'int32'); let result = zeros(backend, outShape, resultDtype); const reduceSize = util.sizeFromShape(reduceShape); const vals = backend.data.get(result.dataId).values; const aVals = backend.data.get(permutedX.dataId).values; for (let i = 0; i < vals.length; ++i) { const offset = i * reduceSize; let sum = 0; for (let j = 0; j < reduceSize; ++j) { sum += aVals[offset + j]; } vals[i] = sum; } if (keepDims) { const newShape = backend_util.expandShapeToKeepDim(result.shape, axes); const oldResult = result; result = reshape({ inputs: { x: result }, backend, attrs: { shape: newShape } }); backend.disposeIntermediateTensorInfo(oldResult); } backend.disposeIntermediateTensorInfo($x); if (permutation != null) { backend.disposeIntermediateTensorInfo(permutedX); } return result; } export const sumConfig = { kernelName: Sum, backendName: 'cpu', kernelFunc: sum }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9TdW0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBNEIsR0FBRyxFQUErQyxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUdySSxPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLEtBQUssRUFBQyxNQUFNLHFCQUFxQixDQUFDO0FBQzFDLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFDNUIsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUNwQyxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2xDLE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFdEMsTUFBTSxVQUFVLEdBQUcsQ0FDZixJQUFtRTtJQUVyRSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsSUFBSSxFQUFFLFFBQVEsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUUvQixnQkFBZ0IsQ0FBQyxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFM0IsSUFBSSxFQUFFLENBQUM7SUFDUCxJQUFJLENBQUMsQ0FBQyxLQUFLLEtBQUssTUFBTSxFQUFFO1FBQ3RCLEVBQUUsR0FBRyxJQUFJLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLE9BQU8sRUFBQyxFQUFDLENBQUMsQ0FBQztLQUM1RDtTQUFNO1FBQ0wsRUFBRSxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7S0FDdkM7SUFFRCxNQUFNLEtBQUssR0FBRyxFQUFFLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztJQUM5QixNQUFNLElBQUksR0FBRyxJQUFJLENBQUMsY0FBYyxDQUFDLElBQUksRUFBRSxFQUFFLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDakQsTUFBTSxXQUFXLEdBQUcsWUFBWSxDQUFDLGtCQUFrQixDQUFDLElBQUksRUFBRSxLQUFLLENBQUMsQ0FBQztJQUVqRSxJQUFJLGFBQWEsR0FBRyxJQUFJLENBQUM7SUFDekIsSUFBSSxTQUFTLEdBQUcsRUFBRSxDQUFDO0lBQ25CLElBQUksV0FBVyxJQUFJLElBQUksRUFBRTtRQUN2QixTQUFTO1lBQ0wsU0FBUyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLEVBQUUsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxJQUFJLEVBQUUsV0FBVyxFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBQ3RFLGFBQWEsR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsYUFBYSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztLQUM1RTtJQUVELFlBQVksQ0FBQywwQkFBMEIsQ0FDbkMsS0FBSyxFQUFFLGFBQWEsRUFBRSxTQUFTLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBRWxELE1BQU0sQ0FBQyxRQUFRLEVBQUUsV0FBVyxDQUFDLEdBQ3pCLFlBQVksQ0FBQyx5QkFBeUIsQ0FBQyxTQUFTLENBQUMsS0FBSyxFQUFFLGFBQWEsQ0FBQyxDQUFDO0lBQzNFLE1BQU0sV0FBVyxHQUFHLFlBQVksQ0FBQyxVQUFVLENBQUMsU0FBUyxDQUFDLEtBQUssRUFBRSxPQUFPLENBQUMsQ0FBQztJQUN0RSxJQUFJLE1BQU0sR0FBRyxLQUFLLENBQUMsT0FBTyxFQUFFLFFBQVEsRUFBRSxXQUFXLENBQUMsQ0FBQztJQUNuRCxNQUFNLFVBQVUsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLFdBQVcsQ0FBQyxDQUFDO0lBQ25ELE1BQU0sSUFBSSxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBRWxFLE1BQU0sS0FBSyxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBQ3RFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQ3BDLE1BQU0sTUFBTSxHQUFHLENBQUMsR0FBRyxVQUFVLENBQUM7UUFDOUIsSUFBSSxHQUFHLEdBQUcsQ0FBQyxDQUFDO1FBQ1osS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFVBQVUsRUFBRSxFQUFFLENBQUMsRUFBRTtZQUNuQyxHQUFHLElBQUksS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUMsQ0FBQztTQUMxQjtRQUNELElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxHQUFHLENBQUM7S0FDZjtJQUVELElBQUksUUFBUSxFQUFFO1FBQ1osTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLE1BQU0sQ0FBQyxLQUFLLEVBQUUsSUFBSSxDQUFDLENBQUM7UUFDdkUsTUFBTSxTQUFTLEdBQUcsTUFBTSxDQUFDO1FBQ3pCLE1BQU0sR0FBRyxPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsTUFBTSxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxRQUFRLEVBQUMsRUFBQyxDQUFDLENBQUM7UUFDM0UsT0FBTyxDQUFDLDZCQUE2QixDQUFDLFNBQVMsQ0FBQyxDQUFDO0tBQ2xEO0lBRUQsT0FBTyxDQUFDLDZCQUE2QixDQUFDLEVBQUUsQ0FBQyxDQUFDO0lBRTFDLElBQUksV0FBVyxJQUFJLElBQUksRUFBRTtRQUN2QixPQUFPLENBQUMsNkJBQTZCLENBQUMsU0FBUyxDQUFDLENBQUM7S0FDbEQ7SUFFRCxPQUFPLE1BQU0sQ0FBQztBQUNoQixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sU0FBUyxHQUFpQjtJQUNyQyxVQUFVLEVBQUUsR0FBRztJQUNmLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxHQUE0QjtDQUN6QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTdW0sIFN1bUF0dHJzLCBTdW1JbnB1dHMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuaW1wb3J0IHt6ZXJvc30gZnJvbSAnLi4vdXRpbHMvemVyb3NfaW1wbCc7XG5pbXBvcnQge2Nhc3R9IGZyb20gJy4vQ2FzdCc7XG5pbXBvcnQge2lkZW50aXR5fSBmcm9tICcuL0lkZW50aXR5JztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzdW0oXG4gICAgYXJnczoge2lucHV0czogU3VtSW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSwgYXR0cnM6IFN1bUF0dHJzfSk6XG4gICAgVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4fSA9IGlucHV0cztcbiAgY29uc3Qge2F4aXMsIGtlZXBEaW1zfSA9IGF0dHJzO1xuXG4gIGFzc2VydE5vdENvbXBsZXgoeCwgJ3N1bScpO1xuXG4gIGxldCAkeDtcbiAgaWYgKHguZHR5cGUgPT09ICdib29sJykge1xuICAgICR4ID0gY2FzdCh7aW5wdXRzOiB7eH0sIGJhY2tlbmQsIGF0dHJzOiB7ZHR5cGU6ICdpbnQzMid9fSk7XG4gIH0gZWxzZSB7XG4gICAgJHggPSBpZGVudGl0eSh7aW5wdXRzOiB7eH0sIGJhY2tlbmR9KTtcbiAgfVxuXG4gIGNvbnN0IHhSYW5rID0gJHguc2hhcGUubGVuZ3RoO1xuICBjb25zdCBheGVzID0gdXRpbC5wYXJzZUF4aXNQYXJhbShheGlzLCAkeC5zaGFwZSk7XG4gIGNvbnN0IHBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldEF4ZXNQZXJtdXRhdGlvbihheGVzLCB4UmFuayk7XG5cbiAgbGV0IHJlZHVjdGlvbkF4ZXMgPSBheGVzO1xuICBsZXQgcGVybXV0ZWRYID0gJHg7XG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgcGVybXV0ZWRYID1cbiAgICAgICAgdHJhbnNwb3NlKHtpbnB1dHM6IHt4OiAkeH0sIGJhY2tlbmQsIGF0dHJzOiB7cGVybTogcGVybXV0YXRpb259fSk7XG4gICAgcmVkdWN0aW9uQXhlcyA9IGJhY2tlbmRfdXRpbC5nZXRJbm5lck1vc3RBeGVzKHJlZHVjdGlvbkF4ZXMubGVuZ3RoLCB4UmFuayk7XG4gIH1cblxuICBiYWNrZW5kX3V0aWwuYXNzZXJ0QXhlc0FyZUlubmVyTW9zdERpbXMoXG4gICAgICAnc3VtJywgcmVkdWN0aW9uQXhlcywgcGVybXV0ZWRYLnNoYXBlLmxlbmd0aCk7XG5cbiAgY29uc3QgW291dFNoYXBlLCByZWR1Y2VTaGFwZV0gPVxuICAgICAgYmFja2VuZF91dGlsLmNvbXB1dGVPdXRBbmRSZWR1Y2VTaGFwZXMocGVybXV0ZWRYLnNoYXBlLCByZWR1Y3Rpb25BeGVzKTtcbiAgY29uc3QgcmVzdWx0RHR5cGUgPSBiYWNrZW5kX3V0aWwudXBjYXN0VHlwZShwZXJtdXRlZFguZHR5cGUsICdpbnQzMicpO1xuICBsZXQgcmVzdWx0ID0gemVyb3MoYmFja2VuZCwgb3V0U2hhcGUsIHJlc3VsdER0eXBlKTtcbiAgY29uc3QgcmVkdWNlU2l6ZSA9IHV0aWwuc2l6ZUZyb21TaGFwZShyZWR1Y2VTaGFwZSk7XG4gIGNvbnN0IHZhbHMgPSBiYWNrZW5kLmRhdGEuZ2V0KHJlc3VsdC5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuXG4gIGNvbnN0IGFWYWxzID0gYmFja2VuZC5kYXRhLmdldChwZXJtdXRlZFguZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCB2YWxzLmxlbmd0aDsgKytpKSB7XG4gICAgY29uc3Qgb2Zmc2V0ID0gaSAqIHJlZHVjZVNpemU7XG4gICAgbGV0IHN1bSA9IDA7XG4gICAgZm9yIChsZXQgaiA9IDA7IGogPCByZWR1Y2VTaXplOyArK2opIHtcbiAgICAgIHN1bSArPSBhVmFsc1tvZmZzZXQgKyBqXTtcbiAgICB9XG4gICAgdmFsc1tpXSA9IHN1bTtcbiAgfVxuXG4gIGlmIChrZWVwRGltcykge1xuICAgIGNvbnN0IG5ld1NoYXBlID0gYmFja2VuZF91dGlsLmV4cGFuZFNoYXBlVG9LZWVwRGltKHJlc3VsdC5zaGFwZSwgYXhlcyk7XG4gICAgY29uc3Qgb2xkUmVzdWx0ID0gcmVzdWx0O1xuICAgIHJlc3VsdCA9IHJlc2hhcGUoe2lucHV0czoge3g6IHJlc3VsdH0sIGJhY2tlbmQsIGF0dHJzOiB7c2hhcGU6IG5ld1NoYXBlfX0pO1xuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ob2xkUmVzdWx0KTtcbiAgfVxuXG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oJHgpO1xuXG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwZXJtdXRlZFgpO1xuICB9XG5cbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuZXhwb3J0IGNvbnN0IHN1bUNvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBTdW0sXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogc3VtIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==