/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, util } from '@tensorflow/tfjs-core';
|
export function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
|
const denseSize = util.sizeFromShape(inputShape);
|
const nnz = inputIndicesShape[0];
|
const outputRank = targetShape.length;
|
// Compute the output shape. Determine product of specified dimensions, and
|
// find the index of the unspecified one.
|
const outputShape = [];
|
let product = 1;
|
let unknownIndex = -1;
|
for (let d = 0; d < outputRank; ++d) {
|
const size = targetShape[d];
|
if (size === -1) {
|
if (unknownIndex !== -1) {
|
throw new Error(backend_util
|
.getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d));
|
}
|
unknownIndex = d;
|
outputShape.push(1);
|
}
|
else {
|
if (size < 0) {
|
throw new Error(backend_util.getSparseReshapeNegativeOutputDimErrorMessage(d, size));
|
}
|
product *= size;
|
outputShape.push(size);
|
}
|
}
|
if (unknownIndex !== -1) {
|
if (product <= 0) {
|
throw new Error(backend_util.getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());
|
}
|
const missing = Math.trunc(denseSize / product);
|
if (product * missing !== denseSize) {
|
throw new Error(backend_util.getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape));
|
}
|
outputShape[unknownIndex] = missing;
|
}
|
const outputSize = util.sizeFromShape(outputShape);
|
if (outputSize !== denseSize) {
|
throw new Error(backend_util.getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape));
|
}
|
const inputRank = inputShape.length;
|
const inputStrides = [];
|
if (inputRank > 0) {
|
inputStrides[inputRank - 1] = 1;
|
for (let d = inputRank - 2; d >= 0; --d) {
|
inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
|
}
|
}
|
const outputStrides = [];
|
if (outputRank > 0) {
|
outputStrides[outputRank - 1] = 1;
|
for (let d = outputRank - 2; d >= 0; --d) {
|
outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
|
}
|
}
|
const newIndices = util.getArrayFromDType(inputDType, nnz * outputRank);
|
for (let i = 0; i < nnz; ++i) {
|
let id = 0;
|
for (let j = 0; j < inputRank; ++j) {
|
// inputIndices is a 2d tensor with shape of [nnz, inputRank]
|
id += inputIndices[i * inputRank + j] * inputStrides[j];
|
}
|
for (let j = 0; j < outputRank; ++j) {
|
// newIndices is a 2d tensor with shape of [nnz, outputRank]
|
newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
|
id %= outputStrides[j];
|
}
|
}
|
return [newIndices, [nnz, outputRank], outputShape];
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhcnNlUmVzaGFwZV9pbXBsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9TcGFyc2VSZXNoYXBlX2ltcGwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBd0IsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFFL0UsTUFBTSxVQUFVLGlCQUFpQixDQUM3QixZQUF3QixFQUFFLGlCQUEyQixFQUFFLFVBQW9CLEVBQzNFLFVBQW9CLEVBQ3BCLFdBQXFCO0lBQ3ZCLE1BQU0sU0FBUyxHQUFHLElBQUksQ0FBQyxhQUFhLENBQUMsVUFBVSxDQUFDLENBQUM7SUFDakQsTUFBTSxHQUFHLEdBQUcsaUJBQWlCLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDakMsTUFBTSxVQUFVLEdBQUcsV0FBVyxDQUFDLE1BQU0sQ0FBQztJQUV0QywyRUFBMkU7SUFDM0UseUNBQXlDO0lBQ3pDLE1BQU0sV0FBVyxHQUFhLEVBQUUsQ0FBQztJQUNqQyxJQUFJLE9BQU8sR0FBRyxDQUFDLENBQUM7SUFDaEIsSUFBSSxZQUFZLEdBQUcsQ0FBQyxDQUFDLENBQUM7SUFDdEIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFVBQVUsRUFBRSxFQUFFLENBQUMsRUFBRTtRQUNuQyxNQUFNLElBQUksR0FBRyxXQUFXLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDNUIsSUFBSSxJQUFJLEtBQUssQ0FBQyxDQUFDLEVBQUU7WUFDZixJQUFJLFlBQVksS0FBSyxDQUFDLENBQUMsRUFBRTtnQkFDdkIsTUFBTSxJQUFJLEtBQUssQ0FDWCxZQUFZO3FCQUNQLHdEQUF3RCxDQUNyRCxZQUFZLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQzthQUMvQjtZQUNELFlBQVksR0FBRyxDQUFDLENBQUM7WUFDakIsV0FBVyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQztTQUNyQjthQUFNO1lBQ0wsSUFBSSxJQUFJLEdBQUcsQ0FBQyxFQUFFO2dCQUNaLE1BQU0sSUFBSSxLQUFLLENBQ1gsWUFBWSxDQUFDLDZDQUE2QyxDQUN0RCxDQUFDLEVBQUUsSUFBSSxDQUFDLENBQUMsQ0FBQzthQUNuQjtZQUNELE9BQU8sSUFBSSxJQUFJLENBQUM7WUFDaEIsV0FBVyxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQztTQUN4QjtLQUNGO0lBQ0QsSUFBSSxZQUFZLEtBQUssQ0FBQyxDQUFDLEVBQUU7UUFDdkIsSUFBSSxPQUFPLElBQUksQ0FBQyxFQUFFO1lBQ2hCLE1BQU0sSUFBSSxLQUFLLENBQ1gsWUFBWSxDQUFDLG9EQUFvRCxFQUFFLENBQUMsQ0FBQztTQUMxRTtRQUNELE1BQU0sT0FBTyxHQUFHLElBQUksQ0FBQyxLQUFLLENBQUMsU0FBUyxHQUFHLE9BQU8sQ0FBQyxDQUFDO1FBQ2hELElBQUksT0FBTyxHQUFHLE9BQU8sS0FBSyxTQUFTLEVBQUU7WUFDbkMsTUFBTSxJQUFJLEtBQUssQ0FDWCxZQUFZLENBQUMsK0NBQStDLENBQ3hELFVBQVUsRUFBRSxXQUFXLENBQUMsQ0FBQyxDQUFDO1NBQ25DO1FBRUQsV0FBVyxDQUFDLFlBQVksQ0FBQyxHQUFHLE9BQU8sQ0FBQztLQUNyQztJQUNELE1BQU0sVUFBVSxHQUFHLElBQUksQ0FBQyxhQUFhLENBQUMsV0FBVyxDQUFDLENBQUM7SUFDbkQsSUFBSSxVQUFVLEtBQUssU0FBUyxFQUFFO1FBQzVCLE1BQU0sSUFBSSxLQUFLLENBQ1gsWUFBWSxDQUFDLCtDQUErQyxDQUN4RCxVQUFVLEVBQUUsV0FBVyxDQUFDLENBQUMsQ0FBQztLQUNuQztJQUVELE1BQU0sU0FBUyxHQUFHLFVBQVUsQ0FBQyxNQUFNLENBQUM7SUFDcEMsTUFBTSxZQUFZLEdBQWEsRUFBRSxDQUFDO0lBQ2xDLElBQUksU0FBUyxHQUFHLENBQUMsRUFBRTtRQUNqQixZQUFZLENBQUMsU0FBUyxHQUFHLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQztRQUNoQyxLQUFLLElBQUksQ0FBQyxHQUFHLFNBQVMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxJQUFJLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTtZQUN2QyxZQUFZLENBQUMsQ0FBQyxDQUFDLEdBQUcsWUFBWSxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsR0FBRyxVQUFVLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDO1NBQzNEO0tBQ0Y7SUFFRCxNQUFNLGFBQWEsR0FBYSxFQUFFLENBQUM7SUFDbkMsSUFBSSxVQUFVLEdBQUcsQ0FBQyxFQUFFO1FBQ2xCLGFBQWEsQ0FBQyxVQUFVLEdBQUcsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDO1FBQ2xDLEtBQUssSUFBSSxDQUFDLEdBQUcsVUFBVSxHQUFHLENBQUMsRUFBRSxDQUFDLElBQUksQ0FBQyxFQUFFLEVBQUUsQ0FBQyxFQUFFO1lBQ3hDLGFBQWEsQ0FBQyxDQUFDLENBQUMsR0FBRyxhQUFhLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxHQUFHLFdBQVcsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7U0FDOUQ7S0FDRjtJQUVELE1BQU0sVUFBVSxHQUNaLElBQUksQ0FBQyxpQkFBaUIsQ0FBQyxVQUFVLEVBQUUsR0FBRyxHQUFHLFVBQVUsQ0FBZSxDQUFDO0lBQ3ZFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxHQUFHLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDNUIsSUFBSSxFQUFFLEdBQUcsQ0FBQyxDQUFDO1FBQ1gsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFNBQVMsRUFBRSxFQUFFLENBQUMsRUFBRTtZQUNsQyw2REFBNkQ7WUFDN0QsRUFBRSxJQUFJLFlBQVksQ0FBQyxDQUFDLEdBQUcsU0FBUyxHQUFHLENBQUMsQ0FBQyxHQUFHLFlBQVksQ0FBQyxDQUFDLENBQUMsQ0FBQztTQUN6RDtRQUNELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxVQUFVLEVBQUUsRUFBRSxDQUFDLEVBQUU7WUFDbkMsNERBQTREO1lBQzVELFVBQVUsQ0FBQyxDQUFDLEdBQUcsVUFBVSxHQUFHLENBQUMsQ0FBQyxHQUFHLElBQUksQ0FBQyxLQUFLLENBQUMsRUFBRSxHQUFHLGFBQWEsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1lBQ25FLEVBQUUsSUFBSSxhQUFhLENBQUMsQ0FBQyxDQUFDLENBQUM7U0FDeEI7S0FDRjtJQUNELE9BQU8sQ0FBQyxVQUFVLEVBQUUsQ0FBQyxHQUFHLEVBQUUsVUFBVSxDQUFDLEVBQUUsV0FBVyxDQUFDLENBQUM7QUFDdEQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIxIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIERhdGFUeXBlLCBUeXBlZEFycmF5LCB1dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5leHBvcnQgZnVuY3Rpb24gc3BhcnNlUmVzaGFwZUltcGwoXG4gICAgaW5wdXRJbmRpY2VzOiBUeXBlZEFycmF5LCBpbnB1dEluZGljZXNTaGFwZTogbnVtYmVyW10sIGlucHV0RFR5cGU6IERhdGFUeXBlLFxuICAgIGlucHV0U2hhcGU6IG51bWJlcltdLFxuICAgIHRhcmdldFNoYXBlOiBudW1iZXJbXSk6IFtUeXBlZEFycmF5LCBudW1iZXJbXSwgbnVtYmVyW11dIHtcbiAgY29uc3QgZGVuc2VTaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKGlucHV0U2hhcGUpO1xuICBjb25zdCBubnogPSBpbnB1dEluZGljZXNTaGFwZVswXTtcbiAgY29uc3Qgb3V0cHV0UmFuayA9IHRhcmdldFNoYXBlLmxlbmd0aDtcblxuICAvLyBDb21wdXRlIHRoZSBvdXRwdXQgc2hhcGUuIERldGVybWluZSBwcm9kdWN0IG9mIHNwZWNpZmllZCBkaW1lbnNpb25zLCBhbmRcbiAgLy8gZmluZCB0aGUgaW5kZXggb2YgdGhlIHVuc3BlY2lmaWVkIG9uZS5cbiAgY29uc3Qgb3V0cHV0U2hhcGU6IG51bWJlcltdID0gW107XG4gIGxldCBwcm9kdWN0ID0gMTtcbiAgbGV0IHVua25vd25JbmRleCA9IC0xO1xuICBmb3IgKGxldCBkID0gMDsgZCA8IG91dHB1dFJhbms7ICsrZCkge1xuICAgIGNvbnN0IHNpemUgPSB0YXJnZXRTaGFwZVtkXTtcbiAgICBpZiAoc2l6ZSA9PT0gLTEpIHtcbiAgICAgIGlmICh1bmtub3duSW5kZXggIT09IC0xKSB7XG4gICAgICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgICAgIGJhY2tlbmRfdXRpbFxuICAgICAgICAgICAgICAgIC5nZXRTcGFyc2VSZXNoYXBlTXVsdGlwbGVOZWdhdGl2ZU9uZU91dHB1dERpbUVycm9yTWVzc2FnZShcbiAgICAgICAgICAgICAgICAgICAgdW5rbm93bkluZGV4LCBkKSk7XG4gICAgICB9XG4gICAgICB1bmtub3duSW5kZXggPSBkO1xuICAgICAgb3V0cHV0U2hhcGUucHVzaCgxKTtcbiAgICB9IGVsc2Uge1xuICAgICAgaWYgKHNpemUgPCAwKSB7XG4gICAgICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgICAgIGJhY2tlbmRfdXRpbC5nZXRTcGFyc2VSZXNoYXBlTmVnYXRpdmVPdXRwdXREaW1FcnJvck1lc3NhZ2UoXG4gICAgICAgICAgICAgICAgZCwgc2l6ZSkpO1xuICAgICAgfVxuICAgICAgcHJvZHVjdCAqPSBzaXplO1xuICAgICAgb3V0cHV0U2hhcGUucHVzaChzaXplKTtcbiAgICB9XG4gIH1cbiAgaWYgKHVua25vd25JbmRleCAhPT0gLTEpIHtcbiAgICBpZiAocHJvZHVjdCA8PSAwKSB7XG4gICAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICAgYmFja2VuZF91dGlsLmdldFNwYXJzZVJlc2hhcGVFbXB0eVRlbnNvclplcm9PdXRwdXREaW1FcnJvck1lc3NhZ2UoKSk7XG4gICAgfVxuICAgIGNvbnN0IG1pc3NpbmcgPSBNYXRoLnRydW5jKGRlbnNlU2l6ZSAvIHByb2R1Y3QpO1xuICAgIGlmIChwcm9kdWN0ICogbWlzc2luZyAhPT0gZGVuc2VTaXplKSB7XG4gICAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICAgYmFja2VuZF91dGlsLmdldFNwYXJzZVJlc2hhcGVJbnB1dE91dHB1dE11bHRpcGxlRXJyb3JNZXNzYWdlKFxuICAgICAgICAgICAgICBpbnB1dFNoYXBlLCBvdXRwdXRTaGFwZSkpO1xuICAgIH1cblxuICAgIG91dHB1dFNoYXBlW3Vua25vd25JbmRleF0gPSBtaXNzaW5nO1xuICB9XG4gIGNvbnN0IG91dHB1dFNpemUgPSB1dGlsLnNpemVGcm9tU2hhcGUob3V0cHV0U2hhcGUpO1xuICBpZiAob3V0cHV0U2l6ZSAhPT0gZGVuc2VTaXplKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICBiYWNrZW5kX3V0aWwuZ2V0U3BhcnNlUmVzaGFwZUlucHV0T3V0cHV0TWlzbWF0Y2hFcnJvck1lc3NhZ2UoXG4gICAgICAgICAgICBpbnB1dFNoYXBlLCBvdXRwdXRTaGFwZSkpO1xuICB9XG5cbiAgY29uc3QgaW5wdXRSYW5rID0gaW5wdXRTaGFwZS5sZW5ndGg7XG4gIGNvbnN0IGlucHV0U3RyaWRlczogbnVtYmVyW10gPSBbXTtcbiAgaWYgKGlucHV0UmFuayA+IDApIHtcbiAgICBpbnB1dFN0cmlkZXNbaW5wdXRSYW5rIC0gMV0gPSAxO1xuICAgIGZvciAobGV0IGQgPSBpbnB1dFJhbmsgLSAyOyBkID49IDA7IC0tZCkge1xuICAgICAgaW5wdXRTdHJpZGVzW2RdID0gaW5wdXRTdHJpZGVzW2QgKyAxXSAqIGlucHV0U2hhcGVbZCArIDFdO1xuICAgIH1cbiAgfVxuXG4gIGNvbnN0IG91dHB1dFN0cmlkZXM6IG51bWJlcltdID0gW107XG4gIGlmIChvdXRwdXRSYW5rID4gMCkge1xuICAgIG91dHB1dFN0cmlkZXNbb3V0cHV0UmFuayAtIDFdID0gMTtcbiAgICBmb3IgKGxldCBkID0gb3V0cHV0UmFuayAtIDI7IGQgPj0gMDsgLS1kKSB7XG4gICAgICBvdXRwdXRTdHJpZGVzW2RdID0gb3V0cHV0U3RyaWRlc1tkICsgMV0gKiBvdXRwdXRTaGFwZVtkICsgMV07XG4gICAgfVxuICB9XG5cbiAgY29uc3QgbmV3SW5kaWNlcyA9XG4gICAgICB1dGlsLmdldEFycmF5RnJvbURUeXBlKGlucHV0RFR5cGUsIG5ueiAqIG91dHB1dFJhbmspIGFzIFR5cGVkQXJyYXk7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgbm56OyArK2kpIHtcbiAgICBsZXQgaWQgPSAwO1xuICAgIGZvciAobGV0IGogPSAwOyBqIDwgaW5wdXRSYW5rOyArK2opIHtcbiAgICAgIC8vIGlucHV0SW5kaWNlcyBpcyBhIDJkIHRlbnNvciB3aXRoIHNoYXBlIG9mIFtubnosIGlucHV0UmFua11cbiAgICAgIGlkICs9IGlucHV0SW5kaWNlc1tpICogaW5wdXRSYW5rICsgal0gKiBpbnB1dFN0cmlkZXNbal07XG4gICAgfVxuICAgIGZvciAobGV0IGogPSAwOyBqIDwgb3V0cHV0UmFuazsgKytqKSB7XG4gICAgICAvLyBuZXdJbmRpY2VzIGlzIGEgMmQgdGVuc29yIHdpdGggc2hhcGUgb2YgW25ueiwgb3V0cHV0UmFua11cbiAgICAgIG5ld0luZGljZXNbaSAqIG91dHB1dFJhbmsgKyBqXSA9IE1hdGgudHJ1bmMoaWQgLyBvdXRwdXRTdHJpZGVzW2pdKTtcbiAgICAgIGlkICU9IG91dHB1dFN0cmlkZXNbal07XG4gICAgfVxuICB9XG4gIHJldHVybiBbbmV3SW5kaWNlcywgW25ueiwgb3V0cHV0UmFua10sIG91dHB1dFNoYXBlXTtcbn1cbiJdfQ==
|