/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Cumsum, upcastType, util } from '@tensorflow/tfjs-core'; import { assertNotComplex } from '../cpu_util'; import { transpose } from './Transpose'; export function cumsum(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, exclusive, reverse } = attrs; assertNotComplex(x, 'cumsum'); const permutation = backend_util.getAxesPermutation([axis], x.shape.length); let $x = x; if (permutation != null) { $x = transpose({ inputs: { x }, backend, attrs: { perm: permutation } }); } const permutedAxis = backend_util.getInnerMostAxes(1, x.shape.length)[0]; if (permutedAxis !== $x.shape.length - 1) { throw new Error(`backend.cumsum in CPU expects an inner-most ` + `axis=${$x.shape.length - 1} but got axis=${permutedAxis}`); } const resultDtype = upcastType($x.dtype, 'int32'); const vals = util.makeZerosTypedArray(util.sizeFromShape($x.shape), resultDtype); const aVals = backend.data.get($x.dataId).values; const finalDim = $x.shape[$x.shape.length - 1]; const indexAdjuster = reverse ? (i, j) => i + finalDim - j - 1 : (i, j) => i + j; for (let i = 0; i < aVals.length; i += finalDim) { for (let j = 0; j < finalDim; j++) { const idx = indexAdjuster(i, j); if (j === 0) { vals[idx] = exclusive ? 0 : aVals[idx]; } else { const prevIdx = indexAdjuster(i, j - 1); vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] : aVals[idx] + vals[prevIdx]; } } } const result = backend.makeTensorInfo($x.shape, resultDtype, vals); if (permutation != null) { const reversePermutation = backend_util.getUndoAxesPermutation(permutation); const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } }); backend.disposeIntermediateTensorInfo(result); backend.disposeIntermediateTensorInfo($x); return reverseTransposedResult; } return result; } export const cumsumConfig = { kernelName: Cumsum, backendName: 'cpu', kernelFunc: cumsum }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQ3Vtc3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9DdW1zdW0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBRSxNQUFNLEVBQStFLFVBQVUsRUFBRSxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUcxSixPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUV0QyxNQUFNLFVBQVUsTUFBTSxDQUNsQixJQUF5RTtJQUUzRSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFekMsZ0JBQWdCLENBQUMsQ0FBQyxFQUFFLFFBQVEsQ0FBQyxDQUFDO0lBRTlCLE1BQU0sV0FBVyxHQUFHLFlBQVksQ0FBQyxrQkFBa0IsQ0FBQyxDQUFDLElBQUksQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDNUUsSUFBSSxFQUFFLEdBQUcsQ0FBQyxDQUFDO0lBQ1gsSUFBSSxXQUFXLElBQUksSUFBSSxFQUFFO1FBQ3ZCLEVBQUUsR0FBRyxTQUFTLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLFdBQVcsRUFBQyxFQUFDLENBQUMsQ0FBQztLQUNwRTtJQUNELE1BQU0sWUFBWSxHQUFHLFlBQVksQ0FBQyxnQkFBZ0IsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUV6RSxJQUFJLFlBQVksS0FBSyxFQUFFLENBQUMsS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLEVBQUU7UUFDeEMsTUFBTSxJQUFJLEtBQUssQ0FDWCw4Q0FBOEM7WUFDOUMsUUFBUSxFQUFFLENBQUMsS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLGlCQUFpQixZQUFZLEVBQUUsQ0FBQyxDQUFDO0tBQ2pFO0lBRUQsTUFBTSxXQUFXLEdBQUcsVUFBVSxDQUFDLEVBQUUsQ0FBQyxLQUFLLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFDbEQsTUFBTSxJQUFJLEdBQUcsSUFBSSxDQUFDLG1CQUFtQixDQUNwQixJQUFJLENBQUMsYUFBYSxDQUFDLEVBQUUsQ0FBQyxLQUFLLENBQUMsRUFBRSxXQUFXLENBQWUsQ0FBQztJQUUxRSxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUMvRCxNQUFNLFFBQVEsR0FBRyxFQUFFLENBQUMsS0FBSyxDQUFDLEVBQUUsQ0FBQyxLQUFLLENBQUMsTUFBTSxHQUFHLENBQUMsQ0FBQyxDQUFDO0lBQy9DLE1BQU0sYUFBYSxHQUFHLE9BQU8sQ0FBQyxDQUFDO1FBQzNCLENBQUMsQ0FBUyxFQUFFLENBQVMsRUFBRSxFQUFFLENBQUMsQ0FBQyxHQUFHLFFBQVEsR0FBRyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7UUFDaEQsQ0FBQyxDQUFTLEVBQUUsQ0FBUyxFQUFFLEVBQUUsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDO0lBQ3BDLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxLQUFLLENBQUMsTUFBTSxFQUFFLENBQUMsSUFBSSxRQUFRLEVBQUU7UUFDL0MsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFFBQVEsRUFBRSxDQUFDLEVBQUUsRUFBRTtZQUNqQyxNQUFNLEdBQUcsR0FBRyxhQUFhLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO1lBQ2hDLElBQUksQ0FBQyxLQUFLLENBQUMsRUFBRTtnQkFDWCxJQUFJLENBQUMsR0FBRyxDQUFDLEdBQUcsU0FBUyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxHQUFHLENBQUMsQ0FBQzthQUN4QztpQkFBTTtnQkFDTCxNQUFNLE9BQU8sR0FBRyxhQUFhLENBQUMsQ0FBQyxFQUFFLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztnQkFDeEMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLE9BQU8sQ0FBQyxHQUFHLElBQUksQ0FBQyxPQUFPLENBQUMsQ0FBQyxDQUFDO29CQUNoQyxLQUFLLENBQUMsR0FBRyxDQUFDLEdBQUcsSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDO2FBQ3BEO1NBQ0Y7S0FDRjtJQUVELE1BQU0sTUFBTSxHQUFHLE9BQU8sQ0FBQyxjQUFjLENBQUMsRUFBRSxDQUFDLEtBQUssRUFBRSxXQUFXLEVBQUUsSUFBSSxDQUFDLENBQUM7SUFFbkUsSUFBSSxXQUFXLElBQUksSUFBSSxFQUFFO1FBQ3ZCLE1BQU0sa0JBQWtCLEdBQUcsWUFBWSxDQUFDLHNCQUFzQixDQUFDLFdBQVcsQ0FBQyxDQUFDO1FBQzVFLE1BQU0sdUJBQXVCLEdBQUcsU0FBUyxDQUNyQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLGtCQUFrQixFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBRXZFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxNQUFNLENBQUMsQ0FBQztRQUM5QyxPQUFPLENBQUMsNkJBQTZCLENBQUMsRUFBRSxDQUFDLENBQUM7UUFFMUMsT0FBTyx1QkFBdUIsQ0FBQztLQUNoQztJQUVELE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxZQUFZLEdBQWlCO0lBQ3hDLFVBQVUsRUFBRSxNQUFNO0lBQ2xCLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxNQUErQjtDQUM1QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgQ3Vtc3VtLCBDdW1zdW1BdHRycywgQ3Vtc3VtSW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHVwY2FzdFR5cGUsIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuaW1wb3J0IHt0cmFuc3Bvc2V9IGZyb20gJy4vVHJhbnNwb3NlJztcblxuZXhwb3J0IGZ1bmN0aW9uIGN1bXN1bShcbiAgICBhcmdzOiB7aW5wdXRzOiBDdW1zdW1JbnB1dHMsIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVLCBhdHRyczogQ3Vtc3VtQXR0cnN9KTpcbiAgICBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICBjb25zdCB7YXhpcywgZXhjbHVzaXZlLCByZXZlcnNlfSA9IGF0dHJzO1xuXG4gIGFzc2VydE5vdENvbXBsZXgoeCwgJ2N1bXN1bScpO1xuXG4gIGNvbnN0IHBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldEF4ZXNQZXJtdXRhdGlvbihbYXhpc10sIHguc2hhcGUubGVuZ3RoKTtcbiAgbGV0ICR4ID0geDtcbiAgaWYgKHBlcm11dGF0aW9uICE9IG51bGwpIHtcbiAgICAkeCA9IHRyYW5zcG9zZSh7aW5wdXRzOiB7eH0sIGJhY2tlbmQsIGF0dHJzOiB7cGVybTogcGVybXV0YXRpb259fSk7XG4gIH1cbiAgY29uc3QgcGVybXV0ZWRBeGlzID0gYmFja2VuZF91dGlsLmdldElubmVyTW9zdEF4ZXMoMSwgeC5zaGFwZS5sZW5ndGgpWzBdO1xuXG4gIGlmIChwZXJtdXRlZEF4aXMgIT09ICR4LnNoYXBlLmxlbmd0aCAtIDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGBiYWNrZW5kLmN1bXN1bSBpbiBDUFUgZXhwZWN0cyBhbiBpbm5lci1tb3N0IGAgK1xuICAgICAgICBgYXhpcz0keyR4LnNoYXBlLmxlbmd0aCAtIDF9IGJ1dCBnb3QgYXhpcz0ke3Blcm11dGVkQXhpc31gKTtcbiAgfVxuXG4gIGNvbnN0IHJlc3VsdER0eXBlID0gdXBjYXN0VHlwZSgkeC5kdHlwZSwgJ2ludDMyJyk7XG4gIGNvbnN0IHZhbHMgPSB1dGlsLm1ha2VaZXJvc1R5cGVkQXJyYXkoXG4gICAgICAgICAgICAgICAgICAgdXRpbC5zaXplRnJvbVNoYXBlKCR4LnNoYXBlKSwgcmVzdWx0RHR5cGUpIGFzIFR5cGVkQXJyYXk7XG5cbiAgY29uc3QgYVZhbHMgPSBiYWNrZW5kLmRhdGEuZ2V0KCR4LmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXk7XG4gIGNvbnN0IGZpbmFsRGltID0gJHguc2hhcGVbJHguc2hhcGUubGVuZ3RoIC0gMV07XG4gIGNvbnN0IGluZGV4QWRqdXN0ZXIgPSByZXZlcnNlID9cbiAgICAgIChpOiBudW1iZXIsIGo6IG51bWJlcikgPT4gaSArIGZpbmFsRGltIC0gaiAtIDEgOlxuICAgICAgKGk6IG51bWJlciwgajogbnVtYmVyKSA9PiBpICsgajtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCBhVmFscy5sZW5ndGg7IGkgKz0gZmluYWxEaW0pIHtcbiAgICBmb3IgKGxldCBqID0gMDsgaiA8IGZpbmFsRGltOyBqKyspIHtcbiAgICAgIGNvbnN0IGlkeCA9IGluZGV4QWRqdXN0ZXIoaSwgaik7XG4gICAgICBpZiAoaiA9PT0gMCkge1xuICAgICAgICB2YWxzW2lkeF0gPSBleGNsdXNpdmUgPyAwIDogYVZhbHNbaWR4XTtcbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIGNvbnN0IHByZXZJZHggPSBpbmRleEFkanVzdGVyKGksIGogLSAxKTtcbiAgICAgICAgdmFsc1tpZHhdID0gZXhjbHVzaXZlID8gYVZhbHNbcHJldklkeF0gKyB2YWxzW3ByZXZJZHhdIDpcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYVZhbHNbaWR4XSArIHZhbHNbcHJldklkeF07XG4gICAgICB9XG4gICAgfVxuICB9XG5cbiAgY29uc3QgcmVzdWx0ID0gYmFja2VuZC5tYWtlVGVuc29ySW5mbygkeC5zaGFwZSwgcmVzdWx0RHR5cGUsIHZhbHMpO1xuXG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgY29uc3QgcmV2ZXJzZVBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldFVuZG9BeGVzUGVybXV0YXRpb24ocGVybXV0YXRpb24pO1xuICAgIGNvbnN0IHJldmVyc2VUcmFuc3Bvc2VkUmVzdWx0ID0gdHJhbnNwb3NlKFxuICAgICAgICB7aW5wdXRzOiB7eDogcmVzdWx0fSwgYmFja2VuZCwgYXR0cnM6IHtwZXJtOiByZXZlcnNlUGVybXV0YXRpb259fSk7XG5cbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc3VsdCk7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbygkeCk7XG5cbiAgICByZXR1cm4gcmV2ZXJzZVRyYW5zcG9zZWRSZXN1bHQ7XG4gIH1cblxuICByZXR1cm4gcmVzdWx0O1xufVxuXG5leHBvcnQgY29uc3QgY3Vtc3VtQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEN1bXN1bSxcbiAgYmFja2VuZE5hbWU6ICdjcHUnLFxuICBrZXJuZWxGdW5jOiBjdW1zdW0gYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19