/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { env, TopK, util } from '@tensorflow/tfjs-core';
|
import { topKImplCPU } from '../kernel_utils/shared';
|
import { MergeProgram, SwapProgram } from '../top_k_gpu';
|
import { fill } from './Fill';
|
import { gatherV2 } from './GatherV2';
|
import { reshape } from './Reshape';
|
import { slice } from './Slice';
|
function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) {
|
if (tensorInfo !== null) {
|
backend.disposeIntermediateTensorInfo(tensorInfo);
|
}
|
}
|
function roundUpToPow2(num) {
|
let pow2 = 1;
|
while (pow2 < num) {
|
pow2 *= 2;
|
}
|
return pow2;
|
}
|
// Based on Algorithm 2 of Bitonic Top K, ref:
|
// https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
|
export function topK(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { k, sorted } = attrs;
|
// Empirically determined constant used to determine last dim threshold for
|
// handing off execution to the CPU.
|
const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');
|
// Empirically determined constant used to determine k threshold for handing
|
// off execution to the CPU.
|
const TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');
|
const xShape = x.shape;
|
const lastDim = xShape[xShape.length - 1];
|
if (backend.shouldExecuteOnCPU([x]) ||
|
lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD ||
|
k > TOPK_K_CPU_HANDOFF_THRESHOLD) {
|
const xVals = backend.readSync(x.dataId);
|
const [allTopKVals, allTopKIndices] = topKImplCPU(xVals, xShape, x.dtype, k, sorted);
|
return [
|
backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
|
backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
|
];
|
}
|
if (k === 0) {
|
xShape[xShape.length - 1] = 0;
|
return [
|
backend.makeTensorInfo(xShape, x.dtype, []),
|
backend.makeTensorInfo(xShape, 'int32', [])
|
];
|
}
|
if (lastDim === 1 /* firstPass */) {
|
return [
|
x, fill({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend })
|
];
|
}
|
// Eagerly unpack x input since it is passed in to all the shaders which
|
// require unpacked inputs.
|
const xtexData = backend.texData.get(x.dataId);
|
const xIsPacked = xtexData !== null && xtexData.isPacked;
|
const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;
|
// Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
|
const xSize = util.sizeFromShape(xShape);
|
const batch = xSize / lastDim;
|
const x2D = reshape({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend });
|
if (xIsPacked) {
|
disposeIntermediateTensorInfoOrNull(backend, xUnPacked);
|
}
|
const kPow2 = roundUpToPow2(k);
|
const lastDimPow2 = roundUpToPow2(lastDim);
|
// Only the indices containing the top K are kept at every step to reduce
|
// number of outputs in the GPU algorithms, so once the final set of indices
|
// is computed then gather is used to grab the corresponding values
|
// from the original input.
|
let indices = null;
|
// GPU algorithm always takes in an indices input but this input is not used
|
// on the first run of a GPU algorithm, therefore if indices is null we simply
|
// pass in x2D instead of it but the value will not actually be used
|
const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices];
|
const runSwap = (dir, inc, shape) => {
|
const inputs = getInputs();
|
const program = new SwapProgram(shape);
|
const fistPass = indices === null ? 1 : 0;
|
const customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];
|
const prevIndices = indices;
|
indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);
|
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
|
};
|
// Step 1: local sort
|
for (let len = 1; len < kPow2; len *= 2) {
|
const dir = len * 2;
|
for (let inc = len; inc >= 1; inc /= 2) {
|
runSwap(dir, inc, [batch, lastDimPow2]);
|
}
|
}
|
// Step 2: merge
|
for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {
|
const inputs = getInputs();
|
const mergeProgram = new MergeProgram([batch, indicesSize / 2]);
|
const firstPass = indices === null ? 1 : 0;
|
const customValues = [[lastDim], [firstPass], [kPow2]];
|
const prevIndices = indices;
|
indices =
|
backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues);
|
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
|
// Step 3: rebuild
|
const len = kPow2 / 2;
|
const dir = len * 2;
|
for (let inc = len; inc >= 1; inc /= 2) {
|
runSwap(dir, inc, indices.shape);
|
}
|
}
|
// Keep only the requested top K results instead of kPow2
|
let prevIndices = indices;
|
indices = slice({ inputs: { x: indices }, backend, attrs: { begin: 0, size: [batch, k] } });
|
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
|
// Gather values on last dimension
|
let values = gatherV2({ inputs: { x: x2D, indices }, backend, attrs: { axis: 1, batchDims: 1 } });
|
disposeIntermediateTensorInfoOrNull(backend, x2D);
|
// Reshape back to the original input shape, except that the last
|
// dimension is k.
|
const newShape = xShape.slice(0, -1);
|
newShape.push(k);
|
prevIndices = indices;
|
indices = reshape({ inputs: { x: indices }, attrs: { shape: newShape }, backend });
|
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
|
const prevValues = values;
|
values = reshape({ inputs: { x: values }, attrs: { shape: newShape }, backend });
|
disposeIntermediateTensorInfoOrNull(backend, prevValues);
|
return [values, indices];
|
}
|
export const topKConfig = {
|
kernelName: TopK,
|
backendName: 'webgl',
|
kernelFunc: topK
|
};
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"TopK.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/TopK.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,GAAG,EAAyD,IAAI,EAAqC,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGhJ,OAAO,EAAC,WAAW,EAAC,MAAM,wBAAwB,CAAC;AACnD,OAAO,EAAC,YAAY,EAAE,WAAW,EAAC,MAAM,cAAc,CAAC;AACvD,OAAO,EAAC,IAAI,EAAC,MAAM,QAAQ,CAAC;AAC5B,OAAO,EAAC,QAAQ,EAAC,MAAM,YAAY,CAAC;AACpC,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAClC,OAAO,EAAC,KAAK,EAAC,MAAM,SAAS,CAAC;AAE9B,SAAS,mCAAmC,CACxC,OAAyB,EAAE,UAAsB;IACnD,IAAI,UAAU,KAAK,IAAI,EAAE;QACvB,OAAO,CAAC,6BAA6B,CAAC,UAAU,CAAC,CAAC;KACnD;AACH,CAAC;AAED,SAAS,aAAa,CAAC,GAAW;IAChC,IAAI,IAAI,GAAG,CAAC,CAAC;IACb,OAAO,IAAI,GAAG,GAAG,EAAE;QACjB,IAAI,IAAI,CAAC,CAAC;KACX;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAED,8CAA8C;AAC9C,6DAA6D;AAC7D,MAAM,UAAU,IAAI,CAChB,IAAuE;IAEzE,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAC,GAAG,MAAM,CAAC;IACnB,MAAM,EAAC,CAAC,EAAE,MAAM,EAAC,GAAG,KAAK,CAAC;IAE1B,2EAA2E;IAC3E,oCAAoC;IACpC,MAAM,wCAAwC,GAC1C,GAAG,EAAE,CAAC,SAAS,CAAC,0CAA0C,CAAC,CAAC;IAEhE,4EAA4E;IAC5E,4BAA4B;IAC5B,MAAM,4BAA4B,GAC9B,GAAG,EAAE,CAAC,SAAS,CAAC,8BAA8B,CAAC,CAAC;IAEpD,MAAM,MAAM,GAAG,CAAC,CAAC,KAAK,CAAC;IACvB,MAAM,OAAO,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;IAE1C,IAAI,OAAO,CAAC,kBAAkB,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/B,OAAO,GAAG,wCAAwC;QAClD,CAAC,GAAG,4BAA4B,EAAE;QACpC,MAAM,KAAK,GAAG,OAAO,CAAC,QAAQ,CAAC,CAAC,CAAC,MAAM,CAAe,CAAC;QACvD,MAAM,CAAC,WAAW,EAAE,cAAc,CAAC,GAC/B,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,CAAC,CAAC,KAAwB,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAEtE,OAAO;YACL,OAAO,CAAC,cAAc,CAClB,WAAW,CAAC,KAAK,EAAE,WAAW,CAAC,KAAK,EAAE,WAAW,CAAC,MAAM,CAAC;YAC7D,OAAO,CAAC,cAAc,CAClB,cAAc,CAAC,KAAK,EAAE,cAAc,CAAC,KAAK,EAAE,cAAc,CAAC,MAAM,CAAC;SACvE,CAAC;KACH;IAED,IAAI,CAAC,KAAK,CAAC,EAAE;QACX,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;QAC9B,OAAO;YACL,OAAO,CAAC,cAAc,CAAC,MAAM,EAAE,CAAC,CAAC,KAAK,EAAE,EAAE,CAAC;YAC3C,OAAO,CAAC,cAAc,CAAC,MAAM,EAAE,OAAO,EAAE,EAAE,CAAC;SAC5C,CAAC;KACH;IAED,IAAI,OAAO,KAAK,CAAC,CAAC,eAAe,EAAE;QACjC,OAAO;YACL,CAAC,EAAE,IAAI,CAAC,EAAC,KAAK,EAAE,EAAC,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,KAAK,EAAE,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC;SACrE,CAAC;KACH;IAED,wEAAwE;IACxE,2BAA2B;IAC3B,MAAM,QAAQ,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;IAC/C,MAAM,SAAS,GAAG,QAAQ,KAAK,IAAI,IAAI,QAAQ,CAAC,QAAQ,CAAC;IACzD,MAAM,SAAS,GAAG,SAAS,CAAC,CAAC,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAE1D,4EAA4E;IAC5E,MAAM,KAAK,GAAG,IAAI,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;IACzC,MAAM,KAAK,GAAG,KAAK,GAAG,OAAO,CAAC;IAC9B,MAAM,GAAG,GAAG,OAAO,CACf,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,SAAS,EAAC,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,KAAK,EAAE,OAAO,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;IAEzE,IAAI,SAAS,EAAE;QACb,mCAAmC,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC;KACzD;IAED,MAAM,KAAK,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;IAC/B,MAAM,WAAW,GAAG,aAAa,CAAC,OAAO,CAAC,CAAC;IAE3C,yEAAyE;IACzE,4EAA4E;IAC5E,mEAAmE;IACnE,2BAA2B;IAC3B,IAAI,OAAO,GAAe,IAAI,CAAC;IAE/B,4EAA4E;IAC5E,8EAA8E;IAC9E,oEAAoE;IACpE,MAAM,SAAS,GAAG,GAAG,EAAE,CAAC,OAAO,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,OAAO,CAAC,CAAC;IAEvE,MAAM,OAAO,GAAG,CAAC,GAAW,EAAE,GAAW,EAAE,KAAe,EAAE,EAAE;QAC5D,MAAM,MAAM,GAAG,SAAS,EAAE,CAAC;QAC3B,MAAM,OAAO,GAAG,IAAI,WAAW,CAAC,KAAK,CAAC,CAAC;QACvC,MAAM,QAAQ,GAAG,OAAO,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC1C,MAAM,YAAY,GACd,CAAC,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC;QACtE,MAAM,WAAW,GAAG,OAAO,CAAC;QAC5B,OAAO,GAAG,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,YAAY,CAAC,CAAC;QAC1E,mCAAmC,CAAC,OAAO,EAAE,WAAW,CAAC,CAAC;IAC5D,CAAC,CAAC;IAEF,qBAAqB;IACrB,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,KAAK,EAAE,GAAG,IAAI,CAAC,EAAE;QACvC,MAAM,GAAG,GAAG,GAAG,GAAG,CAAC,CAAC;QACpB,KAAK,IAAI,GAAG,GAAG,GAAG,EAAE,GAAG,IAAI,CAAC,EAAE,GAAG,IAAI,CAAC,EAAE;YACtC,OAAO,CAAC,GAAG,EAAE,GAAG,EAAE,CAAC,KAAK,EAAE,WAAW,CAAC,CAAC,CAAC;SACzC;KACF;IAED,gBAAgB;IAChB,KAAK,IAAI,WAAW,GAAG,WAAW,EAAE,WAAW,GAAG,KAAK,EAAE,WAAW,IAAI,CAAC,EAAE;QACzE,MAAM,MAAM,GAAG,SAAS,EAAE,CAAC;QAC3B,MAAM,YAAY,GAAG,IAAI,YAAY,CAAC,CAAC,KAAK,EAAE,WAAW,GAAG,CAAC,CAAC,CAAC,CAAC;QAChE,MAAM,SAAS,GAAG,OAAO,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,YAAY,GAAG,CAAC,CAAC,OAAO,CAAC,EAAE,CAAC,SAAS,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC;QACvD,MAAM,WAAW,GAAG,OAAO,CAAC;QAC5B,OAAO;YACH,OAAO,CAAC,eAAe,CAAC,YAAY,EAAE,MAAM,EAAE,OAAO,EAAE,YAAY,CAAC,CAAC;QACzE,mCAAmC,CAAC,OAAO,EAAE,WAAW,CAAC,CAAC;QAE1D,kBAAkB;QAClB,MAAM,GAAG,GAAG,KAAK,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,GAAG,GAAG,CAAC,CAAC;QACpB,KAAK,IAAI,GAAG,GAAG,GAAG,EAAE,GAAG,IAAI,CAAC,EAAE,GAAG,IAAI,CAAC,EAAE;YACtC,OAAO,CAAC,GAAG,EAAE,GAAG,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC;SAClC;KACF;IAED,yDAAyD;IACzD,IAAI,WAAW,GAAG,OAAO,CAAC;IAC1B,OAAO,GAAG,KAAK,CACX,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,OAAO,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,EAAE,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,EAAC,EAAC,CAAC,CAAC;IAC1E,mCAAmC,CAAC,OAAO,EAAE,WAAW,CAAC,CAAC;IAE1D,kCAAkC;IAClC,IAAI,MAAM,GAAG,QAAQ,CACjB,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,GAAG,EAAE,OAAO,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,IAAI,EAAE,CAAC,EAAE,SAAS,EAAE,CAAC,EAAC,EAAC,CAAC,CAAC;IAC1E,mCAAmC,CAAC,OAAO,EAAE,GAAG,CAAC,CAAC;IAElD,iEAAiE;IACjE,kBAAkB;IAClB,MAAM,QAAQ,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACrC,QAAQ,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAEjB,WAAW,GAAG,OAAO,CAAC;IACtB,OAAO,GAAG,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,OAAO,EAAC,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,QAAQ,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;IAC7E,mCAAmC,CAAC,OAAO,EAAE,WAAW,CAAC,CAAC;IAE1D,MAAM,UAAU,GAAG,MAAM,CAAC;IAC1B,MAAM,GAAG,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,QAAQ,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;IAC3E,mCAAmC,CAAC,OAAO,EAAE,UAAU,CAAC,CAAC;IAEzD,OAAO,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;AAC3B,CAAC;AAED,MAAM,CAAC,MAAM,UAAU,GAAiB;IACtC,UAAU,EAAE,IAAI;IAChB,WAAW,EAAE,OAAO;IACpB,UAAU,EAAE,IAA6B;CAC1C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env, KernelConfig, KernelFunc, NumericDataType, TensorInfo, TopK, TopKAttrs, TopKInputs, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {topKImplCPU} from '../kernel_utils/shared';\nimport {MergeProgram, SwapProgram} from '../top_k_gpu';\nimport {fill} from './Fill';\nimport {gatherV2} from './GatherV2';\nimport {reshape} from './Reshape';\nimport {slice} from './Slice';\n\nfunction disposeIntermediateTensorInfoOrNull(\n    backend: MathBackendWebGL, tensorInfo: TensorInfo) {\n  if (tensorInfo !== null) {\n    backend.disposeIntermediateTensorInfo(tensorInfo);\n  }\n}\n\nfunction roundUpToPow2(num: number) {\n  let pow2 = 1;\n  while (pow2 < num) {\n    pow2 *= 2;\n  }\n  return pow2;\n}\n\n// Based on Algorithm 2 of Bitonic Top K, ref:\n// https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf\nexport function topK(\n    args: {inputs: TopKInputs, backend: MathBackendWebGL, attrs: TopKAttrs}):\n    TensorInfo[] {\n  const {inputs, backend, attrs} = args;\n  const {x} = inputs;\n  const {k, sorted} = attrs;\n\n  // Empirically determined constant used to determine last dim threshold for\n  // handing off execution to the CPU.\n  const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD =\n      env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');\n\n  // Empirically determined constant used to determine k threshold for handing\n  // off execution to the CPU.\n  const TOPK_K_CPU_HANDOFF_THRESHOLD =\n      env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');\n\n  const xShape = x.shape;\n  const lastDim = xShape[xShape.length - 1];\n\n  if (backend.shouldExecuteOnCPU([x]) ||\n      lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD ||\n      k > TOPK_K_CPU_HANDOFF_THRESHOLD) {\n    const xVals = backend.readSync(x.dataId) as TypedArray;\n    const [allTopKVals, allTopKIndices] =\n        topKImplCPU(xVals, xShape, x.dtype as NumericDataType, k, sorted);\n\n    return [\n      backend.makeTensorInfo(\n          allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),\n      backend.makeTensorInfo(\n          allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)\n    ];\n  }\n\n  if (k === 0) {\n    xShape[xShape.length - 1] = 0;\n    return [\n      backend.makeTensorInfo(xShape, x.dtype, []),\n      backend.makeTensorInfo(xShape, 'int32', [])\n    ];\n  }\n\n  if (lastDim === 1 /* firstPass */) {\n    return [\n      x, fill({attrs: {shape: xShape, dtype: 'int32', value: 0}, backend})\n    ];\n  }\n\n  // Eagerly unpack x input since it is passed in to all the shaders which\n  // require unpacked inputs.\n  const xtexData = backend.texData.get(x.dataId);\n  const xIsPacked = xtexData !== null && xtexData.isPacked;\n  const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;\n\n  // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.\n  const xSize = util.sizeFromShape(xShape);\n  const batch = xSize / lastDim;\n  const x2D = reshape(\n      {inputs: {x: xUnPacked}, attrs: {shape: [batch, lastDim]}, backend});\n\n  if (xIsPacked) {\n    disposeIntermediateTensorInfoOrNull(backend, xUnPacked);\n  }\n\n  const kPow2 = roundUpToPow2(k);\n  const lastDimPow2 = roundUpToPow2(lastDim);\n\n  // Only the indices containing the top K are kept at every step to reduce\n  // number of outputs in the GPU algorithms, so once the final set of indices\n  // is computed then gather is used to grab the corresponding values\n  // from the original input.\n  let indices: TensorInfo = null;\n\n  // GPU algorithm always takes in an indices input but this input is not used\n  // on the first run of a GPU algorithm, therefore if indices is null we simply\n  // pass in x2D instead of it but the value will not actually be used\n  const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices];\n\n  const runSwap = (dir: number, inc: number, shape: number[]) => {\n    const inputs = getInputs();\n    const program = new SwapProgram(shape);\n    const fistPass = indices === null ? 1 : 0;\n    const customValues =\n        [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];\n    const prevIndices = indices;\n    indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);\n    disposeIntermediateTensorInfoOrNull(backend, prevIndices);\n  };\n\n  // Step 1: local sort\n  for (let len = 1; len < kPow2; len *= 2) {\n    const dir = len * 2;\n    for (let inc = len; inc >= 1; inc /= 2) {\n      runSwap(dir, inc, [batch, lastDimPow2]);\n    }\n  }\n\n  // Step 2: merge\n  for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {\n    const inputs = getInputs();\n    const mergeProgram = new MergeProgram([batch, indicesSize / 2]);\n    const firstPass = indices === null ? 1 : 0;\n    const customValues = [[lastDim], [firstPass], [kPow2]];\n    const prevIndices = indices;\n    indices =\n        backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues);\n    disposeIntermediateTensorInfoOrNull(backend, prevIndices);\n\n    // Step 3: rebuild\n    const len = kPow2 / 2;\n    const dir = len * 2;\n    for (let inc = len; inc >= 1; inc /= 2) {\n      runSwap(dir, inc, indices.shape);\n    }\n  }\n\n  // Keep only the requested top K results instead of kPow2\n  let prevIndices = indices;\n  indices = slice(\n      {inputs: {x: indices}, backend, attrs: {begin: 0, size: [batch, k]}});\n  disposeIntermediateTensorInfoOrNull(backend, prevIndices);\n\n  // Gather values on last dimension\n  let values = gatherV2(\n      {inputs: {x: x2D, indices}, backend, attrs: {axis: 1, batchDims: 1}});\n  disposeIntermediateTensorInfoOrNull(backend, x2D);\n\n  // Reshape back to the original input shape, except that the last\n  // dimension is k.\n  const newShape = xShape.slice(0, -1);\n  newShape.push(k);\n\n  prevIndices = indices;\n  indices = reshape({inputs: {x: indices}, attrs: {shape: newShape}, backend});\n  disposeIntermediateTensorInfoOrNull(backend, prevIndices);\n\n  const prevValues = values;\n  values = reshape({inputs: {x: values}, attrs: {shape: newShape}, backend});\n  disposeIntermediateTensorInfoOrNull(backend, prevValues);\n\n  return [values, indices];\n}\n\nexport const topKConfig: KernelConfig = {\n  kernelName: TopK,\n  backendName: 'webgl',\n  kernelFunc: topK as unknown as KernelFunc\n};\n"]}
|