/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, GatherV2, util } from '@tensorflow/tfjs-core'; import { assertNotComplex } from '../cpu_util'; import { gatherV2Impl } from './GatherV2_impl'; import { reshape } from './Reshape'; export function gatherV2(args) { const { inputs, backend, attrs } = args; const { x, indices } = inputs; const { axis, batchDims } = attrs; assertNotComplex([x, indices], 'gatherV2'); // Throw error when any index is out of bound. const parsedAxis = util.parseAxisParam(axis, x.shape)[0]; const indicesVals = backend.data.get(indices.dataId).values; const axisDim = x.shape[parsedAxis]; for (let i = 0; i < indicesVals.length; ++i) { const index = indicesVals[i]; util.assert(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`); } let $batchDims = batchDims; if (batchDims == null) { $batchDims = 0; } const indicesSize = util.sizeFromShape(indices.shape); const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims); const flattenX = reshape({ inputs: { x }, backend, attrs: { shape: [ shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize, shapeInfo.sliceSize ] } }); const flattenIndex = reshape({ inputs: { x: indices }, backend, attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] } }); const flattenOutputShape = [ shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize, shapeInfo.sliceSize ]; const indicesBuf = backend.bufferSync(flattenIndex); const xBuf = backend.bufferSync(flattenX); const outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(flattenIndex); return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values); } export const gatherV2Config = { kernelName: GatherV2, backendName: 'cpu', kernelFunc: gatherV2 }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiR2F0aGVyVjIuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL0dhdGhlclYyLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQUUsUUFBUSxFQUFtRixJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUdwSixPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLFlBQVksRUFBQyxNQUFNLGlCQUFpQixDQUFDO0FBQzdDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFFbEMsTUFBTSxVQUFVLFFBQVEsQ0FBQyxJQUl4QjtJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFFLE9BQU8sRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUM1QixNQUFNLEVBQUMsSUFBSSxFQUFFLFNBQVMsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUVoQyxnQkFBZ0IsQ0FBQyxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsRUFBRSxVQUFVLENBQUMsQ0FBQztJQUUzQyw4Q0FBOEM7SUFDOUMsTUFBTSxVQUFVLEdBQUcsSUFBSSxDQUFDLGNBQWMsQ0FBQyxJQUFJLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ3pELE1BQU0sV0FBVyxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBQzFFLE1BQU0sT0FBTyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsVUFBVSxDQUFDLENBQUM7SUFDcEMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFdBQVcsQ0FBQyxNQUFNLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDM0MsTUFBTSxLQUFLLEdBQUcsV0FBVyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzdCLElBQUksQ0FBQyxNQUFNLENBQ1AsS0FBSyxJQUFJLE9BQU8sR0FBRyxDQUFDLElBQUksS0FBSyxJQUFJLENBQUMsRUFDbEMsR0FBRyxFQUFFLENBQ0QsNkJBQTZCLEtBQUssa0JBQWtCLE9BQU8sR0FBRyxDQUFDLEdBQUcsQ0FBQyxDQUFDO0tBQzdFO0lBRUQsSUFBSSxVQUFVLEdBQUcsU0FBUyxDQUFDO0lBRTNCLElBQUksU0FBUyxJQUFJLElBQUksRUFBRTtRQUNyQixVQUFVLEdBQUcsQ0FBQyxDQUFDO0tBQ2hCO0lBRUQsTUFBTSxXQUFXLEdBQUcsSUFBSSxDQUFDLGFBQWEsQ0FBQyxPQUFPLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFdEQsTUFBTSxTQUFTLEdBQUcsWUFBWSxDQUFDLFlBQVksQ0FBQyx3QkFBd0IsQ0FDaEUsQ0FBQyxFQUFFLE9BQU8sRUFBRSxVQUFVLEVBQUUsVUFBVSxDQUFDLENBQUM7SUFFeEMsTUFBTSxRQUFRLEdBQUcsT0FBTyxDQUFDO1FBQ3ZCLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQztRQUNYLE9BQU87UUFDUCxLQUFLLEVBQUU7WUFDTCxLQUFLLEVBQUU7Z0JBQ0wsU0FBUyxDQUFDLFNBQVMsRUFBRSxTQUFTLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxPQUFPO2dCQUMzRCxTQUFTLENBQUMsU0FBUzthQUNwQjtTQUNGO0tBQ0YsQ0FBQyxDQUFDO0lBRUgsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUFDO1FBQzNCLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxPQUFPLEVBQUM7UUFDcEIsT0FBTztRQUNQLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxDQUFDLFNBQVMsQ0FBQyxTQUFTLEVBQUUsV0FBVyxHQUFHLFNBQVMsQ0FBQyxTQUFTLENBQUMsRUFBQztLQUN6RSxDQUFDLENBQUM7SUFFSCxNQUFNLGtCQUFrQixHQUFHO1FBQ3pCLFNBQVMsQ0FBQyxTQUFTLEVBQUUsU0FBUyxDQUFDLFNBQVMsRUFBRSxXQUFXLEdBQUcsU0FBUyxDQUFDLFNBQVM7UUFDM0UsU0FBUyxDQUFDLFNBQVM7S0FDcEIsQ0FBQztJQUVGLE1BQU0sVUFBVSxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsWUFBWSxDQUFDLENBQUM7SUFDcEQsTUFBTSxJQUFJLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUMxQyxNQUFNLE1BQU0sR0FBRyxZQUFZLENBQUMsSUFBSSxFQUFFLFVBQVUsRUFBRSxrQkFBa0IsQ0FBQyxDQUFDO0lBRWxFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUNoRCxPQUFPLENBQUMsNkJBQTZCLENBQUMsWUFBWSxDQUFDLENBQUM7SUFFcEQsT0FBTyxPQUFPLENBQUMsY0FBYyxDQUN6QixTQUFTLENBQUMsV0FBVyxFQUFFLE1BQU0sQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0FBQzFELENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxjQUFjLEdBQWlCO0lBQzFDLFVBQVUsRUFBRSxRQUFRO0lBQ3BCLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxRQUFpQztDQUM5QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgR2F0aGVyVjIsIEdhdGhlclYyQXR0cnMsIEdhdGhlclYySW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuaW1wb3J0IHtnYXRoZXJWMkltcGx9IGZyb20gJy4vR2F0aGVyVjJfaW1wbCc7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vUmVzaGFwZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBnYXRoZXJWMihhcmdzOiB7XG4gIGlucHV0czogR2F0aGVyVjJJbnB1dHMsXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVLFxuICBhdHRyczogR2F0aGVyVjJBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eCwgaW5kaWNlc30gPSBpbnB1dHM7XG4gIGNvbnN0IHtheGlzLCBiYXRjaERpbXN9ID0gYXR0cnM7XG5cbiAgYXNzZXJ0Tm90Q29tcGxleChbeCwgaW5kaWNlc10sICdnYXRoZXJWMicpO1xuXG4gIC8vIFRocm93IGVycm9yIHdoZW4gYW55IGluZGV4IGlzIG91dCBvZiBib3VuZC5cbiAgY29uc3QgcGFyc2VkQXhpcyA9IHV0aWwucGFyc2VBeGlzUGFyYW0oYXhpcywgeC5zaGFwZSlbMF07XG4gIGNvbnN0IGluZGljZXNWYWxzID0gYmFja2VuZC5kYXRhLmdldChpbmRpY2VzLmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXk7XG4gIGNvbnN0IGF4aXNEaW0gPSB4LnNoYXBlW3BhcnNlZEF4aXNdO1xuICBmb3IgKGxldCBpID0gMDsgaSA8IGluZGljZXNWYWxzLmxlbmd0aDsgKytpKSB7XG4gICAgY29uc3QgaW5kZXggPSBpbmRpY2VzVmFsc1tpXTtcbiAgICB1dGlsLmFzc2VydChcbiAgICAgICAgaW5kZXggPD0gYXhpc0RpbSAtIDEgJiYgaW5kZXggPj0gMCxcbiAgICAgICAgKCkgPT5cbiAgICAgICAgICAgIGBHYXRoZXJWMjogdGhlIGluZGV4IHZhbHVlICR7aW5kZXh9IGlzIG5vdCBpbiBbMCwgJHtheGlzRGltIC0gMX1dYCk7XG4gIH1cblxuICBsZXQgJGJhdGNoRGltcyA9IGJhdGNoRGltcztcblxuICBpZiAoYmF0Y2hEaW1zID09IG51bGwpIHtcbiAgICAkYmF0Y2hEaW1zID0gMDtcbiAgfVxuXG4gIGNvbnN0IGluZGljZXNTaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKGluZGljZXMuc2hhcGUpO1xuXG4gIGNvbnN0IHNoYXBlSW5mbyA9IGJhY2tlbmRfdXRpbC5zZWdtZW50X3V0aWwuY29sbGVjdEdhdGhlck9wU2hhcGVJbmZvKFxuICAgICAgeCwgaW5kaWNlcywgcGFyc2VkQXhpcywgJGJhdGNoRGltcyk7XG5cbiAgY29uc3QgZmxhdHRlblggPSByZXNoYXBlKHtcbiAgICBpbnB1dHM6IHt4fSxcbiAgICBiYWNrZW5kLFxuICAgIGF0dHJzOiB7XG4gICAgICBzaGFwZTogW1xuICAgICAgICBzaGFwZUluZm8uYmF0Y2hTaXplLCBzaGFwZUluZm8ub3V0ZXJTaXplLCBzaGFwZUluZm8uZGltU2l6ZSxcbiAgICAgICAgc2hhcGVJbmZvLnNsaWNlU2l6ZVxuICAgICAgXVxuICAgIH1cbiAgfSk7XG5cbiAgY29uc3QgZmxhdHRlbkluZGV4ID0gcmVzaGFwZSh7XG4gICAgaW5wdXRzOiB7eDogaW5kaWNlc30sXG4gICAgYmFja2VuZCxcbiAgICBhdHRyczoge3NoYXBlOiBbc2hhcGVJbmZvLmJhdGNoU2l6ZSwgaW5kaWNlc1NpemUgLyBzaGFwZUluZm8uYmF0Y2hTaXplXX1cbiAgfSk7XG5cbiAgY29uc3QgZmxhdHRlbk91dHB1dFNoYXBlID0gW1xuICAgIHNoYXBlSW5mby5iYXRjaFNpemUsIHNoYXBlSW5mby5vdXRlclNpemUsIGluZGljZXNTaXplIC8gc2hhcGVJbmZvLmJhdGNoU2l6ZSxcbiAgICBzaGFwZUluZm8uc2xpY2VTaXplXG4gIF07XG5cbiAgY29uc3QgaW5kaWNlc0J1ZiA9IGJhY2tlbmQuYnVmZmVyU3luYyhmbGF0dGVuSW5kZXgpO1xuICBjb25zdCB4QnVmID0gYmFja2VuZC5idWZmZXJTeW5jKGZsYXR0ZW5YKTtcbiAgY29uc3Qgb3V0QnVmID0gZ2F0aGVyVjJJbXBsKHhCdWYsIGluZGljZXNCdWYsIGZsYXR0ZW5PdXRwdXRTaGFwZSk7XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhmbGF0dGVuWCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oZmxhdHRlbkluZGV4KTtcblxuICByZXR1cm4gYmFja2VuZC5tYWtlVGVuc29ySW5mbyhcbiAgICAgIHNoYXBlSW5mby5vdXRwdXRTaGFwZSwgb3V0QnVmLmR0eXBlLCBvdXRCdWYudmFsdWVzKTtcbn1cblxuZXhwb3J0IGNvbnN0IGdhdGhlclYyQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEdhdGhlclYyLFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IGdhdGhlclYyIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==