/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, util } from '@tensorflow/tfjs-core';
|
export function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) {
|
const numIndices = indices.length;
|
// Flatten the array to two dimensions
|
const inputFlat = [inputShape[0], input.length / inputShape[0]];
|
const numCol = inputFlat[1];
|
// Note that the current implementation assumes that segmentIds values are
|
// sorted.
|
const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
|
const outputRows = lastSegmentIdPlusOne;
|
if (outputRows < 0) {
|
throw new Error(backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
|
}
|
const outputShape = inputShape.slice();
|
outputShape[0] = outputRows;
|
const outputLength = outputShape.reduce((product, value) => product * value, 1);
|
// Output array is initialized with the value 0 by default.
|
const output = util.getArrayFromDType(inputDType, outputLength);
|
// Note that we do not initialize the output buffer with a default value, so
|
// we need to explicitly set missing indices to the default value.
|
if (numIndices === 0) {
|
if (outputRows > 0) {
|
output.fill(defaultValue);
|
}
|
return [output, outputShape];
|
}
|
if (outputRows <= 0) {
|
throw new Error(backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
|
}
|
let start = 0, end = 1;
|
// Index from which the output is not initialized.
|
let uninitializedIndex = 0;
|
let outIndex = segmentIds[start];
|
while (true) {
|
// We initialize nextIndex to 0 to avoid may be uninitialized warning
|
let nextIndex = 0;
|
if (end < numIndices) {
|
nextIndex = segmentIds[end];
|
if (outIndex === nextIndex) {
|
++end;
|
continue;
|
}
|
// We have a new segment here. Verify that the segment ids are growing.
|
if (outIndex >= nextIndex) {
|
throw new Error(backend_util
|
.getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage());
|
}
|
}
|
if (outIndex < 0 || outIndex >= outputRows) {
|
throw new Error(backend_util.getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows));
|
}
|
// If there is a gap between two indices, we need to set that gap to the
|
// default value.
|
if (outIndex > uninitializedIndex) {
|
output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
|
}
|
for (let i = start; i < end; ++i) {
|
const index = indices[i];
|
if (index < 0 || index >= inputFlat[0]) {
|
throw new Error(backend_util.getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0]));
|
}
|
for (let j = 0; j < numCol; j++) {
|
output[outIndex * numCol + j] += input[index * numCol + j];
|
}
|
}
|
if (isMean) {
|
for (let j = 0; j < numCol; j++) {
|
output[outIndex * numCol + j] /= end - start;
|
}
|
}
|
start = end;
|
++end;
|
uninitializedIndex = outIndex + 1;
|
outIndex = nextIndex;
|
if (end > numIndices) {
|
break;
|
}
|
}
|
// Fill the gap at the end with the default value.
|
if (uninitializedIndex < outputRows) {
|
output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
|
}
|
return [output, outputShape];
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"SparseSegmentReduction_impl.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/SparseSegmentReduction_impl.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAwB,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE/E,MAAM,UAAU,0BAA0B,CACtC,KAAiB,EAAE,UAAoB,EAAE,UAAoB,EAC7D,OAAmB,EAAE,UAAsB,EAAE,MAAM,GAAG,KAAK,EAC3D,YAAY,GAAG,CAAC;IAClB,MAAM,UAAU,GAAG,OAAO,CAAC,MAAM,CAAC;IAElC,sCAAsC;IACtC,MAAM,SAAS,GAAa,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,MAAM,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC;IAC1E,MAAM,MAAM,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;IAC5B,0EAA0E;IAC1E,UAAU;IACV,MAAM,oBAAoB,GACtB,UAAU,GAAG,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,UAAU,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACxD,MAAM,UAAU,GAAG,oBAAoB,CAAC;IAExC,IAAI,UAAU,GAAG,CAAC,EAAE;QAClB,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,uDAAuD,EAAE,CAAC,CAAC;KAC7E;IAED,MAAM,WAAW,GAAG,UAAU,CAAC,KAAK,EAAE,CAAC;IACvC,WAAW,CAAC,CAAC,CAAC,GAAG,UAAU,CAAC;IAE5B,MAAM,YAAY,GACd,WAAW,CAAC,MAAM,CAAC,CAAC,OAAO,EAAE,KAAK,EAAE,EAAE,CAAC,OAAO,GAAG,KAAK,EAAE,CAAC,CAAC,CAAC;IAC/D,2DAA2D;IAC3D,MAAM,MAAM,GAAG,IAAI,CAAC,iBAAiB,CAAC,UAAU,EAAE,YAAY,CAAe,CAAC;IAE9E,4EAA4E;IAC5E,kEAAkE;IAClE,IAAI,UAAU,KAAK,CAAC,EAAE;QACpB,IAAI,UAAU,GAAG,CAAC,EAAE;YAClB,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC;SAC3B;QACD,OAAO,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;KAC9B;IAED,IAAI,UAAU,IAAI,CAAC,EAAE;QACnB,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,uDAAuD,EAAE,CAAC,CAAC;KAC7E;IAED,IAAI,KAAK,GAAG,CAAC,EAAE,GAAG,GAAG,CAAC,CAAC;IACvB,kDAAkD;IAClD,IAAI,kBAAkB,GAAG,CAAC,CAAC;IAC3B,IAAI,QAAQ,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC;IAEjC,OAAO,IAAI,EAAE;QACX,qEAAqE;QACrE,IAAI,SAAS,GAAG,CAAC,CAAC;QAClB,IAAI,GAAG,GAAG,UAAU,EAAE;YACpB,SAAS,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC;YAC5B,IAAI,QAAQ,KAAK,SAAS,EAAE;gBAC1B,EAAE,GAAG,CAAC;gBACN,SAAS;aACV;YACD,wEAAwE;YACxE,IAAI,QAAQ,IAAI,SAAS,EAAE;gBACzB,MAAM,IAAI,KAAK,CAAC,YAAY;qBACvB,4DAA4D,EAAE,CAAC,CAAC;aACtE;SACF;QAED,IAAI,QAAQ,GAAG,CAAC,IAAI,QAAQ,IAAI,UAAU,EAAE;YAC1C,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,wDAAwD,CACjE,QAAQ,EAAE,UAAU,CAAC,CAAC,CAAC;SAChC;QAED,wEAAwE;QACxE,iBAAiB;QACjB,IAAI,QAAQ,GAAG,kBAAkB,EAAE;YACjC,MAAM,CAAC,IAAI,CAAC,YAAY,EAAE,kBAAkB,GAAG,MAAM,EAAE,QAAQ,GAAG,MAAM,CAAC,CAAC;SAC3E;QAED,KAAK,IAAI,CAAC,GAAG,KAAK,EAAE,CAAC,GAAG,GAAG,EAAE,EAAE,CAAC,EAAE;YAChC,MAAM,KAAK,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;YACzB,IAAI,KAAK,GAAG,CAAC,IAAI,KAAK,IAAI,SAAS,CAAC,CAAC,CAAC,EAAE;gBACtC,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,sDAAsD,CAC/D,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aACvC;YACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE;gBAC/B,MAAM,CAAC,QAAQ,GAAG,MAAM,GAAG,CAAC,CAAC,IAAI,KAAK,CAAC,KAAK,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;aAC5D;SACF;QAED,IAAI,MAAM,EAAE;YACV,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE;gBAC/B,MAAM,CAAC,QAAQ,GAAG,MAAM,GAAG,CAAC,CAAC,IAAI,GAAG,GAAG,KAAK,CAAC;aAC9C;SACF;QAED,KAAK,GAAG,GAAG,CAAC;QACZ,EAAE,GAAG,CAAC;QACN,kBAAkB,GAAG,QAAQ,GAAG,CAAC,CAAC;QAClC,QAAQ,GAAG,SAAS,CAAC;QACrB,IAAI,GAAG,GAAG,UAAU,EAAE;YACpB,MAAM;SACP;KACF;IAED,kDAAkD;IAClD,IAAI,kBAAkB,GAAG,UAAU,EAAE;QACnC,MAAM,CAAC,IAAI,CAAC,YAAY,EAAE,kBAAkB,GAAG,MAAM,EAAE,UAAU,GAAG,MAAM,CAAC,CAAC;KAC7E;IAED,OAAO,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;AAC/B,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2021 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, DataType, TypedArray, util} from '@tensorflow/tfjs-core';\n\nexport function sparseSegmentReductionImpl(\n    input: TypedArray, inputShape: number[], inputDType: DataType,\n    indices: TypedArray, segmentIds: TypedArray, isMean = false,\n    defaultValue = 0): [TypedArray, number[]] {\n  const numIndices = indices.length;\n\n  // Flatten the array to two dimensions\n  const inputFlat: number[] = [inputShape[0], input.length / inputShape[0]];\n  const numCol = inputFlat[1];\n  // Note that the current implementation assumes that segmentIds values are\n  // sorted.\n  const lastSegmentIdPlusOne =\n      numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;\n  const outputRows = lastSegmentIdPlusOne;\n\n  if (outputRows < 0) {\n    throw new Error(\n        backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage());\n  }\n\n  const outputShape = inputShape.slice();\n  outputShape[0] = outputRows;\n\n  const outputLength =\n      outputShape.reduce((product, value) => product * value, 1);\n  // Output array is initialized with the value 0 by default.\n  const output = util.getArrayFromDType(inputDType, outputLength) as TypedArray;\n\n  // Note that we do not initialize the output buffer with a default value, so\n  // we need to explicitly set missing indices to the default value.\n  if (numIndices === 0) {\n    if (outputRows > 0) {\n      output.fill(defaultValue);\n    }\n    return [output, outputShape];\n  }\n\n  if (outputRows <= 0) {\n    throw new Error(\n        backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage());\n  }\n\n  let start = 0, end = 1;\n  // Index from which the output is not initialized.\n  let uninitializedIndex = 0;\n  let outIndex = segmentIds[start];\n\n  while (true) {\n    // We initialize nextIndex to 0 to avoid may be uninitialized warning\n    let nextIndex = 0;\n    if (end < numIndices) {\n      nextIndex = segmentIds[end];\n      if (outIndex === nextIndex) {\n        ++end;\n        continue;\n      }\n      // We have a new segment here.  Verify that the segment ids are growing.\n      if (outIndex >= nextIndex) {\n        throw new Error(backend_util\n            .getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage());\n      }\n    }\n\n    if (outIndex < 0 || outIndex >= outputRows) {\n      throw new Error(\n          backend_util.getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(\n              outIndex, outputRows));\n    }\n\n    // If there is a gap between two indices, we need to set that gap to the\n    // default value.\n    if (outIndex > uninitializedIndex) {\n      output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);\n    }\n\n    for (let i = start; i < end; ++i) {\n      const index = indices[i];\n      if (index < 0 || index >= inputFlat[0]) {\n        throw new Error(\n            backend_util.getSparseSegmentReductionIndicesOutOfRangeErrorMessage(\n                i, indices[i], inputFlat[0]));\n      }\n      for (let j = 0; j < numCol; j++) {\n        output[outIndex * numCol + j] += input[index * numCol + j];\n      }\n    }\n\n    if (isMean) {\n      for (let j = 0; j < numCol; j++) {\n        output[outIndex * numCol + j] /= end - start;\n      }\n    }\n\n    start = end;\n    ++end;\n    uninitializedIndex = outIndex + 1;\n    outIndex = nextIndex;\n    if (end > numIndices) {\n      break;\n    }\n  }\n\n  // Fill the gap at the end with the default value.\n  if (uninitializedIndex < outputRows) {\n    output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);\n  }\n\n  return [output, outputShape];\n}\n"]}
|