gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/**
 * @license
 * Copyright 2022 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util } from '@tensorflow/tfjs-core';
import { CumProgram } from '../cum_gpu';
import { identity } from './Identity';
import { transpose } from './Transpose';
export function cumImpl(op, x, backend, axis, exclusive, reverse) {
    const xRank = x.shape.length;
    const permutation = backend_util.getAxesPermutation([axis], xRank);
    let permutedX = x;
    if (permutation != null) {
        permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
    }
    const permutedAxis = backend_util.getInnerMostAxes(1, xRank)[0];
    if (permutedAxis !== xRank - 1) {
        throw new Error(`WebGL cumprod shader expects an inner-most axis=${x.shape.length - 1} ` +
            `but got axis=${axis}`);
    }
    const size = permutedX.shape[permutedAxis];
    let result = identity({ inputs: { x: permutedX }, backend });
    // Use cum parallel algorithm, inspired by:
    // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
    // Note: although the algorithm is called sum, it works for any associtative
    // operator with an identity.
    for (let i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
        const program = new CumProgram(op, permutedX.shape, false, reverse);
        const customValues = [[i]];
        const prevResult = result;
        result =
            backend.runWebGLProgram(program, [result], result.dtype, customValues);
        backend.disposeIntermediateTensorInfo(prevResult);
    }
    // For exclusive cum, shift the end result in the direction of product or sum
    // and add 1 for product or 0 for sum to the front index.
    if (exclusive) {
        const program = new CumProgram(op, permutedX.shape, exclusive, reverse);
        const prevResult = result;
        result = backend.runWebGLProgram(program, [result], result.dtype);
        backend.disposeIntermediateTensorInfo(prevResult);
    }
    if (permutation != null) {
        const reversePermutation = backend_util.getUndoAxesPermutation(permutation);
        const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
        backend.disposeIntermediateTensorInfo(result);
        backend.disposeIntermediateTensorInfo(permutedX);
        return reverseTransposedResult;
    }
    return result;
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQ3VtX2ltcGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2tlcm5lbHMvQ3VtX2ltcGwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBYSxNQUFNLHVCQUF1QixDQUFDO0FBRy9ELE9BQU8sRUFBWSxVQUFVLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFFakQsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUNwQyxPQUFPLEVBQUMsU0FBUyxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRXRDLE1BQU0sVUFBVSxPQUFPLENBQ25CLEVBQWEsRUFBRSxDQUFhLEVBQUUsT0FBeUIsRUFBRSxJQUFZLEVBQ3JFLFNBQWtCLEVBQUUsT0FBZ0I7SUFDdEMsTUFBTSxLQUFLLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUM7SUFDN0IsTUFBTSxXQUFXLEdBQUcsWUFBWSxDQUFDLGtCQUFrQixDQUFDLENBQUMsSUFBSSxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFDbkUsSUFBSSxTQUFTLEdBQUcsQ0FBQyxDQUFDO0lBQ2xCLElBQUksV0FBVyxJQUFJLElBQUksRUFBRTtRQUN2QixTQUFTLEdBQUcsU0FBUyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLElBQUksRUFBRSxXQUFXLEVBQUMsRUFBQyxDQUFDLENBQUM7S0FDM0U7SUFDRCxNQUFNLFlBQVksR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsQ0FBQyxFQUFFLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBRWhFLElBQUksWUFBWSxLQUFLLEtBQUssR0FBRyxDQUFDLEVBQUU7UUFDOUIsTUFBTSxJQUFJLEtBQUssQ0FDWCxtREFDSSxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLEdBQUc7WUFDekIsZ0JBQWdCLElBQUksRUFBRSxDQUFDLENBQUM7S0FDN0I7SUFDRCxNQUFNLElBQUksR0FBRyxTQUFTLENBQUMsS0FBSyxDQUFDLFlBQVksQ0FBQyxDQUFDO0lBQzNDLElBQUksTUFBTSxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxTQUFTLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO0lBQ3pELDJDQUEyQztJQUMzQywrR0FBK0c7SUFDL0csNEVBQTRFO0lBQzVFLDZCQUE2QjtJQUU3QixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLElBQUksSUFBSSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFO1FBQ3hELE1BQU0sT0FBTyxHQUFHLElBQUksVUFBVSxDQUFDLEVBQUUsRUFBRSxTQUFTLENBQUMsS0FBSyxFQUFFLEtBQUssRUFBRSxPQUFPLENBQUMsQ0FBQztRQUNwRSxNQUFNLFlBQVksR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUMzQixNQUFNLFVBQVUsR0FBRyxNQUFNLENBQUM7UUFDMUIsTUFBTTtZQUNGLE9BQU8sQ0FBQyxlQUFlLENBQUMsT0FBTyxFQUFFLENBQUMsTUFBTSxDQUFDLEVBQUUsTUFBTSxDQUFDLEtBQUssRUFBRSxZQUFZLENBQUMsQ0FBQztRQUMzRSxPQUFPLENBQUMsNkJBQTZCLENBQUMsVUFBVSxDQUFDLENBQUM7S0FDbkQ7SUFDRCw2RUFBNkU7SUFDN0UseURBQXlEO0lBQ3pELElBQUksU0FBUyxFQUFFO1FBQ2IsTUFBTSxPQUFPLEdBQUcsSUFBSSxVQUFVLENBQUMsRUFBRSxFQUFFLFNBQVMsQ0FBQyxLQUFLLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBQyxDQUFDO1FBQ3hFLE1BQU0sVUFBVSxHQUFHLE1BQU0sQ0FBQztRQUMxQixNQUFNLEdBQUcsT0FBTyxDQUFDLGVBQWUsQ0FBQyxPQUFPLEVBQUUsQ0FBQyxNQUFNLENBQUMsRUFBRSxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUM7UUFDbEUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLFVBQVUsQ0FBQyxDQUFDO0tBQ25EO0lBRUQsSUFBSSxXQUFXLElBQUksSUFBSSxFQUFFO1FBQ3ZCLE1BQU0sa0JBQWtCLEdBQUcsWUFBWSxDQUFDLHNCQUFzQixDQUFDLFdBQVcsQ0FBQyxDQUFDO1FBQzVFLE1BQU0sdUJBQXVCLEdBQUcsU0FBUyxDQUNyQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLGtCQUFrQixFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBRXZFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxNQUFNLENBQUMsQ0FBQztRQUM5QyxPQUFPLENBQUMsNkJBQTZCLENBQUMsU0FBUyxDQUFDLENBQUM7UUFFakQsT0FBTyx1QkFBdUIsQ0FBQztLQUNoQztJQUVELE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMiBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7YmFja2VuZF91dGlsLCBUZW5zb3JJbmZvfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kV2ViR0x9IGZyb20gJy4uL2JhY2tlbmRfd2ViZ2wnO1xuaW1wb3J0IHtDdW1PcFR5cGUsIEN1bVByb2dyYW19IGZyb20gJy4uL2N1bV9ncHUnO1xuXG5pbXBvcnQge2lkZW50aXR5fSBmcm9tICcuL0lkZW50aXR5JztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBjdW1JbXBsKFxuICAgIG9wOiBDdW1PcFR5cGUsIHg6IFRlbnNvckluZm8sIGJhY2tlbmQ6IE1hdGhCYWNrZW5kV2ViR0wsIGF4aXM6IG51bWJlcixcbiAgICBleGNsdXNpdmU6IGJvb2xlYW4sIHJldmVyc2U6IGJvb2xlYW4pOiBUZW5zb3JJbmZvIHtcbiAgY29uc3QgeFJhbmsgPSB4LnNoYXBlLmxlbmd0aDtcbiAgY29uc3QgcGVybXV0YXRpb24gPSBiYWNrZW5kX3V0aWwuZ2V0QXhlc1Blcm11dGF0aW9uKFtheGlzXSwgeFJhbmspO1xuICBsZXQgcGVybXV0ZWRYID0geDtcbiAgaWYgKHBlcm11dGF0aW9uICE9IG51bGwpIHtcbiAgICBwZXJtdXRlZFggPSB0cmFuc3Bvc2Uoe2lucHV0czoge3h9LCBiYWNrZW5kLCBhdHRyczoge3Blcm06IHBlcm11dGF0aW9ufX0pO1xuICB9XG4gIGNvbnN0IHBlcm11dGVkQXhpcyA9IGJhY2tlbmRfdXRpbC5nZXRJbm5lck1vc3RBeGVzKDEsIHhSYW5rKVswXTtcblxuICBpZiAocGVybXV0ZWRBeGlzICE9PSB4UmFuayAtIDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGBXZWJHTCBjdW1wcm9kIHNoYWRlciBleHBlY3RzIGFuIGlubmVyLW1vc3QgYXhpcz0ke1xuICAgICAgICAgICAgeC5zaGFwZS5sZW5ndGggLSAxfSBgICtcbiAgICAgICAgYGJ1dCBnb3QgYXhpcz0ke2F4aXN9YCk7XG4gIH1cbiAgY29uc3Qgc2l6ZSA9IHBlcm11dGVkWC5zaGFwZVtwZXJtdXRlZEF4aXNdO1xuICBsZXQgcmVzdWx0ID0gaWRlbnRpdHkoe2lucHV0czoge3g6IHBlcm11dGVkWH0sIGJhY2tlbmR9KTtcbiAgLy8gVXNlIGN1bSBwYXJhbGxlbCBhbGdvcml0aG0sIGluc3BpcmVkIGJ5OlxuICAvLyBodHRwczovL2RldmVsb3Blci5udmlkaWEuY29tL2dwdWdlbXMvZ3B1Z2VtczMvcGFydC12aS1ncHUtY29tcHV0aW5nL2NoYXB0ZXItMzktcGFyYWxsZWwtcHJlZml4LXN1bS1zY2FuLWN1ZGFcbiAgLy8gTm90ZTogYWx0aG91Z2ggdGhlIGFsZ29yaXRobSBpcyBjYWxsZWQgc3VtLCBpdCB3b3JrcyBmb3IgYW55IGFzc29jaXRhdGl2ZVxuICAvLyBvcGVyYXRvciB3aXRoIGFuIGlkZW50aXR5LlxuXG4gIGZvciAobGV0IGkgPSAwOyBpIDw9IE1hdGguY2VpbChNYXRoLmxvZzIoc2l6ZSkpIC0gMTsgaSsrKSB7XG4gICAgY29uc3QgcHJvZ3JhbSA9IG5ldyBDdW1Qcm9ncmFtKG9wLCBwZXJtdXRlZFguc2hhcGUsIGZhbHNlLCByZXZlcnNlKTtcbiAgICBjb25zdCBjdXN0b21WYWx1ZXMgPSBbW2ldXTtcbiAgICBjb25zdCBwcmV2UmVzdWx0ID0gcmVzdWx0O1xuICAgIHJlc3VsdCA9XG4gICAgICAgIGJhY2tlbmQucnVuV2ViR0xQcm9ncmFtKHByb2dyYW0sIFtyZXN1bHRdLCByZXN1bHQuZHR5cGUsIGN1c3RvbVZhbHVlcyk7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwcmV2UmVzdWx0KTtcbiAgfVxuICAvLyBGb3IgZXhjbHVzaXZlIGN1bSwgc2hpZnQgdGhlIGVuZCByZXN1bHQgaW4gdGhlIGRpcmVjdGlvbiBvZiBwcm9kdWN0IG9yIHN1bVxuICAvLyBhbmQgYWRkIDEgZm9yIHByb2R1Y3Qgb3IgMCBmb3Igc3VtIHRvIHRoZSBmcm9udCBpbmRleC5cbiAgaWYgKGV4Y2x1c2l2ZSkge1xuICAgIGNvbnN0IHByb2dyYW0gPSBuZXcgQ3VtUHJvZ3JhbShvcCwgcGVybXV0ZWRYLnNoYXBlLCBleGNsdXNpdmUsIHJldmVyc2UpO1xuICAgIGNvbnN0IHByZXZSZXN1bHQgPSByZXN1bHQ7XG4gICAgcmVzdWx0ID0gYmFja2VuZC5ydW5XZWJHTFByb2dyYW0ocHJvZ3JhbSwgW3Jlc3VsdF0sIHJlc3VsdC5kdHlwZSk7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwcmV2UmVzdWx0KTtcbiAgfVxuXG4gIGlmIChwZXJtdXRhdGlvbiAhPSBudWxsKSB7XG4gICAgY29uc3QgcmV2ZXJzZVBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldFVuZG9BeGVzUGVybXV0YXRpb24ocGVybXV0YXRpb24pO1xuICAgIGNvbnN0IHJldmVyc2VUcmFuc3Bvc2VkUmVzdWx0ID0gdHJhbnNwb3NlKFxuICAgICAgICB7aW5wdXRzOiB7eDogcmVzdWx0fSwgYmFja2VuZCwgYXR0cnM6IHtwZXJtOiByZXZlcnNlUGVybXV0YXRpb259fSk7XG5cbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc3VsdCk7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwZXJtdXRlZFgpO1xuXG4gICAgcmV0dXJuIHJldmVyc2VUcmFuc3Bvc2VkUmVzdWx0O1xuICB9XG5cbiAgcmV0dXJuIHJlc3VsdDtcbn1cbiJdfQ==