/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from '../../engine';
|
import { SparseReshape } from '../../kernel_names';
|
import { convertToTensor } from '../../tensor_util_env';
|
import { op } from '../operation';
|
/**
|
* This operation has the same semantics as reshape on the represented dense
|
* tensor. The `inputIndices` are recomputed based on the requested `newShape`.
|
* If one component of `newShape` is the special value -1, the size of that
|
* dimension is computed so that the total dense size remains constant. At most
|
* one component of `newShape` can be -1. The number of dense elements implied
|
* by `newShape` must be the same as the number of dense elements originally
|
* implied by `inputShape`. Reshaping does not affect the order of values in the
|
* SparseTensor. If the input tensor has rank R_in and N non-empty values, and
|
* `newShape` has length R_out, then `inputIndices` has shape [N, R_in],
|
* `inputShape` has length R_in, `outputIndices` has shape [N, R_out], and
|
* `outputShape` has length R_out.
|
*
|
* ```js
|
* const result = tf.sparse.sparseReshape(
|
* [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]],
|
* [2, 3, 6], [9, -1]);
|
* console.log(result);
|
* result['outputIndices'].print(); //[[0, 0], [0, 1], [1, 2], [4, 2], [8, 1]]
|
* result['outputShape'].print(); // [9, 4]
|
* ```
|
* @param inputIndices: 2-D. N x R_in matrix with the indices of non-empty
|
* values in a SparseTensor.
|
* @param inputShape: 1-D. R_in Tensor1D with the input SparseTensor's dense
|
* shape.
|
* @param newShape: 1-D. R_out Tensor1D with the requested new dense shape.
|
* @return A map with the following properties:
|
* - outputIndices: 2-D. N x R_out matrix with the updated indices of
|
* non-empty values in the output SparseTensor.
|
* - outputShape: 1-D. R_out vector with the full dense shape of the output
|
* SparseTensor. This is the same as newShape but with any -1 dimensions
|
* filled in.
|
* @doc {heading: 'Operations', subheading: 'Sparse'}
|
*/
|
function sparseReshape_(inputIndices, inputShape, newShape) {
|
const $inputIndices = convertToTensor(inputIndices, 'inputIndices', 'sparseReshape', 'int32');
|
const $inputShape = convertToTensor(inputShape, 'inputShape', 'sparseReshape', 'int32');
|
const $newShape = convertToTensor(newShape, 'newShape', 'sparseReshape', 'int32');
|
if ($inputIndices.rank !== 2) {
|
throw new Error(`Input indices should be Tensor2D but received shape
|
${$inputIndices.shape}`);
|
}
|
if ($inputShape.rank !== 1) {
|
throw new Error(`Input shape should be Tensor1D but received shape ${$inputShape.shape}`);
|
}
|
if ($newShape.rank !== 1) {
|
throw new Error(`New shape should be Tensor1D but received shape ${$newShape.shape}`);
|
}
|
const inputs = {
|
inputIndices: $inputIndices,
|
inputShape: $inputShape,
|
newShape: $newShape
|
};
|
const result = ENGINE.runKernel(SparseReshape, inputs);
|
return { outputIndices: result[0], outputShape: result[1] };
|
}
|
export const sparseReshape = /* @__PURE__ */ op({ sparseReshape_ });
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoic3BhcnNlX3Jlc2hhcGUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy9zcGFyc2Uvc3BhcnNlX3Jlc2hhcGUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLGNBQWMsQ0FBQztBQUNwQyxPQUFPLEVBQUMsYUFBYSxFQUFzQixNQUFNLG9CQUFvQixDQUFDO0FBR3RFLE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUV0RCxPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sY0FBYyxDQUFDO0FBRWhDOzs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7R0FpQ0c7QUFDSCxTQUFTLGNBQWMsQ0FDbkIsWUFBaUMsRUFBRSxVQUErQixFQUNsRSxRQUE2QjtJQUMvQixNQUFNLGFBQWEsR0FDZixlQUFlLENBQUMsWUFBWSxFQUFFLGNBQWMsRUFBRSxlQUFlLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFDNUUsTUFBTSxXQUFXLEdBQ2IsZUFBZSxDQUFDLFVBQVUsRUFBRSxZQUFZLEVBQUUsZUFBZSxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBQ3hFLE1BQU0sU0FBUyxHQUNYLGVBQWUsQ0FBQyxRQUFRLEVBQUUsVUFBVSxFQUFFLGVBQWUsRUFBRSxPQUFPLENBQUMsQ0FBQztJQUVwRSxJQUFJLGFBQWEsQ0FBQyxJQUFJLEtBQUssQ0FBQyxFQUFFO1FBQzVCLE1BQU0sSUFBSSxLQUFLLENBQUM7VUFDVixhQUFhLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQztLQUM5QjtJQUNELElBQUksV0FBVyxDQUFDLElBQUksS0FBSyxDQUFDLEVBQUU7UUFDMUIsTUFBTSxJQUFJLEtBQUssQ0FBQyxxREFDWixXQUFXLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQztLQUMxQjtJQUNELElBQUksU0FBUyxDQUFDLElBQUksS0FBSyxDQUFDLEVBQUU7UUFDeEIsTUFBTSxJQUFJLEtBQUssQ0FDWCxtREFBbUQsU0FBUyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDM0U7SUFFRCxNQUFNLE1BQU0sR0FBd0I7UUFDbEMsWUFBWSxFQUFFLGFBQWE7UUFDM0IsVUFBVSxFQUFFLFdBQVc7UUFDdkIsUUFBUSxFQUFFLFNBQVM7S0FDcEIsQ0FBQztJQUNGLE1BQU0sTUFBTSxHQUFhLE1BQU0sQ0FBQyxTQUFTLENBQUMsYUFBYSxFQUFFLE1BQVksQ0FBQyxDQUFDO0lBQ3ZFLE9BQU8sRUFBQyxhQUFhLEVBQUUsTUFBTSxDQUFDLENBQUMsQ0FBQyxFQUFFLFdBQVcsRUFBRSxNQUFNLENBQUMsQ0FBQyxDQUFDLEVBQUMsQ0FBQztBQUM1RCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sYUFBYSxHQUFHLGVBQWUsQ0FBQyxFQUFFLENBQUMsRUFBQyxjQUFjLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjEgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0VOR0lORX0gZnJvbSAnLi4vLi4vZW5naW5lJztcbmltcG9ydCB7U3BhcnNlUmVzaGFwZSwgU3BhcnNlUmVzaGFwZUlucHV0c30gZnJvbSAnLi4vLi4va2VybmVsX25hbWVzJztcbmltcG9ydCB7VGVuc29yLCBUZW5zb3IxRCwgVGVuc29yMkR9IGZyb20gJy4uLy4uL3RlbnNvcic7XG5pbXBvcnQge05hbWVkVGVuc29yTWFwfSBmcm9tICcuLi8uLi90ZW5zb3JfdHlwZXMnO1xuaW1wb3J0IHtjb252ZXJ0VG9UZW5zb3J9IGZyb20gJy4uLy4uL3RlbnNvcl91dGlsX2Vudic7XG5pbXBvcnQge1RlbnNvckxpa2V9IGZyb20gJy4uLy4uL3R5cGVzJztcbmltcG9ydCB7b3B9IGZyb20gJy4uL29wZXJhdGlvbic7XG5cbi8qKlxuICogVGhpcyBvcGVyYXRpb24gaGFzIHRoZSBzYW1lIHNlbWFudGljcyBhcyByZXNoYXBlIG9uIHRoZSByZXByZXNlbnRlZCBkZW5zZVxuICogdGVuc29yLiBUaGUgYGlucHV0SW5kaWNlc2AgYXJlIHJlY29tcHV0ZWQgYmFzZWQgb24gdGhlIHJlcXVlc3RlZCBgbmV3U2hhcGVgLlxuICogSWYgb25lIGNvbXBvbmVudCBvZiBgbmV3U2hhcGVgIGlzIHRoZSBzcGVjaWFsIHZhbHVlIC0xLCB0aGUgc2l6ZSBvZiB0aGF0XG4gKiBkaW1lbnNpb24gaXMgY29tcHV0ZWQgc28gdGhhdCB0aGUgdG90YWwgZGVuc2Ugc2l6ZSByZW1haW5zIGNvbnN0YW50LiBBdCBtb3N0XG4gKiBvbmUgY29tcG9uZW50IG9mIGBuZXdTaGFwZWAgY2FuIGJlIC0xLiBUaGUgbnVtYmVyIG9mIGRlbnNlIGVsZW1lbnRzIGltcGxpZWRcbiAqIGJ5IGBuZXdTaGFwZWAgbXVzdCBiZSB0aGUgc2FtZSBhcyB0aGUgbnVtYmVyIG9mIGRlbnNlIGVsZW1lbnRzIG9yaWdpbmFsbHlcbiAqIGltcGxpZWQgYnkgYGlucHV0U2hhcGVgLiBSZXNoYXBpbmcgZG9lcyBub3QgYWZmZWN0IHRoZSBvcmRlciBvZiB2YWx1ZXMgaW4gdGhlXG4gKiBTcGFyc2VUZW5zb3IuIElmIHRoZSBpbnB1dCB0ZW5zb3IgaGFzIHJhbmsgUl9pbiBhbmQgTiBub24tZW1wdHkgdmFsdWVzLCBhbmRcbiAqIGBuZXdTaGFwZWAgaGFzIGxlbmd0aCBSX291dCwgdGhlbiBgaW5wdXRJbmRpY2VzYCBoYXMgc2hhcGUgW04sIFJfaW5dLFxuICogYGlucHV0U2hhcGVgIGhhcyBsZW5ndGggUl9pbiwgYG91dHB1dEluZGljZXNgIGhhcyBzaGFwZSBbTiwgUl9vdXRdLCBhbmRcbiAqIGBvdXRwdXRTaGFwZWAgaGFzIGxlbmd0aCBSX291dC5cbiAqXG4gKiBgYGBqc1xuICogY29uc3QgcmVzdWx0ID0gdGYuc3BhcnNlLnNwYXJzZVJlc2hhcGUoXG4gKiAgIFtbMCwgMCwgMF0sIFswLCAwLCAxXSwgWzAsIDEsIDBdLCBbMSwgMCwgMF0sIFsxLCAyLCAzXV0sXG4gKiAgIFsyLCAzLCA2XSwgWzksIC0xXSk7XG4gKiBjb25zb2xlLmxvZyhyZXN1bHQpO1xuICogcmVzdWx0WydvdXRwdXRJbmRpY2VzJ10ucHJpbnQoKTsgLy9bWzAsIDBdLCBbMCwgMV0sIFsxLCAyXSwgWzQsIDJdLCBbOCwgMV1dXG4gKiByZXN1bHRbJ291dHB1dFNoYXBlJ10ucHJpbnQoKTsgLy8gWzksIDRdXG4gKiBgYGBcbiAqIEBwYXJhbSBpbnB1dEluZGljZXM6IDItRC4gTiB4IFJfaW4gbWF0cml4IHdpdGggdGhlIGluZGljZXMgb2Ygbm9uLWVtcHR5XG4gKiB2YWx1ZXMgaW4gYSBTcGFyc2VUZW5zb3IuXG4gKiBAcGFyYW0gaW5wdXRTaGFwZTogMS1ELiBSX2luIFRlbnNvcjFEIHdpdGggdGhlIGlucHV0IFNwYXJzZVRlbnNvcidzIGRlbnNlXG4gKiBzaGFwZS5cbiAqIEBwYXJhbSBuZXdTaGFwZTogMS1ELiBSX291dCBUZW5zb3IxRCB3aXRoIHRoZSByZXF1ZXN0ZWQgbmV3IGRlbnNlIHNoYXBlLlxuICogQHJldHVybiBBIG1hcCB3aXRoIHRoZSBmb2xsb3dpbmcgcHJvcGVydGllczpcbiAqICAgICAtIG91dHB1dEluZGljZXM6IDItRC4gTiB4IFJfb3V0IG1hdHJpeCB3aXRoIHRoZSB1cGRhdGVkIGluZGljZXMgb2ZcbiAqICAgICAgIG5vbi1lbXB0eSB2YWx1ZXMgaW4gdGhlIG91dHB1dCBTcGFyc2VUZW5zb3IuXG4gKiAgICAgLSBvdXRwdXRTaGFwZTogMS1ELiBSX291dCB2ZWN0b3Igd2l0aCB0aGUgZnVsbCBkZW5zZSBzaGFwZSBvZiB0aGUgb3V0cHV0XG4gKiAgICAgICBTcGFyc2VUZW5zb3IuIFRoaXMgaXMgdGhlIHNhbWUgYXMgbmV3U2hhcGUgYnV0IHdpdGggYW55IC0xIGRpbWVuc2lvbnNcbiAqICAgICAgICBmaWxsZWQgaW4uXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdTcGFyc2UnfVxuICovXG5mdW5jdGlvbiBzcGFyc2VSZXNoYXBlXyhcbiAgICBpbnB1dEluZGljZXM6IFRlbnNvcjJEfFRlbnNvckxpa2UsIGlucHV0U2hhcGU6IFRlbnNvcjFEfFRlbnNvckxpa2UsXG4gICAgbmV3U2hhcGU6IFRlbnNvcjFEfFRlbnNvckxpa2UpOiBOYW1lZFRlbnNvck1hcCB7XG4gIGNvbnN0ICRpbnB1dEluZGljZXMgPVxuICAgICAgY29udmVydFRvVGVuc29yKGlucHV0SW5kaWNlcywgJ2lucHV0SW5kaWNlcycsICdzcGFyc2VSZXNoYXBlJywgJ2ludDMyJyk7XG4gIGNvbnN0ICRpbnB1dFNoYXBlID1cbiAgICAgIGNvbnZlcnRUb1RlbnNvcihpbnB1dFNoYXBlLCAnaW5wdXRTaGFwZScsICdzcGFyc2VSZXNoYXBlJywgJ2ludDMyJyk7XG4gIGNvbnN0ICRuZXdTaGFwZSA9XG4gICAgICBjb252ZXJ0VG9UZW5zb3IobmV3U2hhcGUsICduZXdTaGFwZScsICdzcGFyc2VSZXNoYXBlJywgJ2ludDMyJyk7XG5cbiAgaWYgKCRpbnB1dEluZGljZXMucmFuayAhPT0gMikge1xuICAgIHRocm93IG5ldyBFcnJvcihgSW5wdXQgaW5kaWNlcyBzaG91bGQgYmUgVGVuc29yMkQgYnV0IHJlY2VpdmVkIHNoYXBlXG4gICAgICAgICR7JGlucHV0SW5kaWNlcy5zaGFwZX1gKTtcbiAgfVxuICBpZiAoJGlucHV0U2hhcGUucmFuayAhPT0gMSkge1xuICAgIHRocm93IG5ldyBFcnJvcihgSW5wdXQgc2hhcGUgc2hvdWxkIGJlIFRlbnNvcjFEIGJ1dCByZWNlaXZlZCBzaGFwZSAke1xuICAgICAgICAkaW5wdXRTaGFwZS5zaGFwZX1gKTtcbiAgfVxuICBpZiAoJG5ld1NoYXBlLnJhbmsgIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGBOZXcgc2hhcGUgc2hvdWxkIGJlIFRlbnNvcjFEIGJ1dCByZWNlaXZlZCBzaGFwZSAkeyRuZXdTaGFwZS5zaGFwZX1gKTtcbiAgfVxuXG4gIGNvbnN0IGlucHV0czogU3BhcnNlUmVzaGFwZUlucHV0cyA9IHtcbiAgICBpbnB1dEluZGljZXM6ICRpbnB1dEluZGljZXMsXG4gICAgaW5wdXRTaGFwZTogJGlucHV0U2hhcGUsXG4gICAgbmV3U2hhcGU6ICRuZXdTaGFwZVxuICB9O1xuICBjb25zdCByZXN1bHQ6IFRlbnNvcltdID0gRU5HSU5FLnJ1bktlcm5lbChTcGFyc2VSZXNoYXBlLCBpbnB1dHMgYXMge30pO1xuICByZXR1cm4ge291dHB1dEluZGljZXM6IHJlc3VsdFswXSwgb3V0cHV0U2hhcGU6IHJlc3VsdFsxXX07XG59XG5cbmV4cG9ydCBjb25zdCBzcGFyc2VSZXNoYXBlID0gLyogQF9fUFVSRV9fICovIG9wKHtzcGFyc2VSZXNoYXBlX30pO1xuIl19
|