"use strict";
|
Object.defineProperty(exports, "__esModule", { value: true });
|
/**
|
* @license
|
* Copyright 2018 Google Inc. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
var engine_1 = require("../engine");
|
var tensor_util_env_1 = require("../tensor_util_env");
|
var operation_1 = require("./operation");
|
/**
|
* Gather slices from input tensor into a Tensor with shape specified by
|
* `indices`.
|
*
|
* `indices` is an K-dimensional integer tensor, best thought of as a
|
* (K-1)-dimensional tensor of indices into input, where each element defines a
|
* slice of input:
|
* output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
|
*
|
* Whereas in `tf.gather`, `indices` defines slices into the first dimension of
|
* input, in `tf.gatherND`, `indices` defines slices into the first N dimensions
|
* of input, where N = indices.shape[-1].
|
*
|
* The last dimension of indices can be at most the rank of input:
|
* indices.shape[-1] <= input.rank
|
*
|
* The last dimension of `indices` corresponds to elements
|
* (if indices.shape[-1] == input.rank) or slices
|
* (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
|
* input.
|
* The output tensor has shape
|
* indices.shape[:-1] + input.shape[indices.shape[-1]:]
|
*
|
* Note that on CPU, if an out of bound index is found, an error is returned. On
|
* GPU, if an out of bound index is found, a 0 is stored in the corresponding
|
* output value.
|
*
|
* ```js
|
* const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
|
* const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
|
* tf.gatherND(input, indices).print() // [10, 11]
|
* ```
|
*
|
* @param x The tensor from which to gather values.
|
* @param indices Index tensor, must be of type int32.
|
*/
|
/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */
|
function gatherND_(x, indices) {
|
var $indices = tensor_util_env_1.convertToTensor(indices, 'indices', 'gatherND', 'int32');
|
var $x = tensor_util_env_1.convertToTensor(x, 'x', 'gatherND');
|
return engine_1.ENGINE.runKernelFunc(function (backend) { return backend.gatherND($x, $indices); }, { x: $x, indices: $indices }, null /* backward */, 'GatherNd');
|
}
|
exports.gatherND = operation_1.op({ gatherND_: gatherND_ });
|
//# sourceMappingURL=gather_nd.js.map
|