"use strict";
|
Object.defineProperty(exports, "__esModule", { value: true });
|
var util_1 = require("../util");
|
/**
|
* Validate gather nd inputs.
|
*
|
* @param tensor The tensor contains the source values.
|
* @param indices The tensor contains the indices to slice the source.
|
*
|
* @returns [resultShape, numUpdates, sliceSize, strides]
|
*/
|
function prepareAndValidate(tensor, indices) {
|
if (tensor.rank < 1) {
|
throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
|
(" but the rank was " + tensor.rank + "."));
|
}
|
if (indices.rank < 1) {
|
throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
|
(" but the rank was " + indices.rank + "."));
|
}
|
if (indices.dtype !== 'int32') {
|
throw new Error('tf.gatherND() expects the indices to be int32 type,' +
|
(" but the dtype was " + indices.dtype + "."));
|
}
|
if (indices.shape[indices.rank - 1] > tensor.rank) {
|
throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
|
(indices.shape[indices.rank - 1] + " vs. " + tensor.rank));
|
}
|
if (tensor.size === 0) {
|
throw new Error('Requested more than 0 entries, but input is empty.' +
|
(" Input shape: " + tensor.shape + "."));
|
}
|
var indicesShape = indices.shape;
|
var sliceRank = indicesShape[indicesShape.length - 1];
|
// The result shape is
|
// indices.shape[:-1] + params.shape[indices.shape[-1]:]
|
var nResult = 1;
|
for (var i = 0; i < indicesShape.length - 1; ++i) {
|
nResult *= indicesShape[i];
|
}
|
var inputShape = tensor.shape;
|
var resultShape = indicesShape.slice();
|
resultShape.pop();
|
var sliceSize = 1;
|
for (var i = sliceRank; i < tensor.rank; ++i) {
|
sliceSize *= inputShape[i];
|
resultShape.push(inputShape[i]);
|
}
|
var strides = util_1.computeStrides(tensor.shape).map(function (stride) { return stride / sliceSize; }).concat([1]).slice(0, sliceRank);
|
return [resultShape, nResult, sliceSize, strides];
|
}
|
exports.prepareAndValidate = prepareAndValidate;
|
//# sourceMappingURL=gather_nd_util.js.map
|