/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Mean, util } from '@tensorflow/tfjs-core'; import { meanImpl } from './Mean_impl'; import { transposeImpl, transposeImplCPU } from './Transpose_impl'; export const meanConfig = { kernelName: Mean, backendName: 'webgl', kernelFunc: ({ inputs, attrs, backend }) => { const { x } = inputs; const { keepDims, axis } = attrs; const webglBackend = backend; const xRank = x.shape.length; const origAxes = util.parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = backend_util.getAxesPermutation(axes, xRank); const meanInputIsTransposed = permutedAxes != null; const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]); const intermediates = []; let meanInput = x; if (meanInputIsTransposed) { if (shouldExecuteOnCPU) { const xTexData = webglBackend.texData.get(meanInput.dataId); const values = xTexData.values; const newShape = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = x.shape[permutedAxes[i]]; } const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); meanInput = webglBackend.makeTensorInfo(newShape, x.dtype); const meanInputData = webglBackend.texData.get(meanInput.dataId); meanInputData.values = meanInputValues; } else { meanInput = transposeImpl(x, permutedAxes, webglBackend); } intermediates.push(meanInput); axes = backend_util.getInnerMostAxes(axes.length, xRank); } backend_util.assertAxesAreInnerMostDims('sum', axes, xRank); const [meanOutShape, reduceShape] = backend_util.computeOutAndReduceShapes(meanInput.shape, axes); let outShape = meanOutShape; if (keepDims) { // rather than reshape at the end, set the target shape here. outShape = backend_util.expandShapeToKeepDim(meanOutShape, origAxes); } const out = meanImpl(meanInput, reduceShape, outShape, webglBackend); for (const i of intermediates) { webglBackend.disposeIntermediateTensorInfo(i); } return out; } }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWVhbi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9NZWFuLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQWdCLElBQUksRUFBaUQsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJNUgsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUNyQyxPQUFPLEVBQUMsYUFBYSxFQUFFLGdCQUFnQixFQUFDLE1BQU0sa0JBQWtCLENBQUM7QUFFakUsTUFBTSxDQUFDLE1BQU0sVUFBVSxHQUFpQjtJQUN0QyxVQUFVLEVBQUUsSUFBSTtJQUNoQixXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsQ0FBQyxFQUFDLE1BQU0sRUFBRSxLQUFLLEVBQUUsT0FBTyxFQUFDLEVBQUUsRUFBRTtRQUN2QyxNQUFNLEVBQUMsQ0FBQyxFQUFDLEdBQUcsTUFBb0IsQ0FBQztRQUNqQyxNQUFNLEVBQUMsUUFBUSxFQUFFLElBQUksRUFBQyxHQUFHLEtBQTZCLENBQUM7UUFDdkQsTUFBTSxZQUFZLEdBQUcsT0FBMkIsQ0FBQztRQUVqRCxNQUFNLEtBQUssR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztRQUM3QixNQUFNLFFBQVEsR0FBRyxJQUFJLENBQUMsY0FBYyxDQUFDLElBQUksRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUM7UUFFcEQsSUFBSSxJQUFJLEdBQUcsUUFBUSxDQUFDO1FBQ3BCLE1BQU0sWUFBWSxHQUFHLFlBQVksQ0FBQyxrQkFBa0IsQ0FBQyxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7UUFDbEUsTUFBTSxxQkFBcUIsR0FBRyxZQUFZLElBQUksSUFBSSxDQUFDO1FBQ25ELE1BQU0sa0JBQWtCLEdBQUcsWUFBWSxDQUFDLGtCQUFrQixDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUVoRSxNQUFNLGFBQWEsR0FBaUIsRUFBRSxDQUFDO1FBRXZDLElBQUksU0FBUyxHQUFHLENBQUMsQ0FBQztRQUNsQixJQUFJLHFCQUFxQixFQUFFO1lBQ3pCLElBQUksa0JBQWtCLEVBQUU7Z0JBQ3RCLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQztnQkFDNUQsTUFBTSxNQUFNLEdBQUcsUUFBUSxDQUFDLE1BQW9CLENBQUM7Z0JBRTdDLE1BQU0sUUFBUSxHQUFhLElBQUksS0FBSyxDQUFDLEtBQUssQ0FBQyxDQUFDO2dCQUM1QyxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsUUFBUSxDQUFDLE1BQU0sRUFBRSxDQUFDLEVBQUUsRUFBRTtvQkFDeEMsUUFBUSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsWUFBWSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7aUJBQ3hDO2dCQUNELE1BQU0sZUFBZSxHQUNqQixnQkFBZ0IsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLFlBQVksRUFBRSxRQUFRLENBQUMsQ0FBQztnQkFFdkUsU0FBUyxHQUFHLFlBQVksQ0FBQyxjQUFjLENBQUMsUUFBUSxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztnQkFDM0QsTUFBTSxhQUFhLEdBQUcsWUFBWSxDQUFDLE9BQU8sQ0FBQyxHQUFHLENBQUMsU0FBUyxDQUFDLE1BQU0sQ0FBQyxDQUFDO2dCQUNqRSxhQUFhLENBQUMsTUFBTSxHQUFHLGVBQWUsQ0FBQzthQUN4QztpQkFBTTtnQkFDTCxTQUFTLEdBQUcsYUFBYSxDQUFDLENBQUMsRUFBRSxZQUFZLEVBQUUsWUFBWSxDQUFDLENBQUM7YUFDMUQ7WUFFRCxhQUFhLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxDQUFDO1lBQzlCLElBQUksR0FBRyxZQUFZLENBQUMsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztTQUMxRDtRQUVELFlBQVksQ0FBQywwQkFBMEIsQ0FBQyxLQUFLLEVBQUUsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO1FBQzVELE1BQU0sQ0FBQyxZQUFZLEVBQUUsV0FBVyxDQUFDLEdBQzdCLFlBQVksQ0FBQyx5QkFBeUIsQ0FBQyxTQUFTLENBQUMsS0FBSyxFQUFFLElBQUksQ0FBQyxDQUFDO1FBRWxFLElBQUksUUFBUSxHQUFHLFlBQVksQ0FBQztRQUM1QixJQUFJLFFBQVEsRUFBRTtZQUNaLDZEQUE2RDtZQUM3RCxRQUFRLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFlBQVksRUFBRSxRQUFRLENBQUMsQ0FBQztTQUN0RTtRQUVELE1BQU0sR0FBRyxHQUFHLFFBQVEsQ0FBQyxTQUFTLEVBQUUsV0FBVyxFQUFFLFFBQVEsRUFBRSxZQUFZLENBQUMsQ0FBQztRQUNyRSxLQUFLLE1BQU0sQ0FBQyxJQUFJLGFBQWEsRUFBRTtZQUM3QixZQUFZLENBQUMsNkJBQTZCLENBQUMsQ0FBQyxDQUFDLENBQUM7U0FDL0M7UUFFRCxPQUFPLEdBQUcsQ0FBQztJQUNiLENBQUM7Q0FDRixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBNZWFuLCBNZWFuQXR0cnMsIE1lYW5JbnB1dHMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5cbmltcG9ydCB7bWVhbkltcGx9IGZyb20gJy4vTWVhbl9pbXBsJztcbmltcG9ydCB7dHJhbnNwb3NlSW1wbCwgdHJhbnNwb3NlSW1wbENQVX0gZnJvbSAnLi9UcmFuc3Bvc2VfaW1wbCc7XG5cbmV4cG9ydCBjb25zdCBtZWFuQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IE1lYW4sXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiAoe2lucHV0cywgYXR0cnMsIGJhY2tlbmR9KSA9PiB7XG4gICAgY29uc3Qge3h9ID0gaW5wdXRzIGFzIE1lYW5JbnB1dHM7XG4gICAgY29uc3Qge2tlZXBEaW1zLCBheGlzfSA9IGF0dHJzIGFzIHVua25vd24gYXMgTWVhbkF0dHJzO1xuICAgIGNvbnN0IHdlYmdsQmFja2VuZCA9IGJhY2tlbmQgYXMgTWF0aEJhY2tlbmRXZWJHTDtcblxuICAgIGNvbnN0IHhSYW5rID0geC5zaGFwZS5sZW5ndGg7XG4gICAgY29uc3Qgb3JpZ0F4ZXMgPSB1dGlsLnBhcnNlQXhpc1BhcmFtKGF4aXMsIHguc2hhcGUpO1xuXG4gICAgbGV0IGF4ZXMgPSBvcmlnQXhlcztcbiAgICBjb25zdCBwZXJtdXRlZEF4ZXMgPSBiYWNrZW5kX3V0aWwuZ2V0QXhlc1Blcm11dGF0aW9uKGF4ZXMsIHhSYW5rKTtcbiAgICBjb25zdCBtZWFuSW5wdXRJc1RyYW5zcG9zZWQgPSBwZXJtdXRlZEF4ZXMgIT0gbnVsbDtcbiAgICBjb25zdCBzaG91bGRFeGVjdXRlT25DUFUgPSB3ZWJnbEJhY2tlbmQuc2hvdWxkRXhlY3V0ZU9uQ1BVKFt4XSk7XG5cbiAgICBjb25zdCBpbnRlcm1lZGlhdGVzOiBUZW5zb3JJbmZvW10gPSBbXTtcblxuICAgIGxldCBtZWFuSW5wdXQgPSB4O1xuICAgIGlmIChtZWFuSW5wdXRJc1RyYW5zcG9zZWQpIHtcbiAgICAgIGlmIChzaG91bGRFeGVjdXRlT25DUFUpIHtcbiAgICAgICAgY29uc3QgeFRleERhdGEgPSB3ZWJnbEJhY2tlbmQudGV4RGF0YS5nZXQobWVhbklucHV0LmRhdGFJZCk7XG4gICAgICAgIGNvbnN0IHZhbHVlcyA9IHhUZXhEYXRhLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuXG4gICAgICAgIGNvbnN0IG5ld1NoYXBlOiBudW1iZXJbXSA9IG5ldyBBcnJheSh4UmFuayk7XG4gICAgICAgIGZvciAobGV0IGkgPSAwOyBpIDwgbmV3U2hhcGUubGVuZ3RoOyBpKyspIHtcbiAgICAgICAgICBuZXdTaGFwZVtpXSA9IHguc2hhcGVbcGVybXV0ZWRBeGVzW2ldXTtcbiAgICAgICAgfVxuICAgICAgICBjb25zdCBtZWFuSW5wdXRWYWx1ZXMgPVxuICAgICAgICAgICAgdHJhbnNwb3NlSW1wbENQVSh2YWx1ZXMsIHguc2hhcGUsIHguZHR5cGUsIHBlcm11dGVkQXhlcywgbmV3U2hhcGUpO1xuXG4gICAgICAgIG1lYW5JbnB1dCA9IHdlYmdsQmFja2VuZC5tYWtlVGVuc29ySW5mbyhuZXdTaGFwZSwgeC5kdHlwZSk7XG4gICAgICAgIGNvbnN0IG1lYW5JbnB1dERhdGEgPSB3ZWJnbEJhY2tlbmQudGV4RGF0YS5nZXQobWVhbklucHV0LmRhdGFJZCk7XG4gICAgICAgIG1lYW5JbnB1dERhdGEudmFsdWVzID0gbWVhbklucHV0VmFsdWVzO1xuICAgICAgfSBlbHNlIHtcbiAgICAgICAgbWVhbklucHV0ID0gdHJhbnNwb3NlSW1wbCh4LCBwZXJtdXRlZEF4ZXMsIHdlYmdsQmFja2VuZCk7XG4gICAgICB9XG5cbiAgICAgIGludGVybWVkaWF0ZXMucHVzaChtZWFuSW5wdXQpO1xuICAgICAgYXhlcyA9IGJhY2tlbmRfdXRpbC5nZXRJbm5lck1vc3RBeGVzKGF4ZXMubGVuZ3RoLCB4UmFuayk7XG4gICAgfVxuXG4gICAgYmFja2VuZF91dGlsLmFzc2VydEF4ZXNBcmVJbm5lck1vc3REaW1zKCdzdW0nLCBheGVzLCB4UmFuayk7XG4gICAgY29uc3QgW21lYW5PdXRTaGFwZSwgcmVkdWNlU2hhcGVdID1cbiAgICAgICAgYmFja2VuZF91dGlsLmNvbXB1dGVPdXRBbmRSZWR1Y2VTaGFwZXMobWVhbklucHV0LnNoYXBlLCBheGVzKTtcblxuICAgIGxldCBvdXRTaGFwZSA9IG1lYW5PdXRTaGFwZTtcbiAgICBpZiAoa2VlcERpbXMpIHtcbiAgICAgIC8vIHJhdGhlciB0aGFuIHJlc2hhcGUgYXQgdGhlIGVuZCwgc2V0IHRoZSB0YXJnZXQgc2hhcGUgaGVyZS5cbiAgICAgIG91dFNoYXBlID0gYmFja2VuZF91dGlsLmV4cGFuZFNoYXBlVG9LZWVwRGltKG1lYW5PdXRTaGFwZSwgb3JpZ0F4ZXMpO1xuICAgIH1cblxuICAgIGNvbnN0IG91dCA9IG1lYW5JbXBsKG1lYW5JbnB1dCwgcmVkdWNlU2hhcGUsIG91dFNoYXBlLCB3ZWJnbEJhY2tlbmQpO1xuICAgIGZvciAoY29uc3QgaSBvZiBpbnRlcm1lZGlhdGVzKSB7XG4gICAgICB3ZWJnbEJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oaSk7XG4gICAgfVxuXG4gICAgcmV0dXJuIG91dDtcbiAgfVxufTtcbiJdfQ==