/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { SparseReshape } from '@tensorflow/tfjs-core';
|
import { sparseReshapeImpl } from './SparseReshape_impl';
|
export function sparseReshape(args) {
|
const { inputs, backend } = args;
|
const { inputIndices, inputShape, newShape } = inputs;
|
if (inputIndices.shape.length !== 2) {
|
throw new Error(`Input indices should be a matrix but received shape
|
${inputIndices.shape}`);
|
}
|
if (inputShape.shape.length !== 1) {
|
throw new Error(`Input shape should be a vector but received shape
|
${inputShape.shape}`);
|
}
|
if (newShape.shape.length !== 1) {
|
throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
|
}
|
const $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
|
const $inputIndices = backend.data.get(inputIndices.dataId).values;
|
const targetShape = Array.from(backend.data.get(newShape.dataId).values);
|
const [newIndices, indicesShape, outputShape] = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
|
return [
|
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
|
backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
|
];
|
}
|
export const sparseReshapeConfig = {
|
kernelName: SparseReshape,
|
backendName: 'cpu',
|
kernelFunc: sparseReshape,
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhcnNlUmVzaGFwZS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC1jcHUvc3JjL2tlcm5lbHMvU3BhcnNlUmVzaGFwZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQWUsYUFBYSxFQUE4QyxNQUFNLHVCQUF1QixDQUFDO0FBSS9HLE9BQU8sRUFBQyxpQkFBaUIsRUFBQyxNQUFNLHNCQUFzQixDQUFDO0FBRXZELE1BQU0sVUFBVSxhQUFhLENBQ3pCLElBQTREO0lBRTlELE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQy9CLE1BQU0sRUFBQyxZQUFZLEVBQUUsVUFBVSxFQUFFLFFBQVEsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNwRCxJQUFJLFlBQVksQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNuQyxNQUFNLElBQUksS0FBSyxDQUFDO1VBQ1YsWUFBWSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDN0I7SUFDRCxJQUFJLFVBQVUsQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNqQyxNQUFNLElBQUksS0FBSyxDQUFDO1VBQ1YsVUFBVSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDM0I7SUFFRCxJQUFJLFFBQVEsQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUMvQixNQUFNLElBQUksS0FBSyxDQUNYLHNEQUFzRCxRQUFRLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQztLQUM3RTtJQUVELE1BQU0sV0FBVyxHQUNiLEtBQUssQ0FBQyxJQUFJLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQW9CLENBQUMsQ0FBQztJQUN6RSxNQUFNLGFBQWEsR0FDZixPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxZQUFZLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUMvRCxNQUFNLFdBQVcsR0FDYixLQUFLLENBQUMsSUFBSSxDQUFDLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLFFBQVEsQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDLENBQUM7SUFFdkUsTUFBTSxDQUFDLFVBQVUsRUFBRSxZQUFZLEVBQUUsV0FBVyxDQUFDLEdBQUcsaUJBQWlCLENBQzdELGFBQWEsRUFBRSxZQUFZLENBQUMsS0FBSyxFQUFFLFlBQVksQ0FBQyxLQUFLLEVBQUUsV0FBVyxFQUNsRSxXQUFXLENBQUMsQ0FBQztJQUNqQixPQUFPO1FBQ0wsT0FBTyxDQUFDLGNBQWMsQ0FBQyxZQUFZLEVBQUUsWUFBWSxDQUFDLEtBQUssRUFBRSxVQUFVLENBQUM7UUFDcEUsT0FBTyxDQUFDLGNBQWMsQ0FDbEIsQ0FBQyxXQUFXLENBQUMsTUFBTSxDQUFDLEVBQUUsUUFBUSxDQUFDLEtBQUssRUFBRSxJQUFJLFVBQVUsQ0FBQyxXQUFXLENBQUMsQ0FBQztLQUN2RSxDQUFDO0FBQ0osQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLG1CQUFtQixHQUFpQjtJQUMvQyxVQUFVLEVBQUUsYUFBYTtJQUN6QixXQUFXLEVBQUUsS0FBSztJQUNsQixVQUFVLEVBQUUsYUFBYTtDQUMxQixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjEgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0tlcm5lbENvbmZpZywgU3BhcnNlUmVzaGFwZSwgU3BhcnNlUmVzaGFwZUlucHV0cywgVGVuc29ySW5mbywgVHlwZWRBcnJheX0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZENQVX0gZnJvbSAnLi4vYmFja2VuZF9jcHUnO1xuXG5pbXBvcnQge3NwYXJzZVJlc2hhcGVJbXBsfSBmcm9tICcuL1NwYXJzZVJlc2hhcGVfaW1wbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzcGFyc2VSZXNoYXBlKFxuICAgIGFyZ3M6IHtpbnB1dHM6IFNwYXJzZVJlc2hhcGVJbnB1dHMsIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVfSk6XG4gICAgW1RlbnNvckluZm8sIFRlbnNvckluZm9dIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZH0gPSBhcmdzO1xuICBjb25zdCB7aW5wdXRJbmRpY2VzLCBpbnB1dFNoYXBlLCBuZXdTaGFwZX0gPSBpbnB1dHM7XG4gIGlmIChpbnB1dEluZGljZXMuc2hhcGUubGVuZ3RoICE9PSAyKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKGBJbnB1dCBpbmRpY2VzIHNob3VsZCBiZSBhIG1hdHJpeCBidXQgcmVjZWl2ZWQgc2hhcGVcbiAgICAgICAgJHtpbnB1dEluZGljZXMuc2hhcGV9YCk7XG4gIH1cbiAgaWYgKGlucHV0U2hhcGUuc2hhcGUubGVuZ3RoICE9PSAxKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKGBJbnB1dCBzaGFwZSBzaG91bGQgYmUgYSB2ZWN0b3IgYnV0IHJlY2VpdmVkIHNoYXBlXG4gICAgICAgICR7aW5wdXRTaGFwZS5zaGFwZX1gKTtcbiAgfVxuXG4gIGlmIChuZXdTaGFwZS5zaGFwZS5sZW5ndGggIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGBUYXJnZXQgc2hhcGUgc2hvdWxkIGJlIGEgdmVjdG9yIGJ1dCByZWNlaXZlZCBzaGFwZSAke25ld1NoYXBlLnNoYXBlfWApO1xuICB9XG5cbiAgY29uc3QgJGlucHV0U2hhcGUgPVxuICAgICAgQXJyYXkuZnJvbShiYWNrZW5kLmRhdGEuZ2V0KGlucHV0U2hhcGUuZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheSk7XG4gIGNvbnN0ICRpbnB1dEluZGljZXMgPVxuICAgICAgYmFja2VuZC5kYXRhLmdldChpbnB1dEluZGljZXMuZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgdGFyZ2V0U2hhcGUgPVxuICAgICAgQXJyYXkuZnJvbShiYWNrZW5kLmRhdGEuZ2V0KG5ld1NoYXBlLmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXkpO1xuXG4gIGNvbnN0IFtuZXdJbmRpY2VzLCBpbmRpY2VzU2hhcGUsIG91dHB1dFNoYXBlXSA9IHNwYXJzZVJlc2hhcGVJbXBsKFxuICAgICAgJGlucHV0SW5kaWNlcywgaW5wdXRJbmRpY2VzLnNoYXBlLCBpbnB1dEluZGljZXMuZHR5cGUsICRpbnB1dFNoYXBlLFxuICAgICAgdGFyZ2V0U2hhcGUpO1xuICByZXR1cm4gW1xuICAgIGJhY2tlbmQubWFrZVRlbnNvckluZm8oaW5kaWNlc1NoYXBlLCBpbnB1dEluZGljZXMuZHR5cGUsIG5ld0luZGljZXMpLFxuICAgIGJhY2tlbmQubWFrZVRlbnNvckluZm8oXG4gICAgICAgIFtvdXRwdXRTaGFwZS5sZW5ndGhdLCBuZXdTaGFwZS5kdHlwZSwgbmV3IEludDMyQXJyYXkob3V0cHV0U2hhcGUpKSxcbiAgXTtcbn1cblxuZXhwb3J0IGNvbnN0IHNwYXJzZVJlc2hhcGVDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU3BhcnNlUmVzaGFwZSxcbiAgYmFja2VuZE5hbWU6ICdjcHUnLFxuICBrZXJuZWxGdW5jOiBzcGFyc2VSZXNoYXBlLFxufTtcbiJdfQ==
|