gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { BroadcastTo } from '../kernel_names';
import { sum } from '../ops/sum';
export const broadcastToGradConfig = {
    kernelName: BroadcastTo,
    gradFunc: (dy, saved, attrs) => {
        const broadCastToAttrs = attrs;
        const inputShape = broadCastToAttrs.inputShape;
        const outputShape = broadCastToAttrs.shape;
        const reps = Array.from(outputShape);
        for (let i = inputShape.length - 1; i >= 0; i--) {
            if (inputShape[i] === outputShape[i]) {
                reps[i] = 1;
            }
            else if (inputShape[i] !== 1) {
                throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`);
            }
        }
        const axes = [];
        for (let i = 0; i < reps.length; i++) {
            if (reps[i] > 1) {
                axes.push(i);
            }
        }
        return { x: () => sum(dy, axes, true /* keepDims */) };
    }
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQnJvYWRjYXN0VG9fZ3JhZC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvZ3JhZGllbnRzL0Jyb2FkY2FzdFRvX2dyYWQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFdBQVcsRUFBbUIsTUFBTSxpQkFBaUIsQ0FBQztBQUU5RCxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBRy9CLE1BQU0sQ0FBQyxNQUFNLHFCQUFxQixHQUFlO0lBQy9DLFVBQVUsRUFBRSxXQUFXO0lBQ3ZCLFFBQVEsRUFBRSxDQUFDLEVBQVUsRUFBRSxLQUFlLEVBQUUsS0FBbUIsRUFBRSxFQUFFO1FBQzdELE1BQU0sZ0JBQWdCLEdBQ2xCLEtBQW9DLENBQUM7UUFFekMsTUFBTSxVQUFVLEdBQUcsZ0JBQWdCLENBQUMsVUFBVSxDQUFDO1FBQy9DLE1BQU0sV0FBVyxHQUFHLGdCQUFnQixDQUFDLEtBQUssQ0FBQztRQUUzQyxNQUFNLElBQUksR0FBYSxLQUFLLENBQUMsSUFBSSxDQUFDLFdBQVcsQ0FBQyxDQUFDO1FBQy9DLEtBQUssSUFBSSxDQUFDLEdBQUcsVUFBVSxDQUFDLE1BQU0sR0FBRyxDQUFDLEVBQUUsQ0FBQyxJQUFJLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRTtZQUMvQyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxXQUFXLENBQUMsQ0FBQyxDQUFDLEVBQUU7Z0JBQ3BDLElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7YUFDYjtpQkFBTSxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7Z0JBQzlCLE1BQU0sSUFBSSxLQUFLLENBQUMsbUJBQ1osVUFBVSw2QkFBNkIsV0FBVyxJQUFJLENBQUMsQ0FBQzthQUM3RDtTQUNGO1FBQ0QsTUFBTSxJQUFJLEdBQWEsRUFBRSxDQUFDO1FBQzFCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1lBQ3BDLElBQUksSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsRUFBRTtnQkFDZixJQUFJLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO2FBQ2Q7U0FDRjtRQUVELE9BQU8sRUFBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLEVBQUUsRUFBRSxJQUFJLEVBQUUsSUFBSSxDQUFDLGNBQWMsQ0FBQyxFQUFDLENBQUM7SUFDdkQsQ0FBQztDQUNGLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7QnJvYWRjYXN0VG8sIEJyb2FkQ2FzdFRvQXR0cnN9IGZyb20gJy4uL2tlcm5lbF9uYW1lcyc7XG5pbXBvcnQge0dyYWRDb25maWcsIE5hbWVkQXR0ck1hcH0gZnJvbSAnLi4va2VybmVsX3JlZ2lzdHJ5JztcbmltcG9ydCB7c3VtfSBmcm9tICcuLi9vcHMvc3VtJztcbmltcG9ydCB7VGVuc29yfSBmcm9tICcuLi90ZW5zb3InO1xuXG5leHBvcnQgY29uc3QgYnJvYWRjYXN0VG9HcmFkQ29uZmlnOiBHcmFkQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBCcm9hZGNhc3RUbyxcbiAgZ3JhZEZ1bmM6IChkeTogVGVuc29yLCBzYXZlZDogVGVuc29yW10sIGF0dHJzOiBOYW1lZEF0dHJNYXApID0+IHtcbiAgICBjb25zdCBicm9hZENhc3RUb0F0dHJzOiBCcm9hZENhc3RUb0F0dHJzID1cbiAgICAgICAgYXR0cnMgYXMgdW5rbm93biBhcyBCcm9hZENhc3RUb0F0dHJzO1xuXG4gICAgY29uc3QgaW5wdXRTaGFwZSA9IGJyb2FkQ2FzdFRvQXR0cnMuaW5wdXRTaGFwZTtcbiAgICBjb25zdCBvdXRwdXRTaGFwZSA9IGJyb2FkQ2FzdFRvQXR0cnMuc2hhcGU7XG5cbiAgICBjb25zdCByZXBzOiBudW1iZXJbXSA9IEFycmF5LmZyb20ob3V0cHV0U2hhcGUpO1xuICAgIGZvciAobGV0IGkgPSBpbnB1dFNoYXBlLmxlbmd0aCAtIDE7IGkgPj0gMDsgaS0tKSB7XG4gICAgICBpZiAoaW5wdXRTaGFwZVtpXSA9PT0gb3V0cHV0U2hhcGVbaV0pIHtcbiAgICAgICAgcmVwc1tpXSA9IDE7XG4gICAgICB9IGVsc2UgaWYgKGlucHV0U2hhcGVbaV0gIT09IDEpIHtcbiAgICAgICAgdGhyb3cgbmV3IEVycm9yKGBicm9hZGNhc3RUbygpOiBbJHtcbiAgICAgICAgICAgIGlucHV0U2hhcGV9XSBjYW5ub3QgYmUgYnJvYWRjYXN0IHRvIFske291dHB1dFNoYXBlfV0uYCk7XG4gICAgICB9XG4gICAgfVxuICAgIGNvbnN0IGF4ZXM6IG51bWJlcltdID0gW107XG4gICAgZm9yIChsZXQgaSA9IDA7IGkgPCByZXBzLmxlbmd0aDsgaSsrKSB7XG4gICAgICBpZiAocmVwc1tpXSA+IDEpIHtcbiAgICAgICAgYXhlcy5wdXNoKGkpO1xuICAgICAgfVxuICAgIH1cblxuICAgIHJldHVybiB7eDogKCkgPT4gc3VtKGR5LCBheGVzLCB0cnVlIC8qIGtlZXBEaW1zICovKX07XG4gIH1cbn07XG4iXX0=