gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Sum, util } from '@tensorflow/tfjs-core';
import { assertNotComplex } from '../cpu_util';
import { zeros } from '../utils/zeros_impl';
import { cast } from './Cast';
import { identity } from './Identity';
import { reshape } from './Reshape';
import { transpose } from './Transpose';
export function sum(args) {
    const { inputs, backend, attrs } = args;
    const { x } = inputs;
    const { axis, keepDims } = attrs;
    assertNotComplex(x, 'sum');
    let $x;
    if (x.dtype === 'bool') {
        $x = cast({ inputs: { x }, backend, attrs: { dtype: 'int32' } });
    }
    else {
        $x = identity({ inputs: { x }, backend });
    }
    const xRank = $x.shape.length;
    const axes = util.parseAxisParam(axis, $x.shape);
    const permutation = backend_util.getAxesPermutation(axes, xRank);
    let reductionAxes = axes;
    let permutedX = $x;
    if (permutation != null) {
        permutedX =
            transpose({ inputs: { x: $x }, backend, attrs: { perm: permutation } });
        reductionAxes = backend_util.getInnerMostAxes(reductionAxes.length, xRank);
    }
    backend_util.assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
    const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(permutedX.shape, reductionAxes);
    const resultDtype = backend_util.upcastType(permutedX.dtype, 'int32');
    let result = zeros(backend, outShape, resultDtype);
    const reduceSize = util.sizeFromShape(reduceShape);
    const vals = backend.data.get(result.dataId).values;
    const aVals = backend.data.get(permutedX.dataId).values;
    for (let i = 0; i < vals.length; ++i) {
        const offset = i * reduceSize;
        let sum = 0;
        for (let j = 0; j < reduceSize; ++j) {
            sum += aVals[offset + j];
        }
        vals[i] = sum;
    }
    if (keepDims) {
        const newShape = backend_util.expandShapeToKeepDim(result.shape, axes);
        const oldResult = result;
        result = reshape({ inputs: { x: result }, backend, attrs: { shape: newShape } });
        backend.disposeIntermediateTensorInfo(oldResult);
    }
    backend.disposeIntermediateTensorInfo($x);
    if (permutation != null) {
        backend.disposeIntermediateTensorInfo(permutedX);
    }
    return result;
}
export const sumConfig = {
    kernelName: Sum,
    backendName: 'cpu',
    kernelFunc: sum
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9TdW0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBNEIsR0FBRyxFQUErQyxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUdySSxPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLEtBQUssRUFBQyxNQUFNLHFCQUFxQixDQUFDO0FBQzFDLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFDNUIsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUNwQyxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2xDLE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFdEMsTUFBTSxVQUFVLEdBQUcsQ0FDZixJQUFtRTtJQUVyRSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsSUFBSSxFQUFFLFFBQVEsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUUvQixnQkFBZ0IsQ0FBQyxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFM0IsSUFBSSxFQUFFLENBQUM7SUFDUCxJQUFJLENBQUMsQ0FBQyxLQUFLLEtBQUssTUFBTSxFQUFFO1FBQ3RCLEVBQUUsR0FBRyxJQUFJLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLE9BQU8sRUFBQyxFQUFDLENBQUMsQ0FBQztLQUM1RDtTQUFNO1FBQ0wsRUFBRSxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7S0FDdkM7SUFFRCxNQUFNLEtBQUssR0FBRyxFQUFFLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztJQUM5QixNQUFNLElBQUksR0FBRyxJQUFJLENBQUMsY0FBYyxDQUFDLElBQUksRUFBRSxFQUFFLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDakQsTUFBTSxXQUFXLEdBQUcsWUFBWSxDQUFDLGtCQUFrQixDQUFDLElBQUksRUFBRSxLQUFLLENBQUMsQ0FBQztJQUVqRSxJQUFJLGFBQWEsR0FBRyxJQUFJLENBQUM7SUFDekIsSUFBSSxTQUFTLEdBQUcsRUFBRSxDQUFDO0lBQ25CLElBQUksV0FBVyxJQUFJLElBQUksRUFBRTtRQUN2QixTQUFTO1lBQ0wsU0FBUyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLEVBQUUsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxJQUFJLEVBQUUsV0FBVyxFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBQ3RFLGFBQWEsR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsYUFBYSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztLQUM1RTtJQUVELFlBQVksQ0FBQywwQkFBMEIsQ0FDbkMsS0FBSyxFQUFFLGFBQWEsRUFBRSxTQUFTLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBRWxELE1BQU0sQ0FBQyxRQUFRLEVBQUUsV0FBVyxDQUFDLEdBQ3pCLFlBQVksQ0FBQyx5QkFBeUIsQ0FBQyxTQUFTLENBQUMsS0FBSyxFQUFFLGFBQWEsQ0FBQyxDQUFDO0lBQzNFLE1BQU0sV0FBVyxHQUFHLFlBQVksQ0FBQyxVQUFVLENBQUMsU0FBUyxDQUFDLEtBQUssRUFBRSxPQUFPLENBQUMsQ0FBQztJQUN0RSxJQUFJLE1BQU0sR0FBRyxLQUFLLENBQUMsT0FBTyxFQUFFLFFBQVEsRUFBRSxXQUFXLENBQUMsQ0FBQztJQUNuRCxNQUFNLFVBQVUsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLFdBQVcsQ0FBQyxDQUFDO0lBQ25ELE1BQU0sSUFBSSxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBRWxFLE1BQU0sS0FBSyxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBQ3RFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQ3BDLE1BQU0sTUFBTSxHQUFHLENBQUMsR0FBRyxVQUFVLENBQUM7UUFDOUIsSUFBSSxHQUFHLEdBQUcsQ0FBQyxDQUFDO1FBQ1osS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFVBQVUsRUFBRSxFQUFFLENBQUMsRUFBRTtZQUNuQyxHQUFHLElBQUksS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUMsQ0FBQztTQUMxQjtRQUNELElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxHQUFHLENBQUM7S0FDZjtJQUVELElBQUksUUFBUSxFQUFFO1FBQ1osTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLE1BQU0sQ0FBQyxLQUFLLEVBQUUsSUFBSSxDQUFDLENBQUM7UUFDdkUsTUFBTSxTQUFTLEdBQUcsTUFBTSxDQUFDO1FBQ3pCLE1BQU0sR0FBRyxPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsTUFBTSxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxRQUFRLEVBQUMsRUFBQyxDQUFDLENBQUM7UUFDM0UsT0FBTyxDQUFDLDZCQUE2QixDQUFDLFNBQVMsQ0FBQyxDQUFDO0tBQ2xEO0lBRUQsT0FBTyxDQUFDLDZCQUE2QixDQUFDLEVBQUUsQ0FBQyxDQUFDO0lBRTFDLElBQUksV0FBVyxJQUFJLElBQUksRUFBRTtRQUN2QixPQUFPLENBQUMsNkJBQTZCLENBQUMsU0FBUyxDQUFDLENBQUM7S0FDbEQ7SUFFRCxPQUFPLE1BQU0sQ0FBQztBQUNoQixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sU0FBUyxHQUFpQjtJQUNyQyxVQUFVLEVBQUUsR0FBRztJQUNmLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxHQUE0QjtDQUN6QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTdW0sIFN1bUF0dHJzLCBTdW1JbnB1dHMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuaW1wb3J0IHt6ZXJvc30gZnJvbSAnLi4vdXRpbHMvemVyb3NfaW1wbCc7XG5pbXBvcnQge2Nhc3R9IGZyb20gJy4vQ2FzdCc7XG5pbXBvcnQge2lkZW50aXR5fSBmcm9tICcuL0lkZW50aXR5JztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzdW0oXG4gICAgYXJnczoge2lucHV0czogU3VtSW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSwgYXR0cnM6IFN1bUF0dHJzfSk6XG4gICAgVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4fSA9IGlucHV0cztcbiAgY29uc3Qge2F4aXMsIGtlZXBEaW1zfSA9IGF0dHJzO1xuXG4gIGFzc2VydE5vdENvbXBsZXgoeCwgJ3N1bScpO1xuXG4gIGxldCAkeDtcbiAgaWYgKHguZHR5cGUgPT09ICdib29sJykge1xuICAgICR4ID0gY2FzdCh7aW5wdXRzOiB7eH0sIGJhY2tlbmQsIGF0dHJzOiB7ZHR5cGU6ICdpbnQzMid9fSk7XG4gIH0gZWxzZSB7XG4gICAgJHggPSBpZGVudGl0eSh7aW5wdXRzOiB7eH0sIGJhY2tlbmR9KTtcbiAgfVxuXG4gIGNvbnN0IHhSYW5rID0gJHguc2hhcGUubGVuZ3RoO1xuICBjb25zdCBheGVzID0gdXRpbC5wYXJzZUF4aXNQYXJhbShheGlzLCAkeC5zaGFwZSk7XG4gIGNvbnN0IHBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldEF4ZXNQZXJtdXRhdGlvbihheGVzLCB4UmFuayk7XG5cbiAgbGV0IHJlZHVjdGlvbkF4ZXMgPSBheGVzO1xuICBsZXQgcGVybXV0ZWRYID0gJHg7XG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgcGVybXV0ZWRYID1cbiAgICAgICAgdHJhbnNwb3NlKHtpbnB1dHM6IHt4OiAkeH0sIGJhY2tlbmQsIGF0dHJzOiB7cGVybTogcGVybXV0YXRpb259fSk7XG4gICAgcmVkdWN0aW9uQXhlcyA9IGJhY2tlbmRfdXRpbC5nZXRJbm5lck1vc3RBeGVzKHJlZHVjdGlvbkF4ZXMubGVuZ3RoLCB4UmFuayk7XG4gIH1cblxuICBiYWNrZW5kX3V0aWwuYXNzZXJ0QXhlc0FyZUlubmVyTW9zdERpbXMoXG4gICAgICAnc3VtJywgcmVkdWN0aW9uQXhlcywgcGVybXV0ZWRYLnNoYXBlLmxlbmd0aCk7XG5cbiAgY29uc3QgW291dFNoYXBlLCByZWR1Y2VTaGFwZV0gPVxuICAgICAgYmFja2VuZF91dGlsLmNvbXB1dGVPdXRBbmRSZWR1Y2VTaGFwZXMocGVybXV0ZWRYLnNoYXBlLCByZWR1Y3Rpb25BeGVzKTtcbiAgY29uc3QgcmVzdWx0RHR5cGUgPSBiYWNrZW5kX3V0aWwudXBjYXN0VHlwZShwZXJtdXRlZFguZHR5cGUsICdpbnQzMicpO1xuICBsZXQgcmVzdWx0ID0gemVyb3MoYmFja2VuZCwgb3V0U2hhcGUsIHJlc3VsdER0eXBlKTtcbiAgY29uc3QgcmVkdWNlU2l6ZSA9IHV0aWwuc2l6ZUZyb21TaGFwZShyZWR1Y2VTaGFwZSk7XG4gIGNvbnN0IHZhbHMgPSBiYWNrZW5kLmRhdGEuZ2V0KHJlc3VsdC5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuXG4gIGNvbnN0IGFWYWxzID0gYmFja2VuZC5kYXRhLmdldChwZXJtdXRlZFguZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCB2YWxzLmxlbmd0aDsgKytpKSB7XG4gICAgY29uc3Qgb2Zmc2V0ID0gaSAqIHJlZHVjZVNpemU7XG4gICAgbGV0IHN1bSA9IDA7XG4gICAgZm9yIChsZXQgaiA9IDA7IGogPCByZWR1Y2VTaXplOyArK2opIHtcbiAgICAgIHN1bSArPSBhVmFsc1tvZmZzZXQgKyBqXTtcbiAgICB9XG4gICAgdmFsc1tpXSA9IHN1bTtcbiAgfVxuXG4gIGlmIChrZWVwRGltcykge1xuICAgIGNvbnN0IG5ld1NoYXBlID0gYmFja2VuZF91dGlsLmV4cGFuZFNoYXBlVG9LZWVwRGltKHJlc3VsdC5zaGFwZSwgYXhlcyk7XG4gICAgY29uc3Qgb2xkUmVzdWx0ID0gcmVzdWx0O1xuICAgIHJlc3VsdCA9IHJlc2hhcGUoe2lucHV0czoge3g6IHJlc3VsdH0sIGJhY2tlbmQsIGF0dHJzOiB7c2hhcGU6IG5ld1NoYXBlfX0pO1xuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ob2xkUmVzdWx0KTtcbiAgfVxuXG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oJHgpO1xuXG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwZXJtdXRlZFgpO1xuICB9XG5cbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuZXhwb3J0IGNvbnN0IHN1bUNvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBTdW0sXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogc3VtIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==