/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, GatherV2, util } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
import { gatherV2Impl } from './GatherV2_impl';
|
import { reshape } from './Reshape';
|
export function gatherV2(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, indices } = inputs;
|
const { axis, batchDims } = attrs;
|
assertNotComplex([x, indices], 'gatherV2');
|
// Throw error when any index is out of bound.
|
const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
|
const indicesVals = backend.data.get(indices.dataId).values;
|
const axisDim = x.shape[parsedAxis];
|
for (let i = 0; i < indicesVals.length; ++i) {
|
const index = indicesVals[i];
|
util.assert(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
|
}
|
let $batchDims = batchDims;
|
if (batchDims == null) {
|
$batchDims = 0;
|
}
|
const indicesSize = util.sizeFromShape(indices.shape);
|
const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
|
const flattenX = reshape({
|
inputs: { x },
|
backend,
|
attrs: {
|
shape: [
|
shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
|
shapeInfo.sliceSize
|
]
|
}
|
});
|
const flattenIndex = reshape({
|
inputs: { x: indices },
|
backend,
|
attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
|
});
|
const flattenOutputShape = [
|
shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
|
shapeInfo.sliceSize
|
];
|
const indicesBuf = backend.bufferSync(flattenIndex);
|
const xBuf = backend.bufferSync(flattenX);
|
const outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
|
backend.disposeIntermediateTensorInfo(flattenX);
|
backend.disposeIntermediateTensorInfo(flattenIndex);
|
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
|
}
|
export const gatherV2Config = {
|
kernelName: GatherV2,
|
backendName: 'cpu',
|
kernelFunc: gatherV2
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiR2F0aGVyVjIuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL0dhdGhlclYyLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQUUsUUFBUSxFQUFtRixJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUdwSixPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLFlBQVksRUFBQyxNQUFNLGlCQUFpQixDQUFDO0FBQzdDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFFbEMsTUFBTSxVQUFVLFFBQVEsQ0FBQyxJQUl4QjtJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFFLE9BQU8sRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUM1QixNQUFNLEVBQUMsSUFBSSxFQUFFLFNBQVMsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUVoQyxnQkFBZ0IsQ0FBQyxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsRUFBRSxVQUFVLENBQUMsQ0FBQztJQUUzQyw4Q0FBOEM7SUFDOUMsTUFBTSxVQUFVLEdBQUcsSUFBSSxDQUFDLGNBQWMsQ0FBQyxJQUFJLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ3pELE1BQU0sV0FBVyxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBQzFFLE1BQU0sT0FBTyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsVUFBVSxDQUFDLENBQUM7SUFDcEMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFdBQVcsQ0FBQyxNQUFNLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDM0MsTUFBTSxLQUFLLEdBQUcsV0FBVyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzdCLElBQUksQ0FBQyxNQUFNLENBQ1AsS0FBSyxJQUFJLE9BQU8sR0FBRyxDQUFDLElBQUksS0FBSyxJQUFJLENBQUMsRUFDbEMsR0FBRyxFQUFFLENBQ0QsNkJBQTZCLEtBQUssa0JBQWtCLE9BQU8sR0FBRyxDQUFDLEdBQUcsQ0FBQyxDQUFDO0tBQzdFO0lBRUQsSUFBSSxVQUFVLEdBQUcsU0FBUyxDQUFDO0lBRTNCLElBQUksU0FBUyxJQUFJLElBQUksRUFBRTtRQUNyQixVQUFVLEdBQUcsQ0FBQyxDQUFDO0tBQ2hCO0lBRUQsTUFBTSxXQUFXLEdBQUcsSUFBSSxDQUFDLGFBQWEsQ0FBQyxPQUFPLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFdEQsTUFBTSxTQUFTLEdBQUcsWUFBWSxDQUFDLFlBQVksQ0FBQyx3QkFBd0IsQ0FDaEUsQ0FBQyxFQUFFLE9BQU8sRUFBRSxVQUFVLEVBQUUsVUFBVSxDQUFDLENBQUM7SUFFeEMsTUFBTSxRQUFRLEdBQUcsT0FBTyxDQUFDO1FBQ3ZCLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQztRQUNYLE9BQU87UUFDUCxLQUFLLEVBQUU7WUFDTCxLQUFLLEVBQUU7Z0JBQ0wsU0FBUyxDQUFDLFNBQVMsRUFBRSxTQUFTLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxPQUFPO2dCQUMzRCxTQUFTLENBQUMsU0FBUzthQUNwQjtTQUNGO0tBQ0YsQ0FBQyxDQUFDO0lBRUgsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUFDO1FBQzNCLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxPQUFPLEVBQUM7UUFDcEIsT0FBTztRQUNQLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxDQUFDLFNBQVMsQ0FBQyxTQUFTLEVBQUUsV0FBVyxHQUFHLFNBQVMsQ0FBQyxTQUFTLENBQUMsRUFBQztLQUN6RSxDQUFDLENBQUM7SUFFSCxNQUFNLGtCQUFrQixHQUFHO1FBQ3pCLFNBQVMsQ0FBQyxTQUFTLEVBQUUsU0FBUyxDQUFDLFNBQVMsRUFBRSxXQUFXLEdBQUcsU0FBUyxDQUFDLFNBQVM7UUFDM0UsU0FBUyxDQUFDLFNBQVM7S0FDcEIsQ0FBQztJQUVGLE1BQU0sVUFBVSxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsWUFBWSxDQUFDLENBQUM7SUFDcEQsTUFBTSxJQUFJLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUMxQyxNQUFNLE1BQU0sR0FBRyxZQUFZLENBQUMsSUFBSSxFQUFFLFVBQVUsRUFBRSxrQkFBa0IsQ0FBQyxDQUFDO0lBRWxFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUNoRCxPQUFPLENBQUMsNkJBQTZCLENBQUMsWUFBWSxDQUFDLENBQUM7SUFFcEQsT0FBTyxPQUFPLENBQUMsY0FBYyxDQUN6QixTQUFTLENBQUMsV0FBVyxFQUFFLE1BQU0sQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0FBQzFELENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxjQUFjLEdBQWlCO0lBQzFDLFVBQVUsRUFBRSxRQUFRO0lBQ3BCLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxRQUFpQztDQUM5QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgR2F0aGVyVjIsIEdhdGhlclYyQXR0cnMsIEdhdGhlclYySW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuaW1wb3J0IHtnYXRoZXJWMkltcGx9IGZyb20gJy4vR2F0aGVyVjJfaW1wbCc7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vUmVzaGFwZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBnYXRoZXJWMihhcmdzOiB7XG4gIGlucHV0czogR2F0aGVyVjJJbnB1dHMsXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVLFxuICBhdHRyczogR2F0aGVyVjJBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eCwgaW5kaWNlc30gPSBpbnB1dHM7XG4gIGNvbnN0IHtheGlzLCBiYXRjaERpbXN9ID0gYXR0cnM7XG5cbiAgYXNzZXJ0Tm90Q29tcGxleChbeCwgaW5kaWNlc10sICdnYXRoZXJWMicpO1xuXG4gIC8vIFRocm93IGVycm9yIHdoZW4gYW55IGluZGV4IGlzIG91dCBvZiBib3VuZC5cbiAgY29uc3QgcGFyc2VkQXhpcyA9IHV0aWwucGFyc2VBeGlzUGFyYW0oYXhpcywgeC5zaGFwZSlbMF07XG4gIGNvbnN0IGluZGljZXNWYWxzID0gYmFja2VuZC5kYXRhLmdldChpbmRpY2VzLmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXk7XG4gIGNvbnN0IGF4aXNEaW0gPSB4LnNoYXBlW3BhcnNlZEF4aXNdO1xuICBmb3IgKGxldCBpID0gMDsgaSA8IGluZGljZXNWYWxzLmxlbmd0aDsgKytpKSB7XG4gICAgY29uc3QgaW5kZXggPSBpbmRpY2VzVmFsc1tpXTtcbiAgICB1dGlsLmFzc2VydChcbiAgICAgICAgaW5kZXggPD0gYXhpc0RpbSAtIDEgJiYgaW5kZXggPj0gMCxcbiAgICAgICAgKCkgPT5cbiAgICAgICAgICAgIGBHYXRoZXJWMjogdGhlIGluZGV4IHZhbHVlICR7aW5kZXh9IGlzIG5vdCBpbiBbMCwgJHtheGlzRGltIC0gMX1dYCk7XG4gIH1cblxuICBsZXQgJGJhdGNoRGltcyA9IGJhdGNoRGltcztcblxuICBpZiAoYmF0Y2hEaW1zID09IG51bGwpIHtcbiAgICAkYmF0Y2hEaW1zID0gMDtcbiAgfVxuXG4gIGNvbnN0IGluZGljZXNTaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKGluZGljZXMuc2hhcGUpO1xuXG4gIGNvbnN0IHNoYXBlSW5mbyA9IGJhY2tlbmRfdXRpbC5zZWdtZW50X3V0aWwuY29sbGVjdEdhdGhlck9wU2hhcGVJbmZvKFxuICAgICAgeCwgaW5kaWNlcywgcGFyc2VkQXhpcywgJGJhdGNoRGltcyk7XG5cbiAgY29uc3QgZmxhdHRlblggPSByZXNoYXBlKHtcbiAgICBpbnB1dHM6IHt4fSxcbiAgICBiYWNrZW5kLFxuICAgIGF0dHJzOiB7XG4gICAgICBzaGFwZTogW1xuICAgICAgICBzaGFwZUluZm8uYmF0Y2hTaXplLCBzaGFwZUluZm8ub3V0ZXJTaXplLCBzaGFwZUluZm8uZGltU2l6ZSxcbiAgICAgICAgc2hhcGVJbmZvLnNsaWNlU2l6ZVxuICAgICAgXVxuICAgIH1cbiAgfSk7XG5cbiAgY29uc3QgZmxhdHRlbkluZGV4ID0gcmVzaGFwZSh7XG4gICAgaW5wdXRzOiB7eDogaW5kaWNlc30sXG4gICAgYmFja2VuZCxcbiAgICBhdHRyczoge3NoYXBlOiBbc2hhcGVJbmZvLmJhdGNoU2l6ZSwgaW5kaWNlc1NpemUgLyBzaGFwZUluZm8uYmF0Y2hTaXplXX1cbiAgfSk7XG5cbiAgY29uc3QgZmxhdHRlbk91dHB1dFNoYXBlID0gW1xuICAgIHNoYXBlSW5mby5iYXRjaFNpemUsIHNoYXBlSW5mby5vdXRlclNpemUsIGluZGljZXNTaXplIC8gc2hhcGVJbmZvLmJhdGNoU2l6ZSxcbiAgICBzaGFwZUluZm8uc2xpY2VTaXplXG4gIF07XG5cbiAgY29uc3QgaW5kaWNlc0J1ZiA9IGJhY2tlbmQuYnVmZmVyU3luYyhmbGF0dGVuSW5kZXgpO1xuICBjb25zdCB4QnVmID0gYmFja2VuZC5idWZmZXJTeW5jKGZsYXR0ZW5YKTtcbiAgY29uc3Qgb3V0QnVmID0gZ2F0aGVyVjJJbXBsKHhCdWYsIGluZGljZXNCdWYsIGZsYXR0ZW5PdXRwdXRTaGFwZSk7XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhmbGF0dGVuWCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oZmxhdHRlbkluZGV4KTtcblxuICByZXR1cm4gYmFja2VuZC5tYWtlVGVuc29ySW5mbyhcbiAgICAgIHNoYXBlSW5mby5vdXRwdXRTaGFwZSwgb3V0QnVmLmR0eXBlLCBvdXRCdWYudmFsdWVzKTtcbn1cblxuZXhwb3J0IGNvbnN0IGdhdGhlclYyQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEdhdGhlclYyLFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IGdhdGhlclYyIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==
|