/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, env, util } from '@tensorflow/tfjs-core';
|
import { ArgMinMaxProgram } from '../argminmax_gpu';
|
import { ArgMinMaxPackedProgram } from '../argminmax_packed_gpu';
|
import { reshape } from '../kernels/Reshape';
|
function argReduce(backend, x, reduceType, bestIndicesA = null) {
|
let batchSize = x.shape[0];
|
let inSize = x.shape[1];
|
if (bestIndicesA != null) {
|
batchSize = bestIndicesA.shape[0];
|
inSize = bestIndicesA.shape[1];
|
}
|
const windowSize = backend_util.computeOptimalWindowSize(inSize);
|
const reduceInfo = { windowSize, inSize, batchSize, outSize: Math.ceil(inSize / windowSize) };
|
const program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
|
const inputs = [x];
|
if (bestIndicesA != null) {
|
inputs.push(bestIndicesA);
|
}
|
const output = backend.runWebGLProgram(program, inputs, 'int32');
|
// No need to run another GPGPU program.
|
if (output.shape[1] === 1) {
|
return output;
|
}
|
const result = argReduce(backend, x, reduceType, output);
|
backend.disposeIntermediateTensorInfo(output);
|
return result;
|
}
|
function argReducePacked(backend, x, reduceType, bestIndicesA = null) {
|
const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
|
const inSize = inShape[inShape.length - 1];
|
const windowSize = backend_util.computeOptimalWindowSize(inSize);
|
const program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
|
const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
|
const output = backend.runWebGLProgram(program, inputs, 'int32');
|
if (output.shape.length === x.shape.length) {
|
const result = argReducePacked(backend, x, reduceType, output);
|
backend.disposeIntermediateTensorInfo(output);
|
return result;
|
}
|
return output;
|
}
|
export function argMinMaxReduce(backend, x, axis, reduceType) {
|
const axes = [axis];
|
backend_util.assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length);
|
if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {
|
const intermediateTensorInfos = [];
|
// Eagerly unpack x input since it is passed in to all the shaders which
|
// require unpacked inputs.
|
const xtexData = backend.texData.get(x.dataId);
|
const xIsPacked = xtexData !== null && xtexData.isPacked;
|
let xUnPacked = x;
|
if (xIsPacked) {
|
xUnPacked = backend.unpackTensor(x);
|
intermediateTensorInfos.push(xUnPacked);
|
}
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(xUnPacked.shape, axes);
|
const inSize = util.sizeFromShape(reduceShape);
|
const a2D = reshape({ inputs: { x: xUnPacked }, backend, attrs: { shape: [-1, inSize] } });
|
intermediateTensorInfos.push(a2D);
|
const reduced = argReduce(backend, a2D, reduceType);
|
intermediateTensorInfos.push(reduced);
|
const reshaped = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return reshaped;
|
}
|
return argReducePacked(backend, x, reduceType);
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"arg_min_max.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernel_utils/arg_min_max.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,GAAG,EAAc,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE1E,OAAO,EAAC,gBAAgB,EAAC,MAAM,kBAAkB,CAAC;AAClD,OAAO,EAAC,sBAAsB,EAAC,MAAM,yBAAyB,CAAC;AAE/D,OAAO,EAAC,OAAO,EAAC,MAAM,oBAAoB,CAAC;AAE3C,SAAS,SAAS,CACd,OAAyB,EAAE,CAAa,EAAE,UAAuB,EACjE,eAA2B,IAAI;IACjC,IAAI,SAAS,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAC3B,IAAI,MAAM,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACxB,IAAI,YAAY,IAAI,IAAI,EAAE;QACxB,SAAS,GAAG,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,GAAG,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;KAChC;IACD,MAAM,UAAU,GAAG,YAAY,CAAC,wBAAwB,CAAC,MAAM,CAAC,CAAC;IACjE,MAAM,UAAU,GACZ,EAAC,UAAU,EAAE,MAAM,EAAE,SAAS,EAAE,OAAO,EAAE,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,UAAU,CAAC,EAAC,CAAC;IAC7E,MAAM,OAAO,GACT,IAAI,gBAAgB,CAAC,UAAU,EAAE,UAAU,EAAE,YAAY,IAAI,IAAI,CAAC,CAAC;IACvE,MAAM,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC;IACnB,IAAI,YAAY,IAAI,IAAI,EAAE;QACxB,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC;KAC3B;IACD,MAAM,MAAM,GAAG,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,MAAM,EAAE,OAAO,CAAC,CAAC;IACjE,wCAAwC;IACxC,IAAI,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE;QACzB,OAAO,MAAM,CAAC;KACf;IACD,MAAM,MAAM,GAAG,SAAS,CAAC,OAAO,EAAE,CAAC,EAAE,UAAU,EAAE,MAAM,CAAC,CAAC;IACzD,OAAO,CAAC,6BAA6B,CAAC,MAAM,CAAC,CAAC;IAC9C,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,SAAS,eAAe,CACpB,OAAyB,EAAE,CAAa,EAAE,UAAuB,EACjE,eAA2B,IAAI;IACjC,MAAM,OAAO,GAAG,YAAY,IAAI,IAAI,CAAC,CAAC,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC;IACpE,MAAM,MAAM,GAAG,OAAO,CAAC,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;IAC3C,MAAM,UAAU,GAAG,YAAY,CAAC,wBAAwB,CAAC,MAAM,CAAC,CAAC;IACjE,MAAM,OAAO,GAAG,IAAI,sBAAsB,CACtC,OAAO,EAAE,UAAU,EAAE,UAAU,EAAE,YAAY,IAAI,IAAI,CAAC,CAAC;IAC3D,MAAM,MAAM,GAAG,YAAY,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,YAAY,CAAC,CAAC;IAC9D,MAAM,MAAM,GAAG,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,MAAM,EAAE,OAAO,CAAC,CAAC;IACjE,IAAI,MAAM,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,CAAC,KAAK,CAAC,MAAM,EAAE;QAC1C,MAAM,MAAM,GAAG,eAAe,CAAC,OAAO,EAAE,CAAC,EAAE,UAAU,EAAE,MAAM,CAAC,CAAC;QAC/D,OAAO,CAAC,6BAA6B,CAAC,MAAM,CAAC,CAAC;QAC9C,OAAO,MAAM,CAAC;KACf;IACD,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,MAAM,UAAU,eAAe,CAC3B,OAAyB,EAAE,CAAa,EAAE,IAAY,EACtD,UAAuB;IACzB,MAAM,IAAI,GAAG,CAAC,IAAI,CAAC,CAAC;IACpB,YAAY,CAAC,0BAA0B,CACnC,KAAK,GAAG,UAAU,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,WAAW,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,IAAI,EACtE,CAAC,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACpB,IAAI,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EAAE;QAC9D,MAAM,uBAAuB,GAAG,EAAE,CAAC;QACnC,wEAAwE;QACxE,2BAA2B;QAC3B,MAAM,QAAQ,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;QAC/C,MAAM,SAAS,GAAG,QAAQ,KAAK,IAAI,IAAI,QAAQ,CAAC,QAAQ,CAAC;QACzD,IAAI,SAAS,GAAG,CAAC,CAAC;QAClB,IAAI,SAAS,EAAE;YACb,SAAS,GAAG,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC;YACpC,uBAAuB,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;SACzC;QAED,MAAM,CAAC,QAAQ,EAAE,WAAW,CAAC,GACzB,YAAY,CAAC,yBAAyB,CAAC,SAAS,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;QAClE,MAAM,MAAM,GAAG,IAAI,CAAC,aAAa,CAAC,WAAW,CAAC,CAAC;QAC/C,MAAM,GAAG,GAAG,OAAO,CACf,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,SAAS,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,EAAC,EAAC,CAAC,CAAC;QACrE,uBAAuB,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAElC,MAAM,OAAO,GAAG,SAAS,CAAC,OAAO,EAAE,GAAG,EAAE,UAAU,CAAC,CAAC;QACpD,uBAAuB,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QACtC,MAAM,QAAQ,GACV,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,OAAO,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,QAAQ,EAAC,EAAC,CAAC,CAAC;QAEvE,uBAAuB,CAAC,OAAO,CAC3B,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,6BAA6B,CAAC,CAAC,CAAC,CAAC,CAAC;QACnD,OAAO,QAAQ,CAAC;KACjB;IACD,OAAO,eAAe,CAAC,OAAO,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC;AACjD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, env, TensorInfo, util} from '@tensorflow/tfjs-core';\n\nimport {ArgMinMaxProgram} from '../argminmax_gpu';\nimport {ArgMinMaxPackedProgram} from '../argminmax_packed_gpu';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {reshape} from '../kernels/Reshape';\n\nfunction argReduce(\n    backend: MathBackendWebGL, x: TensorInfo, reduceType: 'max'|'min',\n    bestIndicesA: TensorInfo = null): TensorInfo {\n  let batchSize = x.shape[0];\n  let inSize = x.shape[1];\n  if (bestIndicesA != null) {\n    batchSize = bestIndicesA.shape[0];\n    inSize = bestIndicesA.shape[1];\n  }\n  const windowSize = backend_util.computeOptimalWindowSize(inSize);\n  const reduceInfo =\n      {windowSize, inSize, batchSize, outSize: Math.ceil(inSize / windowSize)};\n  const program =\n      new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);\n  const inputs = [x];\n  if (bestIndicesA != null) {\n    inputs.push(bestIndicesA);\n  }\n  const output = backend.runWebGLProgram(program, inputs, 'int32');\n  // No need to run another GPGPU program.\n  if (output.shape[1] === 1) {\n    return output;\n  }\n  const result = argReduce(backend, x, reduceType, output);\n  backend.disposeIntermediateTensorInfo(output);\n  return result;\n}\n\nfunction argReducePacked(\n    backend: MathBackendWebGL, x: TensorInfo, reduceType: 'max'|'min',\n    bestIndicesA: TensorInfo = null): TensorInfo {\n  const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;\n  const inSize = inShape[inShape.length - 1];\n  const windowSize = backend_util.computeOptimalWindowSize(inSize);\n  const program = new ArgMinMaxPackedProgram(\n      inShape, windowSize, reduceType, bestIndicesA == null);\n  const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];\n  const output = backend.runWebGLProgram(program, inputs, 'int32');\n  if (output.shape.length === x.shape.length) {\n    const result = argReducePacked(backend, x, reduceType, output);\n    backend.disposeIntermediateTensorInfo(output);\n    return result;\n  }\n  return output;\n}\n\nexport function argMinMaxReduce(\n    backend: MathBackendWebGL, x: TensorInfo, axis: number,\n    reduceType: 'min'|'max'): TensorInfo {\n  const axes = [axis];\n  backend_util.assertAxesAreInnerMostDims(\n      'arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes,\n      x.shape.length);\n  if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {\n    const intermediateTensorInfos = [];\n    // Eagerly unpack x input since it is passed in to all the shaders which\n    // require unpacked inputs.\n    const xtexData = backend.texData.get(x.dataId);\n    const xIsPacked = xtexData !== null && xtexData.isPacked;\n    let xUnPacked = x;\n    if (xIsPacked) {\n      xUnPacked = backend.unpackTensor(x);\n      intermediateTensorInfos.push(xUnPacked);\n    }\n\n    const [outShape, reduceShape] =\n        backend_util.computeOutAndReduceShapes(xUnPacked.shape, axes);\n    const inSize = util.sizeFromShape(reduceShape);\n    const a2D = reshape(\n        {inputs: {x: xUnPacked}, backend, attrs: {shape: [-1, inSize]}});\n    intermediateTensorInfos.push(a2D);\n\n    const reduced = argReduce(backend, a2D, reduceType);\n    intermediateTensorInfos.push(reduced);\n    const reshaped =\n        reshape({inputs: {x: reduced}, backend, attrs: {shape: outShape}});\n\n    intermediateTensorInfos.forEach(\n        t => backend.disposeIntermediateTensorInfo(t));\n    return reshaped;\n  }\n  return argReducePacked(backend, x, reduceType);\n}\n"]}
|