gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Cumsum, upcastType, util } from '@tensorflow/tfjs-core';
import { assertNotComplex } from '../cpu_util';
import { transpose } from './Transpose';
export function cumsum(args) {
    const { inputs, backend, attrs } = args;
    const { x } = inputs;
    const { axis, exclusive, reverse } = attrs;
    assertNotComplex(x, 'cumsum');
    const permutation = backend_util.getAxesPermutation([axis], x.shape.length);
    let $x = x;
    if (permutation != null) {
        $x = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
    }
    const permutedAxis = backend_util.getInnerMostAxes(1, x.shape.length)[0];
    if (permutedAxis !== $x.shape.length - 1) {
        throw new Error(`backend.cumsum in CPU expects an inner-most ` +
            `axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
    }
    const resultDtype = upcastType($x.dtype, 'int32');
    const vals = util.makeZerosTypedArray(util.sizeFromShape($x.shape), resultDtype);
    const aVals = backend.data.get($x.dataId).values;
    const finalDim = $x.shape[$x.shape.length - 1];
    const indexAdjuster = reverse ?
        (i, j) => i + finalDim - j - 1 :
        (i, j) => i + j;
    for (let i = 0; i < aVals.length; i += finalDim) {
        for (let j = 0; j < finalDim; j++) {
            const idx = indexAdjuster(i, j);
            if (j === 0) {
                vals[idx] = exclusive ? 0 : aVals[idx];
            }
            else {
                const prevIdx = indexAdjuster(i, j - 1);
                vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] :
                    aVals[idx] + vals[prevIdx];
            }
        }
    }
    const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
    if (permutation != null) {
        const reversePermutation = backend_util.getUndoAxesPermutation(permutation);
        const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
        backend.disposeIntermediateTensorInfo(result);
        backend.disposeIntermediateTensorInfo($x);
        return reverseTransposedResult;
    }
    return result;
}
export const cumsumConfig = {
    kernelName: Cumsum,
    backendName: 'cpu',
    kernelFunc: cumsum
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQ3Vtc3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9DdW1zdW0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBRSxNQUFNLEVBQStFLFVBQVUsRUFBRSxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUcxSixPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUV0QyxNQUFNLFVBQVUsTUFBTSxDQUNsQixJQUF5RTtJQUUzRSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFekMsZ0JBQWdCLENBQUMsQ0FBQyxFQUFFLFFBQVEsQ0FBQyxDQUFDO0lBRTlCLE1BQU0sV0FBVyxHQUFHLFlBQVksQ0FBQyxrQkFBa0IsQ0FBQyxDQUFDLElBQUksQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDNUUsSUFBSSxFQUFFLEdBQUcsQ0FBQyxDQUFDO0lBQ1gsSUFBSSxXQUFXLElBQUksSUFBSSxFQUFFO1FBQ3ZCLEVBQUUsR0FBRyxTQUFTLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLFdBQVcsRUFBQyxFQUFDLENBQUMsQ0FBQztLQUNwRTtJQUNELE1BQU0sWUFBWSxHQUFHLFlBQVksQ0FBQyxnQkFBZ0IsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUV6RSxJQUFJLFlBQVksS0FBSyxFQUFFLENBQUMsS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLEVBQUU7UUFDeEMsTUFBTSxJQUFJLEtBQUssQ0FDWCw4Q0FBOEM7WUFDOUMsUUFBUSxFQUFFLENBQUMsS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLGlCQUFpQixZQUFZLEVBQUUsQ0FBQyxDQUFDO0tBQ2pFO0lBRUQsTUFBTSxXQUFXLEdBQUcsVUFBVSxDQUFDLEVBQUUsQ0FBQyxLQUFLLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFDbEQsTUFBTSxJQUFJLEdBQUcsSUFBSSxDQUFDLG1CQUFtQixDQUNwQixJQUFJLENBQUMsYUFBYSxDQUFDLEVBQUUsQ0FBQyxLQUFLLENBQUMsRUFBRSxXQUFXLENBQWUsQ0FBQztJQUUxRSxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUMvRCxNQUFNLFFBQVEsR0FBRyxFQUFFLENBQUMsS0FBSyxDQUFDLEVBQUUsQ0FBQyxLQUFLLENBQUMsTUFBTSxHQUFHLENBQUMsQ0FBQyxDQUFDO0lBQy9DLE1BQU0sYUFBYSxHQUFHLE9BQU8sQ0FBQyxDQUFDO1FBQzNCLENBQUMsQ0FBUyxFQUFFLENBQVMsRUFBRSxFQUFFLENBQUMsQ0FBQyxHQUFHLFFBQVEsR0FBRyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7UUFDaEQsQ0FBQyxDQUFTLEVBQUUsQ0FBUyxFQUFFLEVBQUUsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDO0lBQ3BDLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxLQUFLLENBQUMsTUFBTSxFQUFFLENBQUMsSUFBSSxRQUFRLEVBQUU7UUFDL0MsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFFBQVEsRUFBRSxDQUFDLEVBQUUsRUFBRTtZQUNqQyxNQUFNLEdBQUcsR0FBRyxhQUFhLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO1lBQ2hDLElBQUksQ0FBQyxLQUFLLENBQUMsRUFBRTtnQkFDWCxJQUFJLENBQUMsR0FBRyxDQUFDLEdBQUcsU0FBUyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxHQUFHLENBQUMsQ0FBQzthQUN4QztpQkFBTTtnQkFDTCxNQUFNLE9BQU8sR0FBRyxhQUFhLENBQUMsQ0FBQyxFQUFFLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztnQkFDeEMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLE9BQU8sQ0FBQyxHQUFHLElBQUksQ0FBQyxPQUFPLENBQUMsQ0FBQyxDQUFDO29CQUNoQyxLQUFLLENBQUMsR0FBRyxDQUFDLEdBQUcsSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDO2FBQ3BEO1NBQ0Y7S0FDRjtJQUVELE1BQU0sTUFBTSxHQUFHLE9BQU8sQ0FBQyxjQUFjLENBQUMsRUFBRSxDQUFDLEtBQUssRUFBRSxXQUFXLEVBQUUsSUFBSSxDQUFDLENBQUM7SUFFbkUsSUFBSSxXQUFXLElBQUksSUFBSSxFQUFFO1FBQ3ZCLE1BQU0sa0JBQWtCLEdBQUcsWUFBWSxDQUFDLHNCQUFzQixDQUFDLFdBQVcsQ0FBQyxDQUFDO1FBQzVFLE1BQU0sdUJBQXVCLEdBQUcsU0FBUyxDQUNyQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLGtCQUFrQixFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBRXZFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxNQUFNLENBQUMsQ0FBQztRQUM5QyxPQUFPLENBQUMsNkJBQTZCLENBQUMsRUFBRSxDQUFDLENBQUM7UUFFMUMsT0FBTyx1QkFBdUIsQ0FBQztLQUNoQztJQUVELE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxZQUFZLEdBQWlCO0lBQ3hDLFVBQVUsRUFBRSxNQUFNO0lBQ2xCLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxNQUErQjtDQUM1QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgQ3Vtc3VtLCBDdW1zdW1BdHRycywgQ3Vtc3VtSW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHVwY2FzdFR5cGUsIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuaW1wb3J0IHt0cmFuc3Bvc2V9IGZyb20gJy4vVHJhbnNwb3NlJztcblxuZXhwb3J0IGZ1bmN0aW9uIGN1bXN1bShcbiAgICBhcmdzOiB7aW5wdXRzOiBDdW1zdW1JbnB1dHMsIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVLCBhdHRyczogQ3Vtc3VtQXR0cnN9KTpcbiAgICBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICBjb25zdCB7YXhpcywgZXhjbHVzaXZlLCByZXZlcnNlfSA9IGF0dHJzO1xuXG4gIGFzc2VydE5vdENvbXBsZXgoeCwgJ2N1bXN1bScpO1xuXG4gIGNvbnN0IHBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldEF4ZXNQZXJtdXRhdGlvbihbYXhpc10sIHguc2hhcGUubGVuZ3RoKTtcbiAgbGV0ICR4ID0geDtcbiAgaWYgKHBlcm11dGF0aW9uICE9IG51bGwpIHtcbiAgICAkeCA9IHRyYW5zcG9zZSh7aW5wdXRzOiB7eH0sIGJhY2tlbmQsIGF0dHJzOiB7cGVybTogcGVybXV0YXRpb259fSk7XG4gIH1cbiAgY29uc3QgcGVybXV0ZWRBeGlzID0gYmFja2VuZF91dGlsLmdldElubmVyTW9zdEF4ZXMoMSwgeC5zaGFwZS5sZW5ndGgpWzBdO1xuXG4gIGlmIChwZXJtdXRlZEF4aXMgIT09ICR4LnNoYXBlLmxlbmd0aCAtIDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGBiYWNrZW5kLmN1bXN1bSBpbiBDUFUgZXhwZWN0cyBhbiBpbm5lci1tb3N0IGAgK1xuICAgICAgICBgYXhpcz0keyR4LnNoYXBlLmxlbmd0aCAtIDF9IGJ1dCBnb3QgYXhpcz0ke3Blcm11dGVkQXhpc31gKTtcbiAgfVxuXG4gIGNvbnN0IHJlc3VsdER0eXBlID0gdXBjYXN0VHlwZSgkeC5kdHlwZSwgJ2ludDMyJyk7XG4gIGNvbnN0IHZhbHMgPSB1dGlsLm1ha2VaZXJvc1R5cGVkQXJyYXkoXG4gICAgICAgICAgICAgICAgICAgdXRpbC5zaXplRnJvbVNoYXBlKCR4LnNoYXBlKSwgcmVzdWx0RHR5cGUpIGFzIFR5cGVkQXJyYXk7XG5cbiAgY29uc3QgYVZhbHMgPSBiYWNrZW5kLmRhdGEuZ2V0KCR4LmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXk7XG4gIGNvbnN0IGZpbmFsRGltID0gJHguc2hhcGVbJHguc2hhcGUubGVuZ3RoIC0gMV07XG4gIGNvbnN0IGluZGV4QWRqdXN0ZXIgPSByZXZlcnNlID9cbiAgICAgIChpOiBudW1iZXIsIGo6IG51bWJlcikgPT4gaSArIGZpbmFsRGltIC0gaiAtIDEgOlxuICAgICAgKGk6IG51bWJlciwgajogbnVtYmVyKSA9PiBpICsgajtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCBhVmFscy5sZW5ndGg7IGkgKz0gZmluYWxEaW0pIHtcbiAgICBmb3IgKGxldCBqID0gMDsgaiA8IGZpbmFsRGltOyBqKyspIHtcbiAgICAgIGNvbnN0IGlkeCA9IGluZGV4QWRqdXN0ZXIoaSwgaik7XG4gICAgICBpZiAoaiA9PT0gMCkge1xuICAgICAgICB2YWxzW2lkeF0gPSBleGNsdXNpdmUgPyAwIDogYVZhbHNbaWR4XTtcbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIGNvbnN0IHByZXZJZHggPSBpbmRleEFkanVzdGVyKGksIGogLSAxKTtcbiAgICAgICAgdmFsc1tpZHhdID0gZXhjbHVzaXZlID8gYVZhbHNbcHJldklkeF0gKyB2YWxzW3ByZXZJZHhdIDpcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYVZhbHNbaWR4XSArIHZhbHNbcHJldklkeF07XG4gICAgICB9XG4gICAgfVxuICB9XG5cbiAgY29uc3QgcmVzdWx0ID0gYmFja2VuZC5tYWtlVGVuc29ySW5mbygkeC5zaGFwZSwgcmVzdWx0RHR5cGUsIHZhbHMpO1xuXG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgY29uc3QgcmV2ZXJzZVBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldFVuZG9BeGVzUGVybXV0YXRpb24ocGVybXV0YXRpb24pO1xuICAgIGNvbnN0IHJldmVyc2VUcmFuc3Bvc2VkUmVzdWx0ID0gdHJhbnNwb3NlKFxuICAgICAgICB7aW5wdXRzOiB7eDogcmVzdWx0fSwgYmFja2VuZCwgYXR0cnM6IHtwZXJtOiByZXZlcnNlUGVybXV0YXRpb259fSk7XG5cbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc3VsdCk7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbygkeCk7XG5cbiAgICByZXR1cm4gcmV2ZXJzZVRyYW5zcG9zZWRSZXN1bHQ7XG4gIH1cblxuICByZXR1cm4gcmVzdWx0O1xufVxuXG5leHBvcnQgY29uc3QgY3Vtc3VtQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEN1bXN1bSxcbiAgYmFja2VuZE5hbWU6ICdjcHUnLFxuICBrZXJuZWxGdW5jOiBjdW1zdW0gYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19