gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { SparseReshape } from '@tensorflow/tfjs-core';
import { sparseReshapeImpl } from './SparseReshape_impl';
export function sparseReshape(args) {
    const { inputs, backend } = args;
    const { inputIndices, inputShape, newShape } = inputs;
    if (inputIndices.shape.length !== 2) {
        throw new Error(`Input indices should be a matrix but received shape
        ${inputIndices.shape}`);
    }
    if (inputShape.shape.length !== 1) {
        throw new Error(`Input shape should be a vector but received shape
        ${inputShape.shape}`);
    }
    if (newShape.shape.length !== 1) {
        throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
    }
    const $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
    const $inputIndices = backend.data.get(inputIndices.dataId).values;
    const targetShape = Array.from(backend.data.get(newShape.dataId).values);
    const [newIndices, indicesShape, outputShape] = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
    return [
        backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
        backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
    ];
}
export const sparseReshapeConfig = {
    kernelName: SparseReshape,
    backendName: 'cpu',
    kernelFunc: sparseReshape,
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhcnNlUmVzaGFwZS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC1jcHUvc3JjL2tlcm5lbHMvU3BhcnNlUmVzaGFwZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQWUsYUFBYSxFQUE4QyxNQUFNLHVCQUF1QixDQUFDO0FBSS9HLE9BQU8sRUFBQyxpQkFBaUIsRUFBQyxNQUFNLHNCQUFzQixDQUFDO0FBRXZELE1BQU0sVUFBVSxhQUFhLENBQ3pCLElBQTREO0lBRTlELE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQy9CLE1BQU0sRUFBQyxZQUFZLEVBQUUsVUFBVSxFQUFFLFFBQVEsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNwRCxJQUFJLFlBQVksQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNuQyxNQUFNLElBQUksS0FBSyxDQUFDO1VBQ1YsWUFBWSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDN0I7SUFDRCxJQUFJLFVBQVUsQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNqQyxNQUFNLElBQUksS0FBSyxDQUFDO1VBQ1YsVUFBVSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDM0I7SUFFRCxJQUFJLFFBQVEsQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUMvQixNQUFNLElBQUksS0FBSyxDQUNYLHNEQUFzRCxRQUFRLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQztLQUM3RTtJQUVELE1BQU0sV0FBVyxHQUNiLEtBQUssQ0FBQyxJQUFJLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQW9CLENBQUMsQ0FBQztJQUN6RSxNQUFNLGFBQWEsR0FDZixPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxZQUFZLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUMvRCxNQUFNLFdBQVcsR0FDYixLQUFLLENBQUMsSUFBSSxDQUFDLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLFFBQVEsQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDLENBQUM7SUFFdkUsTUFBTSxDQUFDLFVBQVUsRUFBRSxZQUFZLEVBQUUsV0FBVyxDQUFDLEdBQUcsaUJBQWlCLENBQzdELGFBQWEsRUFBRSxZQUFZLENBQUMsS0FBSyxFQUFFLFlBQVksQ0FBQyxLQUFLLEVBQUUsV0FBVyxFQUNsRSxXQUFXLENBQUMsQ0FBQztJQUNqQixPQUFPO1FBQ0wsT0FBTyxDQUFDLGNBQWMsQ0FBQyxZQUFZLEVBQUUsWUFBWSxDQUFDLEtBQUssRUFBRSxVQUFVLENBQUM7UUFDcEUsT0FBTyxDQUFDLGNBQWMsQ0FDbEIsQ0FBQyxXQUFXLENBQUMsTUFBTSxDQUFDLEVBQUUsUUFBUSxDQUFDLEtBQUssRUFBRSxJQUFJLFVBQVUsQ0FBQyxXQUFXLENBQUMsQ0FBQztLQUN2RSxDQUFDO0FBQ0osQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLG1CQUFtQixHQUFpQjtJQUMvQyxVQUFVLEVBQUUsYUFBYTtJQUN6QixXQUFXLEVBQUUsS0FBSztJQUNsQixVQUFVLEVBQUUsYUFBYTtDQUMxQixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjEgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0tlcm5lbENvbmZpZywgU3BhcnNlUmVzaGFwZSwgU3BhcnNlUmVzaGFwZUlucHV0cywgVGVuc29ySW5mbywgVHlwZWRBcnJheX0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZENQVX0gZnJvbSAnLi4vYmFja2VuZF9jcHUnO1xuXG5pbXBvcnQge3NwYXJzZVJlc2hhcGVJbXBsfSBmcm9tICcuL1NwYXJzZVJlc2hhcGVfaW1wbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzcGFyc2VSZXNoYXBlKFxuICAgIGFyZ3M6IHtpbnB1dHM6IFNwYXJzZVJlc2hhcGVJbnB1dHMsIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVfSk6XG4gICAgW1RlbnNvckluZm8sIFRlbnNvckluZm9dIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZH0gPSBhcmdzO1xuICBjb25zdCB7aW5wdXRJbmRpY2VzLCBpbnB1dFNoYXBlLCBuZXdTaGFwZX0gPSBpbnB1dHM7XG4gIGlmIChpbnB1dEluZGljZXMuc2hhcGUubGVuZ3RoICE9PSAyKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKGBJbnB1dCBpbmRpY2VzIHNob3VsZCBiZSBhIG1hdHJpeCBidXQgcmVjZWl2ZWQgc2hhcGVcbiAgICAgICAgJHtpbnB1dEluZGljZXMuc2hhcGV9YCk7XG4gIH1cbiAgaWYgKGlucHV0U2hhcGUuc2hhcGUubGVuZ3RoICE9PSAxKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKGBJbnB1dCBzaGFwZSBzaG91bGQgYmUgYSB2ZWN0b3IgYnV0IHJlY2VpdmVkIHNoYXBlXG4gICAgICAgICR7aW5wdXRTaGFwZS5zaGFwZX1gKTtcbiAgfVxuXG4gIGlmIChuZXdTaGFwZS5zaGFwZS5sZW5ndGggIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGBUYXJnZXQgc2hhcGUgc2hvdWxkIGJlIGEgdmVjdG9yIGJ1dCByZWNlaXZlZCBzaGFwZSAke25ld1NoYXBlLnNoYXBlfWApO1xuICB9XG5cbiAgY29uc3QgJGlucHV0U2hhcGUgPVxuICAgICAgQXJyYXkuZnJvbShiYWNrZW5kLmRhdGEuZ2V0KGlucHV0U2hhcGUuZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheSk7XG4gIGNvbnN0ICRpbnB1dEluZGljZXMgPVxuICAgICAgYmFja2VuZC5kYXRhLmdldChpbnB1dEluZGljZXMuZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgdGFyZ2V0U2hhcGUgPVxuICAgICAgQXJyYXkuZnJvbShiYWNrZW5kLmRhdGEuZ2V0KG5ld1NoYXBlLmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXkpO1xuXG4gIGNvbnN0IFtuZXdJbmRpY2VzLCBpbmRpY2VzU2hhcGUsIG91dHB1dFNoYXBlXSA9IHNwYXJzZVJlc2hhcGVJbXBsKFxuICAgICAgJGlucHV0SW5kaWNlcywgaW5wdXRJbmRpY2VzLnNoYXBlLCBpbnB1dEluZGljZXMuZHR5cGUsICRpbnB1dFNoYXBlLFxuICAgICAgdGFyZ2V0U2hhcGUpO1xuICByZXR1cm4gW1xuICAgIGJhY2tlbmQubWFrZVRlbnNvckluZm8oaW5kaWNlc1NoYXBlLCBpbnB1dEluZGljZXMuZHR5cGUsIG5ld0luZGljZXMpLFxuICAgIGJhY2tlbmQubWFrZVRlbnNvckluZm8oXG4gICAgICAgIFtvdXRwdXRTaGFwZS5sZW5ndGhdLCBuZXdTaGFwZS5kdHlwZSwgbmV3IEludDMyQXJyYXkob3V0cHV0U2hhcGUpKSxcbiAgXTtcbn1cblxuZXhwb3J0IGNvbnN0IHNwYXJzZVJlc2hhcGVDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU3BhcnNlUmVzaGFwZSxcbiAgYmFja2VuZE5hbWU6ICdjcHUnLFxuICBrZXJuZWxGdW5jOiBzcGFyc2VSZXNoYXBlLFxufTtcbiJdfQ==