import { computeStrides, sizeFromShape } from '../util'; /** * Validate gather nd inputs. * * @param tensor The tensor contains the source values. * @param indices The tensor contains the indices to slice the source. * * @returns [resultShape, numUpdates, sliceSize, strides] */ export function prepareAndValidate(tensor, indices) { const tensorRank = tensor.shape.length; const indicesRank = indices.shape.length; if (tensorRank < 1) { throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' + ` but the rank was ${tensorRank}.`); } if (indicesRank < 1) { throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' + ` but the rank was ${indicesRank}.`); } if (indices.dtype !== 'int32') { throw new Error('tf.gatherND() expects the indices to be int32 type,' + ` but the dtype was ${indices.dtype}.`); } if (indices.shape[indicesRank - 1] > tensorRank) { throw new Error('index innermost dimension length must be <= tensor rank; saw: ' + `${indices.shape[indicesRank - 1]} vs. ${tensorRank}`); } if (sizeFromShape(tensor.shape) === 0) { throw new Error('Requested more than 0 entries, but input is empty.' + ` Input shape: ${tensor.shape}.`); } const indicesShape = indices.shape; const sliceRank = indicesShape[indicesShape.length - 1]; // The result shape is // indices.shape[:-1] + params.shape[indices.shape[-1]:] let nResult = 1; for (let i = 0; i < indicesShape.length - 1; ++i) { nResult *= indicesShape[i]; } const inputShape = tensor.shape; const resultShape = indicesShape.slice(); resultShape.pop(); let sliceSize = 1; for (let i = sliceRank; i < tensorRank; ++i) { sliceSize *= inputShape[i]; resultShape.push(inputShape[i]); } const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize), 1].slice(0, sliceRank); return [resultShape, nResult, sliceSize, strides]; } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZ2F0aGVyX25kX3V0aWwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy9nYXRoZXJfbmRfdXRpbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFpQkEsT0FBTyxFQUFDLGNBQWMsRUFBRSxhQUFhLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFFdEQ7Ozs7Ozs7R0FPRztBQUNILE1BQU0sVUFBVSxrQkFBa0IsQ0FBQyxNQUFrQixFQUFFLE9BQW1CO0lBRXhFLE1BQU0sVUFBVSxHQUFHLE1BQU0sQ0FBQyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBQ3ZDLE1BQU0sV0FBVyxHQUFHLE9BQU8sQ0FBQyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBQ3pDLElBQUksVUFBVSxHQUFHLENBQUMsRUFBRTtRQUNsQixNQUFNLElBQUksS0FBSyxDQUNYLHlEQUF5RDtZQUN6RCxxQkFBcUIsVUFBVSxHQUFHLENBQUMsQ0FBQztLQUN6QztJQUNELElBQUksV0FBVyxHQUFHLENBQUMsRUFBRTtRQUNuQixNQUFNLElBQUksS0FBSyxDQUNYLDJEQUEyRDtZQUMzRCxxQkFBcUIsV0FBVyxHQUFHLENBQUMsQ0FBQztLQUMxQztJQUNELElBQUksT0FBTyxDQUFDLEtBQUssS0FBSyxPQUFPLEVBQUU7UUFDN0IsTUFBTSxJQUFJLEtBQUssQ0FDWCxxREFBcUQ7WUFDckQsc0JBQXNCLE9BQU8sQ0FBQyxLQUFLLEdBQUcsQ0FBQyxDQUFDO0tBQzdDO0lBQ0QsSUFBSSxPQUFPLENBQUMsS0FBSyxDQUFDLFdBQVcsR0FBRyxDQUFDLENBQUMsR0FBRyxVQUFVLEVBQUU7UUFDL0MsTUFBTSxJQUFJLEtBQUssQ0FDWCxnRUFBZ0U7WUFDaEUsR0FBRyxPQUFPLENBQUMsS0FBSyxDQUFDLFdBQVcsR0FBRyxDQUFDLENBQUMsUUFBUSxVQUFVLEVBQUUsQ0FBQyxDQUFDO0tBQzVEO0lBRUQsSUFBSSxhQUFhLENBQUMsTUFBTSxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQUMsRUFBRTtRQUNyQyxNQUFNLElBQUksS0FBSyxDQUNYLG9EQUFvRDtZQUNwRCxpQkFBaUIsTUFBTSxDQUFDLEtBQUssR0FBRyxDQUFDLENBQUM7S0FDdkM7SUFFRCxNQUFNLFlBQVksR0FBRyxPQUFPLENBQUMsS0FBSyxDQUFDO0lBQ25DLE1BQU0sU0FBUyxHQUFHLFlBQVksQ0FBQyxZQUFZLENBQUMsTUFBTSxHQUFHLENBQUMsQ0FBQyxDQUFDO0lBRXhELHNCQUFzQjtJQUN0QiwwREFBMEQ7SUFDMUQsSUFBSSxPQUFPLEdBQUcsQ0FBQyxDQUFDO0lBQ2hCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxZQUFZLENBQUMsTUFBTSxHQUFHLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTtRQUNoRCxPQUFPLElBQUksWUFBWSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQzVCO0lBRUQsTUFBTSxVQUFVLEdBQUcsTUFBTSxDQUFDLEtBQUssQ0FBQztJQUVoQyxNQUFNLFdBQVcsR0FBRyxZQUFZLENBQUMsS0FBSyxFQUFFLENBQUM7SUFDekMsV0FBVyxDQUFDLEdBQUcsRUFBRSxDQUFDO0lBRWxCLElBQUksU0FBUyxHQUFHLENBQUMsQ0FBQztJQUNsQixLQUFLLElBQUksQ0FBQyxHQUFHLFNBQVMsRUFBRSxDQUFDLEdBQUcsVUFBVSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzNDLFNBQVMsSUFBSSxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDM0IsV0FBVyxDQUFDLElBQUksQ0FBQyxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUNqQztJQUVELE1BQU0sT0FBTyxHQUNULENBQUMsR0FBRyxjQUFjLENBQUMsTUFBTSxDQUFDLEtBQUssQ0FBQyxDQUFDLEdBQUcsQ0FBQyxNQUFNLENBQUMsRUFBRSxDQUFDLE1BQU0sR0FBRyxTQUFTLENBQUM7UUFDakUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUU1QixPQUFPLENBQUMsV0FBVyxFQUFFLE9BQU8sRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFDLENBQUM7QUFDcEQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cbmltcG9ydCB7IFRlbnNvckluZm8gfSBmcm9tICcuLi90ZW5zb3JfaW5mbyc7XG5pbXBvcnQge2NvbXB1dGVTdHJpZGVzLCBzaXplRnJvbVNoYXBlfSBmcm9tICcuLi91dGlsJztcblxuLyoqXG4gKiBWYWxpZGF0ZSBnYXRoZXIgbmQgaW5wdXRzLlxuICpcbiAqIEBwYXJhbSB0ZW5zb3IgVGhlIHRlbnNvciBjb250YWlucyB0aGUgc291cmNlIHZhbHVlcy5cbiAqIEBwYXJhbSBpbmRpY2VzIFRoZSB0ZW5zb3IgY29udGFpbnMgdGhlIGluZGljZXMgdG8gc2xpY2UgdGhlIHNvdXJjZS5cbiAqXG4gKiBAcmV0dXJucyBbcmVzdWx0U2hhcGUsIG51bVVwZGF0ZXMsIHNsaWNlU2l6ZSwgc3RyaWRlc11cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIHByZXBhcmVBbmRWYWxpZGF0ZSh0ZW5zb3I6IFRlbnNvckluZm8sIGluZGljZXM6IFRlbnNvckluZm8pOlxuICAgIFtudW1iZXJbXSwgbnVtYmVyLCBudW1iZXIsIG51bWJlcltdXSB7XG4gIGNvbnN0IHRlbnNvclJhbmsgPSB0ZW5zb3Iuc2hhcGUubGVuZ3RoO1xuICBjb25zdCBpbmRpY2VzUmFuayA9IGluZGljZXMuc2hhcGUubGVuZ3RoO1xuICBpZiAodGVuc29yUmFuayA8IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICd0Zi5nYXRoZXJORCgpIGV4cGVjdHMgdGhlIGlucHV0IHRvIGJlIHJhbmsgMSBvciBoaWdoZXIsJyArXG4gICAgICAgIGAgYnV0IHRoZSByYW5rIHdhcyAke3RlbnNvclJhbmt9LmApO1xuICB9XG4gIGlmIChpbmRpY2VzUmFuayA8IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICd0Zi5nYXRoZXJORCgpIGV4cGVjdHMgdGhlIGluZGljZXMgdG8gYmUgcmFuayAxIG9yIGhpZ2hlciwnICtcbiAgICAgICAgYCBidXQgdGhlIHJhbmsgd2FzICR7aW5kaWNlc1Jhbmt9LmApO1xuICB9XG4gIGlmIChpbmRpY2VzLmR0eXBlICE9PSAnaW50MzInKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICAndGYuZ2F0aGVyTkQoKSBleHBlY3RzIHRoZSBpbmRpY2VzIHRvIGJlIGludDMyIHR5cGUsJyArXG4gICAgICAgIGAgYnV0IHRoZSBkdHlwZSB3YXMgJHtpbmRpY2VzLmR0eXBlfS5gKTtcbiAgfVxuICBpZiAoaW5kaWNlcy5zaGFwZVtpbmRpY2VzUmFuayAtIDFdID4gdGVuc29yUmFuaykge1xuICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgJ2luZGV4IGlubmVybW9zdCBkaW1lbnNpb24gbGVuZ3RoIG11c3QgYmUgPD0gdGVuc29yIHJhbms7IHNhdzogJyArXG4gICAgICAgIGAke2luZGljZXMuc2hhcGVbaW5kaWNlc1JhbmsgLSAxXX0gdnMuICR7dGVuc29yUmFua31gKTtcbiAgfVxuXG4gIGlmIChzaXplRnJvbVNoYXBlKHRlbnNvci5zaGFwZSkgPT09IDApIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICdSZXF1ZXN0ZWQgbW9yZSB0aGFuIDAgZW50cmllcywgYnV0IGlucHV0IGlzIGVtcHR5LicgK1xuICAgICAgICBgIElucHV0IHNoYXBlOiAke3RlbnNvci5zaGFwZX0uYCk7XG4gIH1cblxuICBjb25zdCBpbmRpY2VzU2hhcGUgPSBpbmRpY2VzLnNoYXBlO1xuICBjb25zdCBzbGljZVJhbmsgPSBpbmRpY2VzU2hhcGVbaW5kaWNlc1NoYXBlLmxlbmd0aCAtIDFdO1xuXG4gIC8vIFRoZSByZXN1bHQgc2hhcGUgaXNcbiAgLy8gICBpbmRpY2VzLnNoYXBlWzotMV0gKyBwYXJhbXMuc2hhcGVbaW5kaWNlcy5zaGFwZVstMV06XVxuICBsZXQgblJlc3VsdCA9IDE7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgaW5kaWNlc1NoYXBlLmxlbmd0aCAtIDE7ICsraSkge1xuICAgIG5SZXN1bHQgKj0gaW5kaWNlc1NoYXBlW2ldO1xuICB9XG5cbiAgY29uc3QgaW5wdXRTaGFwZSA9IHRlbnNvci5zaGFwZTtcblxuICBjb25zdCByZXN1bHRTaGFwZSA9IGluZGljZXNTaGFwZS5zbGljZSgpO1xuICByZXN1bHRTaGFwZS5wb3AoKTtcblxuICBsZXQgc2xpY2VTaXplID0gMTtcbiAgZm9yIChsZXQgaSA9IHNsaWNlUmFuazsgaSA8IHRlbnNvclJhbms7ICsraSkge1xuICAgIHNsaWNlU2l6ZSAqPSBpbnB1dFNoYXBlW2ldO1xuICAgIHJlc3VsdFNoYXBlLnB1c2goaW5wdXRTaGFwZVtpXSk7XG4gIH1cblxuICBjb25zdCBzdHJpZGVzID1cbiAgICAgIFsuLi5jb21wdXRlU3RyaWRlcyh0ZW5zb3Iuc2hhcGUpLm1hcChzdHJpZGUgPT4gc3RyaWRlIC8gc2xpY2VTaXplKSxcbiAgICAgICAxXS5zbGljZSgwLCBzbGljZVJhbmspO1xuXG4gIHJldHVybiBbcmVzdWx0U2hhcGUsIG5SZXN1bHQsIHNsaWNlU2l6ZSwgc3RyaWRlc107XG59XG4iXX0=