import { computeStrides, sizeFromShape } from '../util'; /** * Check whether updates.shape = indices.shape[:batchDim] + * shape[sliceDim:] * * @param x The input tensor. */ export function validateUpdateShape(shape, indices, updates) { const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1; const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1; const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' + `shape[sliceDim:], got updates.shape: ${updates.shape}` + `, indices.shape: ${indices.shape}, shape: ${shape}` + `, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`; if (updates.rank < batchDim) { throw new Error(shapeError + ` update.rank < ${batchDim}. `); } if (shape.length < sliceDim + (updates.rank - batchDim)) { throw new Error(shapeError + ` Output shape length < ${sliceDim + (updates.rank - batchDim)}`); } if (updates.rank !== batchDim + shape.length - sliceDim) { throw new Error(shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`); } for (let d = 0; d < batchDim; ++d) { if (updates.shape[d] !== indices.shape[d]) { throw new Error(shapeError + ` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${indices.shape[d]}).`); } } for (let d = 0; d < updates.rank - batchDim; ++d) { if (updates.shape[d + batchDim] !== shape[d + sliceDim]) { throw new Error(shapeError + ` updates.shape[${d + batchDim}] (${updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${shape[d + batchDim]})`); } } } /** * Validate scatter nd inputs. * * @param update The tensor contains the update values. * @param indices The tensor contains the indices for the update values. * @param shape The shape of the output tensor. */ export function validateInput(updates, indices, shape) { if (indices.rank < 1) { throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' + ` but the rank was ${indices.rank}.`); } if (updates.rank < 1) { throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' + ` but the rank was ${updates.rank}.`); } if (indices.dtype !== 'int32') { throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${indices.dtype}`); } if (shape.length < 1) { throw new Error(`Output rank must be greater or equal to 1, but got shape: ${shape}`); } if (shape.length === 0) { if (indices.size === 0) { throw new Error(`Indices specified for empty output. indices shape: ${indices.shape}`); } if (updates.size === 0) { throw new Error(`Updates specified for empty output. updates shape: ${updates.shape}`); } } validateUpdateShape(shape, indices, updates); } /** * Calculate the shape information for the output. * * @param update The tensor contains the update values. * @param indices The tensor contains the indices for the update values. * @param shape The shape of the output tensor. * * @returns ScatterShapeInfo */ export function calculateShapes(updates, indices, shape) { // Calculate the number of dimensions in indices const indicesRank = indices.shape.length; const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1; // Calculate the number of elements that make up each slice of our updated // tensor. This allows us to work with flattened tensors and copy over whole // slices at a time. const totalNd = shape.length; let sliceSize = 1; for (let i = sliceRank; i < totalNd; ++i) { sliceSize *= shape[i]; } const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank; const numUpdates = sizeFromShape(indices.shape) / safeSliceDim; const strides = [...computeStrides(shape.slice(0, sliceRank)), 1]; const outputSize = sizeFromShape(shape); return { sliceRank, numUpdates, sliceSize, strides, outputSize }; } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"scatter_nd_util.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/scatter_nd_util.ts"],"names":[],"mappings":"AAkBA,OAAO,EAAC,cAAc,EAAE,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtD;;;;;GAKG;AACH,MAAM,UAAU,mBAAmB,CAC/B,KAAe,EAAE,OAAe,EAAE,OAAe;IACnD,MAAM,QAAQ,GAAG,CAAC,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC1E,MAAM,QAAQ,GAAG,CAAC,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAE3D,MAAM,UAAU,GAAG,uDAAuD;QACtE,wCAAwC,OAAO,CAAC,KAAK,EAAE;QACvD,oBAAoB,OAAO,CAAC,KAAK,YAAY,KAAK,EAAE;QACpD,eAAe,QAAQ,mBAAmB,QAAQ,GAAG,CAAC;IAE1D,IAAI,OAAO,CAAC,IAAI,GAAG,QAAQ,EAAE;QAC3B,MAAM,IAAI,KAAK,CAAC,UAAU,GAAG,kBAAkB,QAAQ,IAAI,CAAC,CAAC;KAC9D;IACD,IAAI,KAAK,CAAC,MAAM,GAAG,QAAQ,GAAG,CAAC,OAAO,CAAC,IAAI,GAAG,QAAQ,CAAC,EAAE;QACvD,MAAM,IAAI,KAAK,CACX,UAAU;YACV,0BAA0B,QAAQ,GAAG,CAAC,OAAO,CAAC,IAAI,GAAG,QAAQ,CAAC,EAAE,CAAC,CAAC;KACvE;IACD,IAAI,OAAO,CAAC,IAAI,KAAK,QAAQ,GAAG,KAAK,CAAC,MAAM,GAAG,QAAQ,EAAE;QACvD,MAAM,IAAI,KAAK,CACX,UAAU,GAAG,mBAAmB,QAAQ,GAAG,KAAK,CAAC,MAAM,GAAG,QAAQ,EAAE,CAAC,CAAC;KAC3E;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,EAAE,CAAC,EAAE;QACjC,IAAI,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;YACzC,MAAM,IAAI,KAAK,CACX,UAAU;gBACV,kBAAkB,CAAC,MAAM,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,sBAAsB,CAAC,MAC5D,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;SAC/B;KACF;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,IAAI,GAAG,QAAQ,EAAE,EAAE,CAAC,EAAE;QAChD,IAAI,OAAO,CAAC,KAAK,CAAC,CAAC,GAAG,QAAQ,CAAC,KAAK,KAAK,CAAC,CAAC,GAAG,QAAQ,CAAC,EAAE;YACvD,MAAM,IAAI,KAAK,CACX,UAAU;gBACV,kBAAkB,CAAC,GAAG,QAAQ,MAC1B,OAAO,CAAC,KAAK,CAAC,CAAC,GAAG,QAAQ,CAAC,cAAc,CAAC,GAAG,QAAQ,MACrD,KAAK,CAAC,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC;SACjC;KACF;AACH,CAAC;AASD;;;;;;GAMG;AACH,MAAM,UAAU,aAAa,CACzB,OAAe,EAAE,OAAe,EAAE,KAAe;IACnD,IAAI,OAAO,CAAC,IAAI,GAAG,CAAC,EAAE;QACpB,MAAM,IAAI,KAAK,CACX,4DAA4D;YAC5D,qBAAqB,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;KAC3C;IACD,IAAI,OAAO,CAAC,IAAI,GAAG,CAAC,EAAE;QACpB,MAAM,IAAI,KAAK,CACX,4DAA4D;YAC5D,qBAAqB,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;KAC3C;IACD,IAAI,OAAO,CAAC,KAAK,KAAK,OAAO,EAAE;QAC7B,MAAM,IAAI,KAAK,CAAC,0DACZ,OAAO,CAAC,KAAK,EAAE,CAAC,CAAC;KACtB;IACD,IAAI,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE;QACpB,MAAM,IAAI,KAAK,CACX,6DAA6D,KAAK,EAAE,CAAC,CAAC;KAC3E;IAED,IAAI,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;QACtB,IAAI,OAAO,CAAC,IAAI,KAAK,CAAC,EAAE;YACtB,MAAM,IAAI,KAAK,CAAC,sDACZ,OAAO,CAAC,KAAK,EAAE,CAAC,CAAC;SACtB;QACD,IAAI,OAAO,CAAC,IAAI,KAAK,CAAC,EAAE;YACtB,MAAM,IAAI,KAAK,CAAC,sDACZ,OAAO,CAAC,KAAK,EAAE,CAAC,CAAC;SACtB;KACF;IAED,mBAAmB,CAAC,KAAK,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC;AAC/C,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,eAAe,CAC3B,OAAmB,EAAE,OAAmB,EACxC,KAAe;IACjB,gDAAgD;IAChD,MAAM,WAAW,GAAG,OAAO,CAAC,KAAK,CAAC,MAAM,CAAC;IACzC,MAAM,SAAS,GAAG,CAAC,WAAW,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,WAAW,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAEzE,0EAA0E;IAC1E,4EAA4E;IAC5E,oBAAoB;IACpB,MAAM,OAAO,GAAG,KAAK,CAAC,MAAM,CAAC;IAE7B,IAAI,SAAS,GAAG,CAAC,CAAC;IAClB,KAAK,IAAI,CAAC,GAAG,SAAS,EAAE,CAAC,GAAG,OAAO,EAAE,EAAE,CAAC,EAAE;QACxC,SAAS,IAAI,KAAK,CAAC,CAAC,CAAC,CAAC;KACvB;IAED,MAAM,YAAY,GAAG,CAAC,SAAS,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC;IACrD,MAAM,UAAU,GAAG,aAAa,CAAC,OAAO,CAAC,KAAK,CAAC,GAAG,YAAY,CAAC;IAE/D,MAAM,OAAO,GAAG,CAAC,GAAG,cAAc,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;IAClE,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAAC,CAAC;IACxC,OAAO,EAAC,SAAS,EAAE,UAAU,EAAE,SAAS,EAAE,OAAO,EAAE,UAAU,EAAC,CAAC;AACjE,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport { TensorInfo } from '../tensor_info';\nimport {Tensor} from '../tensor';\nimport {computeStrides, sizeFromShape} from '../util';\n\n/**\n * Check whether updates.shape = indices.shape[:batchDim] +\n * shape[sliceDim:]\n *\n * @param x The input tensor.\n */\nexport function validateUpdateShape(\n    shape: number[], indices: Tensor, updates: Tensor) {\n  const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;\n  const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;\n\n  const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +\n      `shape[sliceDim:], got updates.shape: ${updates.shape}` +\n      `, indices.shape: ${indices.shape}, shape: ${shape}` +\n      `, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`;\n\n  if (updates.rank < batchDim) {\n    throw new Error(shapeError + ` update.rank < ${batchDim}. `);\n  }\n  if (shape.length < sliceDim + (updates.rank - batchDim)) {\n    throw new Error(\n        shapeError +\n        ` Output shape length < ${sliceDim + (updates.rank - batchDim)}`);\n  }\n  if (updates.rank !== batchDim + shape.length - sliceDim) {\n    throw new Error(\n        shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`);\n  }\n  for (let d = 0; d < batchDim; ++d) {\n    if (updates.shape[d] !== indices.shape[d]) {\n      throw new Error(\n          shapeError +\n          ` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${\n              indices.shape[d]}).`);\n    }\n  }\n  for (let d = 0; d < updates.rank - batchDim; ++d) {\n    if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {\n      throw new Error(\n          shapeError +\n          ` updates.shape[${d + batchDim}] (${\n              updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${\n              shape[d + batchDim]})`);\n    }\n  }\n}\n\nexport interface ScatterShapeInfo {\n  sliceRank: number;\n  numUpdates: number;\n  sliceSize: number;\n  strides: number[];\n  outputSize: number;\n}\n/**\n * Validate scatter nd inputs.\n *\n * @param update The tensor contains the update values.\n * @param indices The tensor contains the indices for the update values.\n * @param shape The shape of the output tensor.\n */\nexport function validateInput(\n    updates: Tensor, indices: Tensor, shape: number[]) {\n  if (indices.rank < 1) {\n    throw new Error(\n        'tf.scatterND() expects the indices to be rank 1 or higher,' +\n        ` but the rank was ${indices.rank}.`);\n  }\n  if (updates.rank < 1) {\n    throw new Error(\n        'tf.scatterND() expects the updates to be rank 1 or higher,' +\n        ` but the rank was ${updates.rank}.`);\n  }\n  if (indices.dtype !== 'int32') {\n    throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${\n        indices.dtype}`);\n  }\n  if (shape.length < 1) {\n    throw new Error(\n        `Output rank must be greater or equal to 1, but got shape: ${shape}`);\n  }\n\n  if (shape.length === 0) {\n    if (indices.size === 0) {\n      throw new Error(`Indices specified for empty output. indices shape: ${\n          indices.shape}`);\n    }\n    if (updates.size === 0) {\n      throw new Error(`Updates specified for empty output. updates shape: ${\n          updates.shape}`);\n    }\n  }\n\n  validateUpdateShape(shape, indices, updates);\n}\n\n/**\n * Calculate the shape information for the output.\n *\n * @param update The tensor contains the update values.\n * @param indices The tensor contains the indices for the update values.\n * @param shape The shape of the output tensor.\n *\n * @returns ScatterShapeInfo\n */\nexport function calculateShapes(\n    updates: TensorInfo, indices: TensorInfo,\n    shape: number[]): ScatterShapeInfo {\n  // Calculate the number of dimensions in indices\n  const indicesRank = indices.shape.length;\n  const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;\n\n  // Calculate the number of elements that make up each slice of our updated\n  // tensor. This allows us to work with flattened tensors and copy over whole\n  // slices at a time.\n  const totalNd = shape.length;\n\n  let sliceSize = 1;\n  for (let i = sliceRank; i < totalNd; ++i) {\n    sliceSize *= shape[i];\n  }\n\n  const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;\n  const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;\n\n  const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];\n  const outputSize = sizeFromShape(shape);\n  return {sliceRank, numUpdates, sliceSize, strides, outputSize};\n}\n"]}