gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Mean, util } from '@tensorflow/tfjs-core';
import { meanImpl } from './Mean_impl';
import { transposeImpl, transposeImplCPU } from './Transpose_impl';
export const meanConfig = {
    kernelName: Mean,
    backendName: 'webgl',
    kernelFunc: ({ inputs, attrs, backend }) => {
        const { x } = inputs;
        const { keepDims, axis } = attrs;
        const webglBackend = backend;
        const xRank = x.shape.length;
        const origAxes = util.parseAxisParam(axis, x.shape);
        let axes = origAxes;
        const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
        const meanInputIsTransposed = permutedAxes != null;
        const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
        const intermediates = [];
        let meanInput = x;
        if (meanInputIsTransposed) {
            if (shouldExecuteOnCPU) {
                const xTexData = webglBackend.texData.get(meanInput.dataId);
                const values = xTexData.values;
                const newShape = new Array(xRank);
                for (let i = 0; i < newShape.length; i++) {
                    newShape[i] = x.shape[permutedAxes[i]];
                }
                const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
                meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
                const meanInputData = webglBackend.texData.get(meanInput.dataId);
                meanInputData.values = meanInputValues;
            }
            else {
                meanInput = transposeImpl(x, permutedAxes, webglBackend);
            }
            intermediates.push(meanInput);
            axes = backend_util.getInnerMostAxes(axes.length, xRank);
        }
        backend_util.assertAxesAreInnerMostDims('sum', axes, xRank);
        const [meanOutShape, reduceShape] = backend_util.computeOutAndReduceShapes(meanInput.shape, axes);
        let outShape = meanOutShape;
        if (keepDims) {
            // rather than reshape at the end, set the target shape here.
            outShape = backend_util.expandShapeToKeepDim(meanOutShape, origAxes);
        }
        const out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
        for (const i of intermediates) {
            webglBackend.disposeIntermediateTensorInfo(i);
        }
        return out;
    }
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWVhbi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9NZWFuLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQWdCLElBQUksRUFBaUQsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJNUgsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUNyQyxPQUFPLEVBQUMsYUFBYSxFQUFFLGdCQUFnQixFQUFDLE1BQU0sa0JBQWtCLENBQUM7QUFFakUsTUFBTSxDQUFDLE1BQU0sVUFBVSxHQUFpQjtJQUN0QyxVQUFVLEVBQUUsSUFBSTtJQUNoQixXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsQ0FBQyxFQUFDLE1BQU0sRUFBRSxLQUFLLEVBQUUsT0FBTyxFQUFDLEVBQUUsRUFBRTtRQUN2QyxNQUFNLEVBQUMsQ0FBQyxFQUFDLEdBQUcsTUFBb0IsQ0FBQztRQUNqQyxNQUFNLEVBQUMsUUFBUSxFQUFFLElBQUksRUFBQyxHQUFHLEtBQTZCLENBQUM7UUFDdkQsTUFBTSxZQUFZLEdBQUcsT0FBMkIsQ0FBQztRQUVqRCxNQUFNLEtBQUssR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztRQUM3QixNQUFNLFFBQVEsR0FBRyxJQUFJLENBQUMsY0FBYyxDQUFDLElBQUksRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUM7UUFFcEQsSUFBSSxJQUFJLEdBQUcsUUFBUSxDQUFDO1FBQ3BCLE1BQU0sWUFBWSxHQUFHLFlBQVksQ0FBQyxrQkFBa0IsQ0FBQyxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7UUFDbEUsTUFBTSxxQkFBcUIsR0FBRyxZQUFZLElBQUksSUFBSSxDQUFDO1FBQ25ELE1BQU0sa0JBQWtCLEdBQUcsWUFBWSxDQUFDLGtCQUFrQixDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUVoRSxNQUFNLGFBQWEsR0FBaUIsRUFBRSxDQUFDO1FBRXZDLElBQUksU0FBUyxHQUFHLENBQUMsQ0FBQztRQUNsQixJQUFJLHFCQUFxQixFQUFFO1lBQ3pCLElBQUksa0JBQWtCLEVBQUU7Z0JBQ3RCLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQztnQkFDNUQsTUFBTSxNQUFNLEdBQUcsUUFBUSxDQUFDLE1BQW9CLENBQUM7Z0JBRTdDLE1BQU0sUUFBUSxHQUFhLElBQUksS0FBSyxDQUFDLEtBQUssQ0FBQyxDQUFDO2dCQUM1QyxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsUUFBUSxDQUFDLE1BQU0sRUFBRSxDQUFDLEVBQUUsRUFBRTtvQkFDeEMsUUFBUSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsWUFBWSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7aUJBQ3hDO2dCQUNELE1BQU0sZUFBZSxHQUNqQixnQkFBZ0IsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLFlBQVksRUFBRSxRQUFRLENBQUMsQ0FBQztnQkFFdkUsU0FBUyxHQUFHLFlBQVksQ0FBQyxjQUFjLENBQUMsUUFBUSxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztnQkFDM0QsTUFBTSxhQUFhLEdBQUcsWUFBWSxDQUFDLE9BQU8sQ0FBQyxHQUFHLENBQUMsU0FBUyxDQUFDLE1BQU0sQ0FBQyxDQUFDO2dCQUNqRSxhQUFhLENBQUMsTUFBTSxHQUFHLGVBQWUsQ0FBQzthQUN4QztpQkFBTTtnQkFDTCxTQUFTLEdBQUcsYUFBYSxDQUFDLENBQUMsRUFBRSxZQUFZLEVBQUUsWUFBWSxDQUFDLENBQUM7YUFDMUQ7WUFFRCxhQUFhLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxDQUFDO1lBQzlCLElBQUksR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztTQUMxRDtRQUVELFlBQVksQ0FBQywwQkFBMEIsQ0FBQyxLQUFLLEVBQUUsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO1FBQzVELE1BQU0sQ0FBQyxZQUFZLEVBQUUsV0FBVyxDQUFDLEdBQzdCLFlBQVksQ0FBQyx5QkFBeUIsQ0FBQyxTQUFTLENBQUMsS0FBSyxFQUFFLElBQUksQ0FBQyxDQUFDO1FBRWxFLElBQUksUUFBUSxHQUFHLFlBQVksQ0FBQztRQUM1QixJQUFJLFFBQVEsRUFBRTtZQUNaLDZEQUE2RDtZQUM3RCxRQUFRLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFlBQVksRUFBRSxRQUFRLENBQUMsQ0FBQztTQUN0RTtRQUVELE1BQU0sR0FBRyxHQUFHLFFBQVEsQ0FBQyxTQUFTLEVBQUUsV0FBVyxFQUFFLFFBQVEsRUFBRSxZQUFZLENBQUMsQ0FBQztRQUNyRSxLQUFLLE1BQU0sQ0FBQyxJQUFJLGFBQWEsRUFBRTtZQUM3QixZQUFZLENBQUMsNkJBQTZCLENBQUMsQ0FBQyxDQUFDLENBQUM7U0FDL0M7UUFFRCxPQUFPLEdBQUcsQ0FBQztJQUNiLENBQUM7Q0FDRixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBNZWFuLCBNZWFuQXR0cnMsIE1lYW5JbnB1dHMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5cbmltcG9ydCB7bWVhbkltcGx9IGZyb20gJy4vTWVhbl9pbXBsJztcbmltcG9ydCB7dHJhbnNwb3NlSW1wbCwgdHJhbnNwb3NlSW1wbENQVX0gZnJvbSAnLi9UcmFuc3Bvc2VfaW1wbCc7XG5cbmV4cG9ydCBjb25zdCBtZWFuQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IE1lYW4sXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiAoe2lucHV0cywgYXR0cnMsIGJhY2tlbmR9KSA9PiB7XG4gICAgY29uc3Qge3h9ID0gaW5wdXRzIGFzIE1lYW5JbnB1dHM7XG4gICAgY29uc3Qge2tlZXBEaW1zLCBheGlzfSA9IGF0dHJzIGFzIHVua25vd24gYXMgTWVhbkF0dHJzO1xuICAgIGNvbnN0IHdlYmdsQmFja2VuZCA9IGJhY2tlbmQgYXMgTWF0aEJhY2tlbmRXZWJHTDtcblxuICAgIGNvbnN0IHhSYW5rID0geC5zaGFwZS5sZW5ndGg7XG4gICAgY29uc3Qgb3JpZ0F4ZXMgPSB1dGlsLnBhcnNlQXhpc1BhcmFtKGF4aXMsIHguc2hhcGUpO1xuXG4gICAgbGV0IGF4ZXMgPSBvcmlnQXhlcztcbiAgICBjb25zdCBwZXJtdXRlZEF4ZXMgPSBiYWNrZW5kX3V0aWwuZ2V0QXhlc1Blcm11dGF0aW9uKGF4ZXMsIHhSYW5rKTtcbiAgICBjb25zdCBtZWFuSW5wdXRJc1RyYW5zcG9zZWQgPSBwZXJtdXRlZEF4ZXMgIT0gbnVsbDtcbiAgICBjb25zdCBzaG91bGRFeGVjdXRlT25DUFUgPSB3ZWJnbEJhY2tlbmQuc2hvdWxkRXhlY3V0ZU9uQ1BVKFt4XSk7XG5cbiAgICBjb25zdCBpbnRlcm1lZGlhdGVzOiBUZW5zb3JJbmZvW10gPSBbXTtcblxuICAgIGxldCBtZWFuSW5wdXQgPSB4O1xuICAgIGlmIChtZWFuSW5wdXRJc1RyYW5zcG9zZWQpIHtcbiAgICAgIGlmIChzaG91bGRFeGVjdXRlT25DUFUpIHtcbiAgICAgICAgY29uc3QgeFRleERhdGEgPSB3ZWJnbEJhY2tlbmQudGV4RGF0YS5nZXQobWVhbklucHV0LmRhdGFJZCk7XG4gICAgICAgIGNvbnN0IHZhbHVlcyA9IHhUZXhEYXRhLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuXG4gICAgICAgIGNvbnN0IG5ld1NoYXBlOiBudW1iZXJbXSA9IG5ldyBBcnJheSh4UmFuayk7XG4gICAgICAgIGZvciAobGV0IGkgPSAwOyBpIDwgbmV3U2hhcGUubGVuZ3RoOyBpKyspIHtcbiAgICAgICAgICBuZXdTaGFwZVtpXSA9IHguc2hhcGVbcGVybXV0ZWRBeGVzW2ldXTtcbiAgICAgICAgfVxuICAgICAgICBjb25zdCBtZWFuSW5wdXRWYWx1ZXMgPVxuICAgICAgICAgICAgdHJhbnNwb3NlSW1wbENQVSh2YWx1ZXMsIHguc2hhcGUsIHguZHR5cGUsIHBlcm11dGVkQXhlcywgbmV3U2hhcGUpO1xuXG4gICAgICAgIG1lYW5JbnB1dCA9IHdlYmdsQmFja2VuZC5tYWtlVGVuc29ySW5mbyhuZXdTaGFwZSwgeC5kdHlwZSk7XG4gICAgICAgIGNvbnN0IG1lYW5JbnB1dERhdGEgPSB3ZWJnbEJhY2tlbmQudGV4RGF0YS5nZXQobWVhbklucHV0LmRhdGFJZCk7XG4gICAgICAgIG1lYW5JbnB1dERhdGEudmFsdWVzID0gbWVhbklucHV0VmFsdWVzO1xuICAgICAgfSBlbHNlIHtcbiAgICAgICAgbWVhbklucHV0ID0gdHJhbnNwb3NlSW1wbCh4LCBwZXJtdXRlZEF4ZXMsIHdlYmdsQmFja2VuZCk7XG4gICAgICB9XG5cbiAgICAgIGludGVybWVkaWF0ZXMucHVzaChtZWFuSW5wdXQpO1xuICAgICAgYXhlcyA9IGJhY2tlbmRfdXRpbC5nZXRJbm5lck1vc3RBeGVzKGF4ZXMubGVuZ3RoLCB4UmFuayk7XG4gICAgfVxuXG4gICAgYmFja2VuZF91dGlsLmFzc2VydEF4ZXNBcmVJbm5lck1vc3REaW1zKCdzdW0nLCBheGVzLCB4UmFuayk7XG4gICAgY29uc3QgW21lYW5PdXRTaGFwZSwgcmVkdWNlU2hhcGVdID1cbiAgICAgICAgYmFja2VuZF91dGlsLmNvbXB1dGVPdXRBbmRSZWR1Y2VTaGFwZXMobWVhbklucHV0LnNoYXBlLCBheGVzKTtcblxuICAgIGxldCBvdXRTaGFwZSA9IG1lYW5PdXRTaGFwZTtcbiAgICBpZiAoa2VlcERpbXMpIHtcbiAgICAgIC8vIHJhdGhlciB0aGFuIHJlc2hhcGUgYXQgdGhlIGVuZCwgc2V0IHRoZSB0YXJnZXQgc2hhcGUgaGVyZS5cbiAgICAgIG91dFNoYXBlID0gYmFja2VuZF91dGlsLmV4cGFuZFNoYXBlVG9LZWVwRGltKG1lYW5PdXRTaGFwZSwgb3JpZ0F4ZXMpO1xuICAgIH1cblxuICAgIGNvbnN0IG91dCA9IG1lYW5JbXBsKG1lYW5JbnB1dCwgcmVkdWNlU2hhcGUsIG91dFNoYXBlLCB3ZWJnbEJhY2tlbmQpO1xuICAgIGZvciAoY29uc3QgaSBvZiBpbnRlcm1lZGlhdGVzKSB7XG4gICAgICB3ZWJnbEJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oaSk7XG4gICAgfVxuXG4gICAgcmV0dXJuIG91dDtcbiAgfVxufTtcbiJdfQ==