/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { util } from '@tensorflow/tfjs-core';
|
function validateIndices(indices, indicesShape, numParams) {
|
indices.forEach((index, i) => {
|
if (index < 0 || index >= numParams) {
|
const locString = util.indexToLoc(i, indicesShape.length, util.computeStrides(indicesShape))
|
.join(',');
|
throw new Error(`indices[${locString}] = ${index} is not in [0, ${numParams})`);
|
}
|
});
|
}
|
function validateSplits(paramsNestedSplits, numParamsDenseValues) {
|
// Validate
|
for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
|
const splits = paramsNestedSplits[dim];
|
const lastSplit = (dim === paramsNestedSplits.length - 1) ?
|
numParamsDenseValues :
|
paramsNestedSplits[dim + 1].length;
|
if (splits.length === 0) {
|
throw new Error('Ragged splits may not be empty');
|
}
|
if (splits[0] < 0) {
|
throw new Error('Ragged splits must be non-negative');
|
}
|
if (splits[splits.length - 1] > lastSplit) {
|
throw new Error('Ragged splits must not point past values');
|
}
|
for (let i = 1; i < splits.length; ++i) {
|
if (splits[i - 1] > splits[i]) {
|
throw new Error('Ragged splits must be sorted in ascending order');
|
}
|
}
|
}
|
}
|
// Construct the `splits` output tensors, encoded using a nested vector.
|
// Also find the slices of values that need to be copied, and store them
|
// in `valueSlices`. The total number of values that will be copied (which
|
// we need for allocating the output values tensor) is stored in `numValues`.
|
function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) {
|
const valueSlices = [];
|
let numValues = 0;
|
const numSplits = indicesShape.length - 1 + paramsNestedSplits.length;
|
const outSplits = new Array(numSplits).fill(null).map(() => [0]);
|
validateSplits(paramsNestedSplits, numParamsDenseValues);
|
// Add `splits` that come from all but the last dimension of the dense
|
// Tensor `indices`. In particular, for each dimension D, we add a
|
// splits tensor whose values are:
|
// range(reduceProd(splits.shape[:D]) + 1) * splits.shape[D+1]
|
// E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
|
// [0, 3, 6] # length=2+1, stride=3
|
// [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4
|
let nrows = 1;
|
for (let dim = 0; dim < indicesShape.length - 1; ++dim) {
|
nrows *= indicesShape[dim];
|
const rowLength = indicesShape[dim + 1];
|
for (let i = 1; i < nrows + 1; ++i) {
|
outSplits[dim].push(i * rowLength);
|
}
|
}
|
// Add `splits` that come from `paramsNestedSplits`. Starting with the
|
// outermost ragged dimension (i.e., the first `splits` tensor), we work
|
// our way in, finding the range of values that should be copied. As we
|
// go, we update the output `splits` for each dimension with the appropriate
|
// values. In particular, the *lengths* of the slices from `param_splits`
|
// should be copied to generate corresponding slice lengths in the output
|
// splits. E.g., if we are copying a ragged row with length 4, then we
|
// should add a new split point to outSplits that is 4 greater than the
|
// previous split point in outSplits.
|
for (let i = 0; i < indices.length; ++i) {
|
let start = indices[i];
|
let limit = indices[i] + 1;
|
// Copy splits.
|
for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
|
const splits = paramsNestedSplits[dim];
|
const outDim = dim + indicesShape.length - 1;
|
if (outDim >= 0) {
|
const outSplitsOutDim = outSplits[outDim];
|
const delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start];
|
for (let j = start; j < limit; ++j) {
|
outSplits[outDim].push(splits[j + 1] + delta);
|
}
|
}
|
start = splits[start];
|
limit = splits[limit];
|
}
|
if (limit !== start) {
|
valueSlices.push([start, limit]);
|
numValues += limit - start;
|
}
|
}
|
return { outSplits, valueSlices, numValues };
|
}
|
function getSplits(outSplits) {
|
const splitsOut = [];
|
for (let i = 0; i < outSplits.length; ++i) {
|
const numSplits = outSplits[i].length;
|
const splits = util.getArrayFromDType('int32', numSplits);
|
splitsOut.push(splits);
|
outSplits[i].forEach((value, j) => splits[j] = value);
|
}
|
return splitsOut;
|
}
|
function computeFlatOuterDims(orig, numOutDims) {
|
const outDims = orig.slice(0, numOutDims);
|
while (outDims.length < numOutDims) {
|
outDims.push(1);
|
}
|
for (let inDim = numOutDims; inDim < orig.length; inDim++) {
|
outDims[numOutDims - 1] *= orig[inDim];
|
}
|
return outDims;
|
}
|
// For each slice in `(start, limit)` in `valueSlices`, append
|
// `paramsDenseValues[start,...,limit] to `values`. `valueSize` indicates
|
// the number of scalars contained in each value paramsDenseValues[i].
|
function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) {
|
const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1];
|
const valuesM = computeFlatOuterDims(valuesShape, 2)[1];
|
let outPos = 0;
|
for (const slice of valueSlices) {
|
for (let i = slice[0]; i < slice[1]; ++i) {
|
for (let j = 0; j < valueSize; ++j) {
|
values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j];
|
}
|
++outPos;
|
}
|
}
|
}
|
function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) {
|
const valuesShape = paramsDenseValuesShape.slice();
|
valuesShape[0] = numValues;
|
const valuesOut = util.getArrayFromDType(paramsDenseValuesDType, util.sizeFromShape(valuesShape));
|
const numElements = paramsDenseValues.length;
|
const valueSize = numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]);
|
writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape);
|
return [valuesOut, valuesShape];
|
}
|
export function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) {
|
if (paramsNestedSplits.length === 0) {
|
throw new Error('paramsNestedSplits must be non empty');
|
}
|
if (paramsNestedSplitsShapes[0].length === 0) {
|
throw new Error('Split tensors must not be scalars');
|
}
|
const numParams = paramsNestedSplitsShapes[0][0] - 1;
|
validateIndices(indices, indicesShape, numParams);
|
if (paramsDenseValuesShape.length === 0) {
|
throw new Error('params.rank must be nonzero');
|
}
|
const numParamsDenseValues = paramsDenseValuesShape[0];
|
// Calculate the `splits`, and store the value slices that we need to
|
// copy in `valueSlices`.
|
const { outSplits, valueSlices, numValues } = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues);
|
// Write the output tensors.
|
const outputNestedSplits = getSplits(outSplits);
|
const outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues);
|
return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]];
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"RaggedGather_impl.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAuB,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAEjE,SAAS,eAAe,CACpB,OAAmB,EAAE,YAAsB,EAAE,SAAiB;IAChE,OAAO,CAAC,OAAO,CAAC,CAAC,KAAa,EAAE,CAAS,EAAE,EAAE;QAC3C,IAAI,KAAK,GAAG,CAAC,IAAI,KAAK,IAAI,SAAS,EAAE;YACnC,MAAM,SAAS,GACX,IAAI,CAAC,UAAU,CACP,CAAC,EAAE,YAAY,CAAC,MAAM,EAAE,IAAI,CAAC,cAAc,CAAC,YAAY,CAAC,CAAC;iBAC7D,IAAI,CAAC,GAAG,CAAC,CAAC;YACnB,MAAM,IAAI,KAAK,CACX,WAAW,SAAS,OAAO,KAAK,kBAAkB,SAAS,GAAG,CAAC,CAAC;SACrE;IACH,CAAC,CAAC,CAAC;AACL,CAAC;AAED,SAAS,cAAc,CACnB,kBAAgC,EAAE,oBAA4B;IAChE,WAAW;IACX,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,kBAAkB,CAAC,MAAM,EAAE,EAAE,GAAG,EAAE;QACxD,MAAM,MAAM,GAAG,kBAAkB,CAAC,GAAG,CAAC,CAAC;QACvC,MAAM,SAAS,GAAG,CAAC,GAAG,KAAK,kBAAkB,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC;YACvD,oBAAoB,CAAC,CAAC;YACtB,kBAAkB,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC;QACvC,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;YACvB,MAAM,IAAI,KAAK,CAAC,gCAAgC,CAAC,CAAC;SACnD;QACD,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE;YACjB,MAAM,IAAI,KAAK,CAAC,oCAAoC,CAAC,CAAC;SACvD;QACD,IAAI,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,SAAS,EAAE;YACzC,MAAM,IAAI,KAAK,CAAC,0CAA0C,CAAC,CAAC;SAC7D;QACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YACtC,IAAI,MAAM,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,EAAE;gBAC7B,MAAM,IAAI,KAAK,CAAC,iDAAiD,CAAC,CAAC;aACpE;SACF;KACF;AACH,CAAC;AAED,wEAAwE;AACxE,wEAAwE;AACxE,2EAA2E;AAC3E,6EAA6E;AAC7E,SAAS,UAAU,CACf,OAAmB,EAAE,YAAsB,EAC3C,kBAAgC,EAAE,oBAA4B;IAChE,MAAM,WAAW,GAA4B,EAAE,CAAC;IAChD,IAAI,SAAS,GAAG,CAAC,CAAC;IAElB,MAAM,SAAS,GAAG,YAAY,CAAC,MAAM,GAAG,CAAC,GAAG,kBAAkB,CAAC,MAAM,CAAC;IACtE,MAAM,SAAS,GAAG,IAAI,KAAK,CAAC,SAAS,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAEjE,cAAc,CAAC,kBAAkB,EAAE,oBAAoB,CAAC,CAAC;IAEzD,sEAAsE;IACtE,mEAAmE;IACnE,kCAAkC;IAClC,gEAAgE;IAChE,oEAAoE;IACpE,wDAAwD;IACxD,0DAA0D;IAC1D,IAAI,KAAK,GAAG,CAAC,CAAC;IACd,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,YAAY,CAAC,MAAM,GAAG,CAAC,EAAE,EAAE,GAAG,EAAE;QACtD,KAAK,IAAI,YAAY,CAAC,GAAG,CAAC,CAAC;QAC3B,MAAM,SAAS,GAAG,YAAY,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC;QACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,GAAG,CAAC,EAAE,EAAE,CAAC,EAAE;YAClC,SAAS,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,GAAG,SAAS,CAAC,CAAC;SACpC;KACF;IAED,uEAAuE;IACvE,wEAAwE;IACxE,wEAAwE;IACxE,4EAA4E;IAC5E,0EAA0E;IAC1E,yEAAyE;IACzE,uEAAuE;IACvE,uEAAuE;IACvE,qCAAqC;IACrC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACvC,IAAI,KAAK,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;QACvB,IAAI,KAAK,GAAG,OAAO,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;QAE3B,eAAe;QACf,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,kBAAkB,CAAC,MAAM,EAAE,EAAE,GAAG,EAAE;YACxD,MAAM,MAAM,GAAG,kBAAkB,CAAC,GAAG,CAAC,CAAC;YACvC,MAAM,MAAM,GAAG,GAAG,GAAG,YAAY,CAAC,MAAM,GAAG,CAAC,CAAC;YAC7C,IAAI,MAAM,IAAI,CAAC,EAAE;gBACf,MAAM,eAAe,GAAG,SAAS,CAAC,MAAM,CAAC,CAAC;gBAC1C,MAAM,KAAK,GACP,eAAe,CAAC,eAAe,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC;gBAChE,KAAK,IAAI,CAAC,GAAG,KAAK,EAAE,CAAC,GAAG,KAAK,EAAE,EAAE,CAAC,EAAE;oBAClC,SAAS,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC;iBAC/C;aACF;YACD,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC;YACtB,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC;SACvB;QACD,IAAI,KAAK,KAAK,KAAK,EAAE;YACnB,WAAW,CAAC,IAAI,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC;YACjC,SAAS,IAAI,KAAK,GAAG,KAAK,CAAC;SAC5B;KACF;IAED,OAAO,EAAC,SAAS,EAAE,WAAW,EAAE,SAAS,EAAC,CAAC;AAC7C,CAAC;AAED,SAAS,SAAS,CAAC,SAAqB;IACtC,MAAM,SAAS,GAAiB,EAAE,CAAC;IACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACzC,MAAM,SAAS,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;QACtC,MAAM,MAAM,GAAG,IAAI,CAAC,iBAAiB,CAAC,OAAO,EAAE,SAAS,CAAe,CAAC;QACxE,SAAS,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;QAEvB,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,CAAS,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC;KAC/D;IAED,OAAO,SAAS,CAAC;AACnB,CAAC;AAED,SAAS,oBAAoB,CAAC,IAAc,EAAE,UAAkB;IAC9D,MAAM,OAAO,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC;IAC1C,OAAO,OAAO,CAAC,MAAM,GAAG,UAAU,EAAE;QAClC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;KACjB;IAED,KAAK,IAAI,KAAK,GAAG,UAAU,EAAE,KAAK,GAAG,IAAI,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE;QACzD,OAAO,CAAC,UAAU,GAAG,CAAC,CAAC,IAAI,IAAI,CAAC,KAAK,CAAC,CAAC;KACxC;IAED,OAAO,OAAO,CAAC;AACjB,CAAC;AACD,8DAA8D;AAC9D,0EAA0E;AAC1E,sEAAsE;AACtE,SAAS,gBAAgB,CACrB,iBAA6B,EAAE,sBAAgC,EAC/D,WAAoC,EAAE,SAAiB,EAAE,MAAkB,EAC3E,WAAqB;IACvB,MAAM,MAAM,GAAG,oBAAoB,CAAC,sBAAsB,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAClE,MAAM,OAAO,GAAG,oBAAoB,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAExD,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,KAAK,MAAM,KAAK,IAAI,WAAW,EAAE;QAC/B,KAAK,IAAI,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE;YACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,EAAE,CAAC,EAAE;gBAClC,MAAM,CAAC,MAAM,GAAG,OAAO,GAAG,CAAC,CAAC,GAAG,iBAAiB,CAAC,CAAC,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;aAClE;YACD,EAAE,MAAM,CAAC;SACV;KACF;AACH,CAAC;AAED,SAAS,SAAS,CACd,iBAA6B,EAAE,sBAAgC,EAC/D,sBAAgC,EAAE,WAAoC,EACtE,SAAiB;IACnB,MAAM,WAAW,GAAG,sBAAsB,CAAC,KAAK,EAAE,CAAC;IACnD,WAAW,CAAC,CAAC,CAAC,GAAG,SAAS,CAAC;IAE3B,MAAM,SAAS,GAAG,IAAI,CAAC,iBAAiB,CAClB,sBAAsB,EACtB,IAAI,CAAC,aAAa,CAAC,WAAW,CAAC,CAAe,CAAC;IAErE,MAAM,WAAW,GAAG,iBAAiB,CAAC,MAAM,CAAC;IAC7C,MAAM,SAAS,GACX,WAAW,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,WAAW,GAAG,sBAAsB,CAAC,CAAC,CAAC,CAAC,CAAC;IACtE,gBAAgB,CACZ,iBAAiB,EAAE,sBAAsB,EAAE,WAAW,EAAE,SAAS,EACjE,SAAS,EAAE,WAAW,CAAC,CAAC;IAE5B,OAAO,CAAC,SAAS,EAAE,WAAW,CAAC,CAAC;AAClC,CAAC;AACD,MAAM,UAAU,gBAAgB,CAC5B,kBAAgC,EAAE,wBAAoC,EACtE,iBAA6B,EAAE,sBAAgC,EAC/D,sBAAgC,EAAE,OAAmB,EACrD,YAAsB,EACtB,gBAAwB;IAC1B,IAAI,kBAAkB,CAAC,MAAM,KAAK,CAAC,EAAE;QACnC,MAAM,IAAI,KAAK,CAAC,sCAAsC,CAAC,CAAC;KACzD;IAED,IAAI,wBAAwB,CAAC,CAAC,CAAC,CAAC,MAAM,KAAK,CAAC,EAAE;QAC5C,MAAM,IAAI,KAAK,CAAC,mCAAmC,CAAC,CAAC;KACtD;IACD,MAAM,SAAS,GAAG,wBAAwB,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;IACrD,eAAe,CAAC,OAAO,EAAE,YAAY,EAAE,SAAS,CAAC,CAAC;IAElD,IAAI,sBAAsB,CAAC,MAAM,KAAK,CAAC,EAAE;QACvC,MAAM,IAAI,KAAK,CAAC,6BAA6B,CAAC,CAAC;KAChD;IACD,MAAM,oBAAoB,GAAG,sBAAsB,CAAC,CAAC,CAAC,CAAC;IAEvD,qEAAqE;IACrE,yBAAyB;IACzB,MAAM,EAAC,SAAS,EAAE,WAAW,EAAE,SAAS,EAAC,GAAG,UAAU,CAClD,OAAO,EAAE,YAAY,EAAE,kBAAkB,EAAE,oBAAoB,CAAC,CAAC;IAErE,4BAA4B;IAC5B,MAAM,kBAAkB,GAAG,SAAS,CAAC,SAAS,CAAC,CAAC;IAChD,MAAM,iBAAiB,GAAG,SAAS,CAC/B,iBAAiB,EAAE,sBAAsB,EAAE,sBAAsB,EACjE,WAAW,EAAE,SAAS,CAAC,CAAC;IAE5B,OAAO,CAAC,kBAAkB,EAAE,iBAAiB,CAAC,CAAC,CAAC,EAAE,iBAAiB,CAAC,CAAC,CAAC,CAAC,CAAC;AAC1E,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2022 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {DataType, TypedArray, util} from '@tensorflow/tfjs-core';\n\nfunction validateIndices(\n    indices: TypedArray, indicesShape: number[], numParams: number) {\n  indices.forEach((index: number, i: number) => {\n    if (index < 0 || index >= numParams) {\n      const locString =\n          util.indexToLoc(\n                  i, indicesShape.length, util.computeStrides(indicesShape))\n              .join(',');\n      throw new Error(\n          `indices[${locString}] = ${index} is not in [0, ${numParams})`);\n    }\n  });\n}\n\nfunction validateSplits(\n    paramsNestedSplits: TypedArray[], numParamsDenseValues: number) {\n  // Validate\n  for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {\n    const splits = paramsNestedSplits[dim];\n    const lastSplit = (dim === paramsNestedSplits.length - 1) ?\n        numParamsDenseValues :\n        paramsNestedSplits[dim + 1].length;\n    if (splits.length === 0) {\n      throw new Error('Ragged splits may not be empty');\n    }\n    if (splits[0] < 0) {\n      throw new Error('Ragged splits must be non-negative');\n    }\n    if (splits[splits.length - 1] > lastSplit) {\n      throw new Error('Ragged splits must not point past values');\n    }\n    for (let i = 1; i < splits.length; ++i) {\n      if (splits[i - 1] > splits[i]) {\n        throw new Error('Ragged splits must be sorted in ascending order');\n      }\n    }\n  }\n}\n\n// Construct the `splits` output tensors, encoded using a nested vector.\n// Also find the slices of values that need to be copied, and store them\n// in `valueSlices`.  The total number of values that will be copied (which\n// we need for allocating the output values tensor) is stored in `numValues`.\nfunction makeSplits(\n    indices: TypedArray, indicesShape: number[],\n    paramsNestedSplits: TypedArray[], numParamsDenseValues: number) {\n  const valueSlices: Array<[number, number]> = [];\n  let numValues = 0;\n\n  const numSplits = indicesShape.length - 1 + paramsNestedSplits.length;\n  const outSplits = new Array(numSplits).fill(null).map(() => [0]);\n\n  validateSplits(paramsNestedSplits, numParamsDenseValues);\n\n  // Add `splits` that come from all but the last dimension of the dense\n  // Tensor `indices`.  In particular, for each dimension D, we add a\n  // splits tensor whose values are:\n  //   range(reduceProd(splits.shape[:D]) + 1) * splits.shape[D+1]\n  // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:\n  //   [0, 3, 6]                    # length=2+1, stride=3\n  //   [0, 4, 8, 12, 16, 20, 24]    # length=2*3+1, stride=4\n  let nrows = 1;\n  for (let dim = 0; dim < indicesShape.length - 1; ++dim) {\n    nrows *= indicesShape[dim];\n    const rowLength = indicesShape[dim + 1];\n    for (let i = 1; i < nrows + 1; ++i) {\n      outSplits[dim].push(i * rowLength);\n    }\n  }\n\n  // Add `splits` that come from `paramsNestedSplits`.  Starting with the\n  // outermost ragged dimension (i.e., the first `splits` tensor), we work\n  // our way in, finding the range of values that should be copied.  As we\n  // go, we update the output `splits` for each dimension with the appropriate\n  // values.  In particular, the *lengths* of the slices from `param_splits`\n  // should be copied to generate corresponding slice lengths in the output\n  // splits.  E.g., if we are copying a ragged row with length 4, then we\n  // should add a new split point to outSplits that is 4 greater than the\n  // previous split point in outSplits.\n  for (let i = 0; i < indices.length; ++i) {\n    let start = indices[i];\n    let limit = indices[i] + 1;\n\n    // Copy splits.\n    for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {\n      const splits = paramsNestedSplits[dim];\n      const outDim = dim + indicesShape.length - 1;\n      if (outDim >= 0) {\n        const outSplitsOutDim = outSplits[outDim];\n        const delta =\n            outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start];\n        for (let j = start; j < limit; ++j) {\n          outSplits[outDim].push(splits[j + 1] + delta);\n        }\n      }\n      start = splits[start];\n      limit = splits[limit];\n    }\n    if (limit !== start) {\n      valueSlices.push([start, limit]);\n      numValues += limit - start;\n    }\n  }\n\n  return {outSplits, valueSlices, numValues};\n}\n\nfunction getSplits(outSplits: number[][]) {\n  const splitsOut: TypedArray[] = [];\n  for (let i = 0; i < outSplits.length; ++i) {\n    const numSplits = outSplits[i].length;\n    const splits = util.getArrayFromDType('int32', numSplits) as TypedArray;\n    splitsOut.push(splits);\n\n    outSplits[i].forEach((value, j: number) => splits[j] = value);\n  }\n\n  return splitsOut;\n}\n\nfunction computeFlatOuterDims(orig: number[], numOutDims: number) {\n  const outDims = orig.slice(0, numOutDims);\n  while (outDims.length < numOutDims) {\n    outDims.push(1);\n  }\n\n  for (let inDim = numOutDims; inDim < orig.length; inDim++) {\n    outDims[numOutDims - 1] *= orig[inDim];\n  }\n\n  return outDims;\n}\n// For each slice in `(start, limit)` in `valueSlices`, append\n// `paramsDenseValues[start,...,limit] to `values`.  `valueSize` indicates\n// the number of scalars contained in each value paramsDenseValues[i].\nfunction writeValueSlices(\n    paramsDenseValues: TypedArray, paramsDenseValuesShape: number[],\n    valueSlices: Array<[number, number]>, valueSize: number, values: TypedArray,\n    valuesShape: number[]) {\n  const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1];\n  const valuesM = computeFlatOuterDims(valuesShape, 2)[1];\n\n  let outPos = 0;\n  for (const slice of valueSlices) {\n    for (let i = slice[0]; i < slice[1]; ++i) {\n      for (let j = 0; j < valueSize; ++j) {\n        values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j];\n      }\n      ++outPos;\n    }\n  }\n}\n\nfunction getValues(\n    paramsDenseValues: TypedArray, paramsDenseValuesShape: number[],\n    paramsDenseValuesDType: DataType, valueSlices: Array<[number, number]>,\n    numValues: number): [TypedArray, number[]] {\n  const valuesShape = paramsDenseValuesShape.slice();\n  valuesShape[0] = numValues;\n\n  const valuesOut = util.getArrayFromDType(\n                        paramsDenseValuesDType,\n                        util.sizeFromShape(valuesShape)) as TypedArray;\n\n  const numElements = paramsDenseValues.length;\n  const valueSize =\n      numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]);\n  writeValueSlices(\n      paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize,\n      valuesOut, valuesShape);\n\n  return [valuesOut, valuesShape];\n}\nexport function raggedGatherImpl(\n    paramsNestedSplits: TypedArray[], paramsNestedSplitsShapes: number[][],\n    paramsDenseValues: TypedArray, paramsDenseValuesShape: number[],\n    paramsDenseValuesDType: DataType, indices: TypedArray,\n    indicesShape: number[],\n    outputRaggedRank: number): [TypedArray[], TypedArray, number[]] {\n  if (paramsNestedSplits.length === 0) {\n    throw new Error('paramsNestedSplits must be non empty');\n  }\n\n  if (paramsNestedSplitsShapes[0].length === 0) {\n    throw new Error('Split tensors must not be scalars');\n  }\n  const numParams = paramsNestedSplitsShapes[0][0] - 1;\n  validateIndices(indices, indicesShape, numParams);\n\n  if (paramsDenseValuesShape.length === 0) {\n    throw new Error('params.rank must be nonzero');\n  }\n  const numParamsDenseValues = paramsDenseValuesShape[0];\n\n  // Calculate the `splits`, and store the value slices that we need to\n  // copy in `valueSlices`.\n  const {outSplits, valueSlices, numValues} = makeSplits(\n      indices, indicesShape, paramsNestedSplits, numParamsDenseValues);\n\n  // Write the output tensors.\n  const outputNestedSplits = getSplits(outSplits);\n  const outputDenseValues = getValues(\n      paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType,\n      valueSlices, numValues);\n\n  return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]];\n}\n"]}
|