/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, GatherNd, util } from '@tensorflow/tfjs-core';
|
import { gatherNdImpl } from './GatherNd_Impl';
|
export function gatherNd(args) {
|
const { inputs, backend } = args;
|
const { params, indices } = inputs;
|
const paramsSize = util.sizeFromShape(params.shape);
|
const indicesShape = indices.shape;
|
const sliceRank = indicesShape[indicesShape.length - 1];
|
const [resultShape, numSlices, sliceSize, strides] = backend_util.prepareAndValidate(params, indices);
|
if (numSlices === 0) {
|
return backend.makeTensorInfo(resultShape, params.dtype, []);
|
}
|
const indicesData = backend.data.get(indices.dataId).values;
|
const paramsBuf = backend.bufferSync(params);
|
const outBuf = gatherNdImpl(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
|
return backend.makeTensorInfo(resultShape, params.dtype, outBuf.values);
|
}
|
export const gatherNdConfig = {
|
kernelName: GatherNd,
|
backendName: 'cpu',
|
kernelFunc: gatherNd
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiR2F0aGVyTmQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL0dhdGhlck5kLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQUUsUUFBUSxFQUEwRSxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUkzSSxPQUFPLEVBQUMsWUFBWSxFQUFDLE1BQU0saUJBQWlCLENBQUM7QUFFN0MsTUFBTSxVQUFVLFFBQVEsQ0FDcEIsSUFBdUQ7SUFDekQsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDL0IsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFFakMsTUFBTSxVQUFVLEdBQUcsSUFBSSxDQUFDLGFBQWEsQ0FBQyxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFcEQsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUFDLEtBQUssQ0FBQztJQUNuQyxNQUFNLFNBQVMsR0FBRyxZQUFZLENBQUMsWUFBWSxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUMsQ0FBQztJQUV4RCxNQUFNLENBQUMsV0FBVyxFQUFFLFNBQVMsRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFDLEdBQzlDLFlBQVksQ0FBQyxrQkFBa0IsQ0FBQyxNQUFNLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFDckQsSUFBSSxTQUFTLEtBQUssQ0FBQyxFQUFFO1FBQ25CLE9BQU8sT0FBTyxDQUFDLGNBQWMsQ0FBQyxXQUFXLEVBQUUsTUFBTSxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsQ0FBQztLQUM5RDtJQUVELE1BQU0sV0FBVyxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBQzFFLE1BQU0sU0FBUyxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQWtCLE1BQU0sQ0FBQyxDQUFDO0lBQzlELE1BQU0sTUFBTSxHQUFHLFlBQVksQ0FDdkIsV0FBVyxFQUFFLFNBQVMsRUFBRSxNQUFNLENBQUMsS0FBSyxFQUFFLFNBQVMsRUFBRSxTQUFTLEVBQUUsU0FBUyxFQUNyRSxPQUFPLEVBQUUsTUFBTSxDQUFDLEtBQUssRUFBRSxVQUFVLENBQUMsQ0FBQztJQUV2QyxPQUFPLE9BQU8sQ0FBQyxjQUFjLENBQUMsV0FBVyxFQUFFLE1BQU0sQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0FBQzFFLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxjQUFjLEdBQWlCO0lBQzFDLFVBQVUsRUFBRSxRQUFRO0lBQ3BCLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxRQUFpQztDQUM5QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgR2F0aGVyTmQsIEdhdGhlck5kSW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFJhbmssIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcblxuaW1wb3J0IHtnYXRoZXJOZEltcGx9IGZyb20gJy4vR2F0aGVyTmRfSW1wbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBnYXRoZXJOZChcbiAgICBhcmdzOiB7aW5wdXRzOiBHYXRoZXJOZElucHV0cywgYmFja2VuZDogTWF0aEJhY2tlbmRDUFV9KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmR9ID0gYXJncztcbiAgY29uc3Qge3BhcmFtcywgaW5kaWNlc30gPSBpbnB1dHM7XG5cbiAgY29uc3QgcGFyYW1zU2l6ZSA9IHV0aWwuc2l6ZUZyb21TaGFwZShwYXJhbXMuc2hhcGUpO1xuXG4gIGNvbnN0IGluZGljZXNTaGFwZSA9IGluZGljZXMuc2hhcGU7XG4gIGNvbnN0IHNsaWNlUmFuayA9IGluZGljZXNTaGFwZVtpbmRpY2VzU2hhcGUubGVuZ3RoIC0gMV07XG5cbiAgY29uc3QgW3Jlc3VsdFNoYXBlLCBudW1TbGljZXMsIHNsaWNlU2l6ZSwgc3RyaWRlc10gPVxuICAgICAgYmFja2VuZF91dGlsLnByZXBhcmVBbmRWYWxpZGF0ZShwYXJhbXMsIGluZGljZXMpO1xuICBpZiAobnVtU2xpY2VzID09PSAwKSB7XG4gICAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8ocmVzdWx0U2hhcGUsIHBhcmFtcy5kdHlwZSwgW10pO1xuICB9XG5cbiAgY29uc3QgaW5kaWNlc0RhdGEgPSBiYWNrZW5kLmRhdGEuZ2V0KGluZGljZXMuZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgcGFyYW1zQnVmID0gYmFja2VuZC5idWZmZXJTeW5jPFJhbmssICdmbG9hdDMyJz4ocGFyYW1zKTtcbiAgY29uc3Qgb3V0QnVmID0gZ2F0aGVyTmRJbXBsKFxuICAgICAgaW5kaWNlc0RhdGEsIHBhcmFtc0J1ZiwgcGFyYW1zLmR0eXBlLCBudW1TbGljZXMsIHNsaWNlUmFuaywgc2xpY2VTaXplLFxuICAgICAgc3RyaWRlcywgcGFyYW1zLnNoYXBlLCBwYXJhbXNTaXplKTtcblxuICByZXR1cm4gYmFja2VuZC5tYWtlVGVuc29ySW5mbyhyZXN1bHRTaGFwZSwgcGFyYW1zLmR0eXBlLCBvdXRCdWYudmFsdWVzKTtcbn1cblxuZXhwb3J0IGNvbnN0IGdhdGhlck5kQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEdhdGhlck5kLFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IGdhdGhlck5kIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==
|