/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, util } from '@tensorflow/tfjs-core';
|
export function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
|
const denseSize = util.sizeFromShape(inputShape);
|
const nnz = inputIndicesShape[0];
|
const outputRank = targetShape.length;
|
// Compute the output shape. Determine product of specified dimensions, and
|
// find the index of the unspecified one.
|
const outputShape = [];
|
let product = 1;
|
let unknownIndex = -1;
|
for (let d = 0; d < outputRank; ++d) {
|
const size = targetShape[d];
|
if (size === -1) {
|
if (unknownIndex !== -1) {
|
throw new Error(backend_util
|
.getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d));
|
}
|
unknownIndex = d;
|
outputShape.push(1);
|
}
|
else {
|
if (size < 0) {
|
throw new Error(backend_util.getSparseReshapeNegativeOutputDimErrorMessage(d, size));
|
}
|
product *= size;
|
outputShape.push(size);
|
}
|
}
|
if (unknownIndex !== -1) {
|
if (product <= 0) {
|
throw new Error(backend_util.getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());
|
}
|
const missing = Math.trunc(denseSize / product);
|
if (product * missing !== denseSize) {
|
throw new Error(backend_util.getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape));
|
}
|
outputShape[unknownIndex] = missing;
|
}
|
const outputSize = util.sizeFromShape(outputShape);
|
if (outputSize !== denseSize) {
|
throw new Error(backend_util.getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape));
|
}
|
const inputRank = inputShape.length;
|
const inputStrides = [];
|
if (inputRank > 0) {
|
inputStrides[inputRank - 1] = 1;
|
for (let d = inputRank - 2; d >= 0; --d) {
|
inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
|
}
|
}
|
const outputStrides = [];
|
if (outputRank > 0) {
|
outputStrides[outputRank - 1] = 1;
|
for (let d = outputRank - 2; d >= 0; --d) {
|
outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
|
}
|
}
|
const newIndices = util.getArrayFromDType(inputDType, nnz * outputRank);
|
for (let i = 0; i < nnz; ++i) {
|
let id = 0;
|
for (let j = 0; j < inputRank; ++j) {
|
// inputIndices is a 2d tensor with shape of [nnz, inputRank]
|
id += inputIndices[i * inputRank + j] * inputStrides[j];
|
}
|
for (let j = 0; j < outputRank; ++j) {
|
// newIndices is a 2d tensor with shape of [nnz, outputRank]
|
newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
|
id %= outputStrides[j];
|
}
|
}
|
return [newIndices, [nnz, outputRank], outputShape];
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"SparseReshape_impl.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/SparseReshape_impl.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAwB,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE/E,MAAM,UAAU,iBAAiB,CAC7B,YAAwB,EAAE,iBAA2B,EAAE,UAAoB,EAC3E,UAAoB,EACpB,WAAqB;IACvB,MAAM,SAAS,GAAG,IAAI,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;IACjD,MAAM,GAAG,GAAG,iBAAiB,CAAC,CAAC,CAAC,CAAC;IACjC,MAAM,UAAU,GAAG,WAAW,CAAC,MAAM,CAAC;IAEtC,2EAA2E;IAC3E,yCAAyC;IACzC,MAAM,WAAW,GAAa,EAAE,CAAC;IACjC,IAAI,OAAO,GAAG,CAAC,CAAC;IAChB,IAAI,YAAY,GAAG,CAAC,CAAC,CAAC;IACtB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,EAAE,CAAC,EAAE;QACnC,MAAM,IAAI,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC;QAC5B,IAAI,IAAI,KAAK,CAAC,CAAC,EAAE;YACf,IAAI,YAAY,KAAK,CAAC,CAAC,EAAE;gBACvB,MAAM,IAAI,KAAK,CACX,YAAY;qBACP,wDAAwD,CACrD,YAAY,EAAE,CAAC,CAAC,CAAC,CAAC;aAC/B;YACD,YAAY,GAAG,CAAC,CAAC;YACjB,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACrB;aAAM;YACL,IAAI,IAAI,GAAG,CAAC,EAAE;gBACZ,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,6CAA6C,CACtD,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC;aACnB;YACD,OAAO,IAAI,IAAI,CAAC;YAChB,WAAW,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SACxB;KACF;IACD,IAAI,YAAY,KAAK,CAAC,CAAC,EAAE;QACvB,IAAI,OAAO,IAAI,CAAC,EAAE;YAChB,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,oDAAoD,EAAE,CAAC,CAAC;SAC1E;QACD,MAAM,OAAO,GAAG,IAAI,CAAC,KAAK,CAAC,SAAS,GAAG,OAAO,CAAC,CAAC;QAChD,IAAI,OAAO,GAAG,OAAO,KAAK,SAAS,EAAE;YACnC,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,+CAA+C,CACxD,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC;SACnC;QAED,WAAW,CAAC,YAAY,CAAC,GAAG,OAAO,CAAC;KACrC;IACD,MAAM,UAAU,GAAG,IAAI,CAAC,aAAa,CAAC,WAAW,CAAC,CAAC;IACnD,IAAI,UAAU,KAAK,SAAS,EAAE;QAC5B,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,+CAA+C,CACxD,UAAU,EAAE,WAAW,CAAC,CAAC,CAAC;KACnC;IAED,MAAM,SAAS,GAAG,UAAU,CAAC,MAAM,CAAC;IACpC,MAAM,YAAY,GAAa,EAAE,CAAC;IAClC,IAAI,SAAS,GAAG,CAAC,EAAE;QACjB,YAAY,CAAC,SAAS,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;QAChC,KAAK,IAAI,CAAC,GAAG,SAAS,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,EAAE,CAAC,EAAE;YACvC,YAAY,CAAC,CAAC,CAAC,GAAG,YAAY,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,UAAU,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;SAC3D;KACF;IAED,MAAM,aAAa,GAAa,EAAE,CAAC;IACnC,IAAI,UAAU,GAAG,CAAC,EAAE;QAClB,aAAa,CAAC,UAAU,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;QAClC,KAAK,IAAI,CAAC,GAAG,UAAU,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,EAAE,CAAC,EAAE;YACxC,aAAa,CAAC,CAAC,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;SAC9D;KACF;IAED,MAAM,UAAU,GACZ,IAAI,CAAC,iBAAiB,CAAC,UAAU,EAAE,GAAG,GAAG,UAAU,CAAe,CAAC;IACvE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,EAAE,CAAC,EAAE;QAC5B,IAAI,EAAE,GAAG,CAAC,CAAC;QACX,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,EAAE,CAAC,EAAE;YAClC,6DAA6D;YAC7D,EAAE,IAAI,YAAY,CAAC,CAAC,GAAG,SAAS,GAAG,CAAC,CAAC,GAAG,YAAY,CAAC,CAAC,CAAC,CAAC;SACzD;QACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,EAAE,CAAC,EAAE;YACnC,4DAA4D;YAC5D,UAAU,CAAC,CAAC,GAAG,UAAU,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,EAAE,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC,CAAC;YACnE,EAAE,IAAI,aAAa,CAAC,CAAC,CAAC,CAAC;SACxB;KACF;IACD,OAAO,CAAC,UAAU,EAAE,CAAC,GAAG,EAAE,UAAU,CAAC,EAAE,WAAW,CAAC,CAAC;AACtD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2021 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, DataType, TypedArray, util} from '@tensorflow/tfjs-core';\n\nexport function sparseReshapeImpl(\n    inputIndices: TypedArray, inputIndicesShape: number[], inputDType: DataType,\n    inputShape: number[],\n    targetShape: number[]): [TypedArray, number[], number[]] {\n  const denseSize = util.sizeFromShape(inputShape);\n  const nnz = inputIndicesShape[0];\n  const outputRank = targetShape.length;\n\n  // Compute the output shape. Determine product of specified dimensions, and\n  // find the index of the unspecified one.\n  const outputShape: number[] = [];\n  let product = 1;\n  let unknownIndex = -1;\n  for (let d = 0; d < outputRank; ++d) {\n    const size = targetShape[d];\n    if (size === -1) {\n      if (unknownIndex !== -1) {\n        throw new Error(\n            backend_util\n                .getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(\n                    unknownIndex, d));\n      }\n      unknownIndex = d;\n      outputShape.push(1);\n    } else {\n      if (size < 0) {\n        throw new Error(\n            backend_util.getSparseReshapeNegativeOutputDimErrorMessage(\n                d, size));\n      }\n      product *= size;\n      outputShape.push(size);\n    }\n  }\n  if (unknownIndex !== -1) {\n    if (product <= 0) {\n      throw new Error(\n          backend_util.getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());\n    }\n    const missing = Math.trunc(denseSize / product);\n    if (product * missing !== denseSize) {\n      throw new Error(\n          backend_util.getSparseReshapeInputOutputMultipleErrorMessage(\n              inputShape, outputShape));\n    }\n\n    outputShape[unknownIndex] = missing;\n  }\n  const outputSize = util.sizeFromShape(outputShape);\n  if (outputSize !== denseSize) {\n    throw new Error(\n        backend_util.getSparseReshapeInputOutputMismatchErrorMessage(\n            inputShape, outputShape));\n  }\n\n  const inputRank = inputShape.length;\n  const inputStrides: number[] = [];\n  if (inputRank > 0) {\n    inputStrides[inputRank - 1] = 1;\n    for (let d = inputRank - 2; d >= 0; --d) {\n      inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];\n    }\n  }\n\n  const outputStrides: number[] = [];\n  if (outputRank > 0) {\n    outputStrides[outputRank - 1] = 1;\n    for (let d = outputRank - 2; d >= 0; --d) {\n      outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];\n    }\n  }\n\n  const newIndices =\n      util.getArrayFromDType(inputDType, nnz * outputRank) as TypedArray;\n  for (let i = 0; i < nnz; ++i) {\n    let id = 0;\n    for (let j = 0; j < inputRank; ++j) {\n      // inputIndices is a 2d tensor with shape of [nnz, inputRank]\n      id += inputIndices[i * inputRank + j] * inputStrides[j];\n    }\n    for (let j = 0; j < outputRank; ++j) {\n      // newIndices is a 2d tensor with shape of [nnz, outputRank]\n      newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);\n      id %= outputStrides[j];\n    }\n  }\n  return [newIndices, [nnz, outputRank], outputShape];\n}\n"]}
|