/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, Mean, util } from '@tensorflow/tfjs-core';
|
import { meanImpl } from './Mean_impl';
|
import { transposeImpl, transposeImplCPU } from './Transpose_impl';
|
export const meanConfig = {
|
kernelName: Mean,
|
backendName: 'webgl',
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
const { x } = inputs;
|
const { keepDims, axis } = attrs;
|
const webglBackend = backend;
|
const xRank = x.shape.length;
|
const origAxes = util.parseAxisParam(axis, x.shape);
|
let axes = origAxes;
|
const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
|
const meanInputIsTransposed = permutedAxes != null;
|
const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
|
const intermediates = [];
|
let meanInput = x;
|
if (meanInputIsTransposed) {
|
if (shouldExecuteOnCPU) {
|
const xTexData = webglBackend.texData.get(meanInput.dataId);
|
const values = xTexData.values;
|
const newShape = new Array(xRank);
|
for (let i = 0; i < newShape.length; i++) {
|
newShape[i] = x.shape[permutedAxes[i]];
|
}
|
const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
|
meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
|
const meanInputData = webglBackend.texData.get(meanInput.dataId);
|
meanInputData.values = meanInputValues;
|
}
|
else {
|
meanInput = transposeImpl(x, permutedAxes, webglBackend);
|
}
|
intermediates.push(meanInput);
|
axes = backend_util.getInnerMostAxes(axes.length, xRank);
|
}
|
backend_util.assertAxesAreInnerMostDims('sum', axes, xRank);
|
const [meanOutShape, reduceShape] = backend_util.computeOutAndReduceShapes(meanInput.shape, axes);
|
let outShape = meanOutShape;
|
if (keepDims) {
|
// rather than reshape at the end, set the target shape here.
|
outShape = backend_util.expandShapeToKeepDim(meanOutShape, origAxes);
|
}
|
const out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
|
for (const i of intermediates) {
|
webglBackend.disposeIntermediateTensorInfo(i);
|
}
|
return out;
|
}
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWVhbi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9NZWFuLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQWdCLElBQUksRUFBaUQsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJNUgsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUNyQyxPQUFPLEVBQUMsYUFBYSxFQUFFLGdCQUFnQixFQUFDLE1BQU0sa0JBQWtCLENBQUM7QUFFakUsTUFBTSxDQUFDLE1BQU0sVUFBVSxHQUFpQjtJQUN0QyxVQUFVLEVBQUUsSUFBSTtJQUNoQixXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsQ0FBQyxFQUFDLE1BQU0sRUFBRSxLQUFLLEVBQUUsT0FBTyxFQUFDLEVBQUUsRUFBRTtRQUN2QyxNQUFNLEVBQUMsQ0FBQyxFQUFDLEdBQUcsTUFBb0IsQ0FBQztRQUNqQyxNQUFNLEVBQUMsUUFBUSxFQUFFLElBQUksRUFBQyxHQUFHLEtBQTZCLENBQUM7UUFDdkQsTUFBTSxZQUFZLEdBQUcsT0FBMkIsQ0FBQztRQUVqRCxNQUFNLEtBQUssR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztRQUM3QixNQUFNLFFBQVEsR0FBRyxJQUFJLENBQUMsY0FBYyxDQUFDLElBQUksRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUM7UUFFcEQsSUFBSSxJQUFJLEdBQUcsUUFBUSxDQUFDO1FBQ3BCLE1BQU0sWUFBWSxHQUFHLFlBQVksQ0FBQyxrQkFBa0IsQ0FBQyxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7UUFDbEUsTUFBTSxxQkFBcUIsR0FBRyxZQUFZLElBQUksSUFBSSxDQUFDO1FBQ25ELE1BQU0sa0JBQWtCLEdBQUcsWUFBWSxDQUFDLGtCQUFrQixDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUVoRSxNQUFNLGFBQWEsR0FBaUIsRUFBRSxDQUFDO1FBRXZDLElBQUksU0FBUyxHQUFHLENBQUMsQ0FBQztRQUNsQixJQUFJLHFCQUFxQixFQUFFO1lBQ3pCLElBQUksa0JBQWtCLEVBQUU7Z0JBQ3RCLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQztnQkFDNUQsTUFBTSxNQUFNLEdBQUcsUUFBUSxDQUFDLE1BQW9CLENBQUM7Z0JBRTdDLE1BQU0sUUFBUSxHQUFhLElBQUksS0FBSyxDQUFDLEtBQUssQ0FBQyxDQUFDO2dCQUM1QyxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsUUFBUSxDQUFDLE1BQU0sRUFBRSxDQUFDLEVBQUUsRUFBRTtvQkFDeEMsUUFBUSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsWUFBWSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7aUJBQ3hDO2dCQUNELE1BQU0sZUFBZSxHQUNqQixnQkFBZ0IsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLFlBQVksRUFBRSxRQUFRLENBQUMsQ0FBQztnQkFFdkUsU0FBUyxHQUFHLFlBQVksQ0FBQyxjQUFjLENBQUMsUUFBUSxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztnQkFDM0QsTUFBTSxhQUFhLEdBQUcsWUFBWSxDQUFDLE9BQU8sQ0FBQyxHQUFHLENBQUMsU0FBUyxDQUFDLE1BQU0sQ0FBQyxDQUFDO2dCQUNqRSxhQUFhLENBQUMsTUFBTSxHQUFHLGVBQWUsQ0FBQzthQUN4QztpQkFBTTtnQkFDTCxTQUFTLEdBQUcsYUFBYSxDQUFDLENBQUMsRUFBRSxZQUFZLEVBQUUsWUFBWSxDQUFDLENBQUM7YUFDMUQ7WUFFRCxhQUFhLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxDQUFDO1lBQzlCLElBQUksR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztTQUMxRDtRQUVELFlBQVksQ0FBQywwQkFBMEIsQ0FBQyxLQUFLLEVBQUUsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO1FBQzVELE1BQU0sQ0FBQyxZQUFZLEVBQUUsV0FBVyxDQUFDLEdBQzdCLFlBQVksQ0FBQyx5QkFBeUIsQ0FBQyxTQUFTLENBQUMsS0FBSyxFQUFFLElBQUksQ0FBQyxDQUFDO1FBRWxFLElBQUksUUFBUSxHQUFHLFlBQVksQ0FBQztRQUM1QixJQUFJLFFBQVEsRUFBRTtZQUNaLDZEQUE2RDtZQUM3RCxRQUFRLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFlBQVksRUFBRSxRQUFRLENBQUMsQ0FBQztTQUN0RTtRQUVELE1BQU0sR0FBRyxHQUFHLFFBQVEsQ0FBQyxTQUFTLEVBQUUsV0FBVyxFQUFFLFFBQVEsRUFBRSxZQUFZLENBQUMsQ0FBQztRQUNyRSxLQUFLLE1BQU0sQ0FBQyxJQUFJLGFBQWEsRUFBRTtZQUM3QixZQUFZLENBQUMsNkJBQTZCLENBQUMsQ0FBQyxDQUFDLENBQUM7U0FDL0M7UUFFRCxPQUFPLEdBQUcsQ0FBQztJQUNiLENBQUM7Q0FDRixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBNZWFuLCBNZWFuQXR0cnMsIE1lYW5JbnB1dHMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5cbmltcG9ydCB7bWVhbkltcGx9IGZyb20gJy4vTWVhbl9pbXBsJztcbmltcG9ydCB7dHJhbnNwb3NlSW1wbCwgdHJhbnNwb3NlSW1wbENQVX0gZnJvbSAnLi9UcmFuc3Bvc2VfaW1wbCc7XG5cbmV4cG9ydCBjb25zdCBtZWFuQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IE1lYW4sXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiAoe2lucHV0cywgYXR0cnMsIGJhY2tlbmR9KSA9PiB7XG4gICAgY29uc3Qge3h9ID0gaW5wdXRzIGFzIE1lYW5JbnB1dHM7XG4gICAgY29uc3Qge2tlZXBEaW1zLCBheGlzfSA9IGF0dHJzIGFzIHVua25vd24gYXMgTWVhbkF0dHJzO1xuICAgIGNvbnN0IHdlYmdsQmFja2VuZCA9IGJhY2tlbmQgYXMgTWF0aEJhY2tlbmRXZWJHTDtcblxuICAgIGNvbnN0IHhSYW5rID0geC5zaGFwZS5sZW5ndGg7XG4gICAgY29uc3Qgb3JpZ0F4ZXMgPSB1dGlsLnBhcnNlQXhpc1BhcmFtKGF4aXMsIHguc2hhcGUpO1xuXG4gICAgbGV0IGF4ZXMgPSBvcmlnQXhlcztcbiAgICBjb25zdCBwZXJtdXRlZEF4ZXMgPSBiYWNrZW5kX3V0aWwuZ2V0QXhlc1Blcm11dGF0aW9uKGF4ZXMsIHhSYW5rKTtcbiAgICBjb25zdCBtZWFuSW5wdXRJc1RyYW5zcG9zZWQgPSBwZXJtdXRlZEF4ZXMgIT0gbnVsbDtcbiAgICBjb25zdCBzaG91bGRFeGVjdXRlT25DUFUgPSB3ZWJnbEJhY2tlbmQuc2hvdWxkRXhlY3V0ZU9uQ1BVKFt4XSk7XG5cbiAgICBjb25zdCBpbnRlcm1lZGlhdGVzOiBUZW5zb3JJbmZvW10gPSBbXTtcblxuICAgIGxldCBtZWFuSW5wdXQgPSB4O1xuICAgIGlmIChtZWFuSW5wdXRJc1RyYW5zcG9zZWQpIHtcbiAgICAgIGlmIChzaG91bGRFeGVjdXRlT25DUFUpIHtcbiAgICAgICAgY29uc3QgeFRleERhdGEgPSB3ZWJnbEJhY2tlbmQudGV4RGF0YS5nZXQobWVhbklucHV0LmRhdGFJZCk7XG4gICAgICAgIGNvbnN0IHZhbHVlcyA9IHhUZXhEYXRhLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuXG4gICAgICAgIGNvbnN0IG5ld1NoYXBlOiBudW1iZXJbXSA9IG5ldyBBcnJheSh4UmFuayk7XG4gICAgICAgIGZvciAobGV0IGkgPSAwOyBpIDwgbmV3U2hhcGUubGVuZ3RoOyBpKyspIHtcbiAgICAgICAgICBuZXdTaGFwZVtpXSA9IHguc2hhcGVbcGVybXV0ZWRBeGVzW2ldXTtcbiAgICAgICAgfVxuICAgICAgICBjb25zdCBtZWFuSW5wdXRWYWx1ZXMgPVxuICAgICAgICAgICAgdHJhbnNwb3NlSW1wbENQVSh2YWx1ZXMsIHguc2hhcGUsIHguZHR5cGUsIHBlcm11dGVkQXhlcywgbmV3U2hhcGUpO1xuXG4gICAgICAgIG1lYW5JbnB1dCA9IHdlYmdsQmFja2VuZC5tYWtlVGVuc29ySW5mbyhuZXdTaGFwZSwgeC5kdHlwZSk7XG4gICAgICAgIGNvbnN0IG1lYW5JbnB1dERhdGEgPSB3ZWJnbEJhY2tlbmQudGV4RGF0YS5nZXQobWVhbklucHV0LmRhdGFJZCk7XG4gICAgICAgIG1lYW5JbnB1dERhdGEudmFsdWVzID0gbWVhbklucHV0VmFsdWVzO1xuICAgICAgfSBlbHNlIHtcbiAgICAgICAgbWVhbklucHV0ID0gdHJhbnNwb3NlSW1wbCh4LCBwZXJtdXRlZEF4ZXMsIHdlYmdsQmFja2VuZCk7XG4gICAgICB9XG5cbiAgICAgIGludGVybWVkaWF0ZXMucHVzaChtZWFuSW5wdXQpO1xuICAgICAgYXhlcyA9IGJhY2tlbmRfdXRpbC5nZXRJbm5lck1vc3RBeGVzKGF4ZXMubGVuZ3RoLCB4UmFuayk7XG4gICAgfVxuXG4gICAgYmFja2VuZF91dGlsLmFzc2VydEF4ZXNBcmVJbm5lck1vc3REaW1zKCdzdW0nLCBheGVzLCB4UmFuayk7XG4gICAgY29uc3QgW21lYW5PdXRTaGFwZSwgcmVkdWNlU2hhcGVdID1cbiAgICAgICAgYmFja2VuZF91dGlsLmNvbXB1dGVPdXRBbmRSZWR1Y2VTaGFwZXMobWVhbklucHV0LnNoYXBlLCBheGVzKTtcblxuICAgIGxldCBvdXRTaGFwZSA9IG1lYW5PdXRTaGFwZTtcbiAgICBpZiAoa2VlcERpbXMpIHtcbiAgICAgIC8vIHJhdGhlciB0aGFuIHJlc2hhcGUgYXQgdGhlIGVuZCwgc2V0IHRoZSB0YXJnZXQgc2hhcGUgaGVyZS5cbiAgICAgIG91dFNoYXBlID0gYmFja2VuZF91dGlsLmV4cGFuZFNoYXBlVG9LZWVwRGltKG1lYW5PdXRTaGFwZSwgb3JpZ0F4ZXMpO1xuICAgIH1cblxuICAgIGNvbnN0IG91dCA9IG1lYW5JbXBsKG1lYW5JbnB1dCwgcmVkdWNlU2hhcGUsIG91dFNoYXBlLCB3ZWJnbEJhY2tlbmQpO1xuICAgIGZvciAoY29uc3QgaSBvZiBpbnRlcm1lZGlhdGVzKSB7XG4gICAgICB3ZWJnbEJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oaSk7XG4gICAgfVxuXG4gICAgcmV0dXJuIG91dDtcbiAgfVxufTtcbiJdfQ==
|