/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { BroadcastTo } from '../kernel_names';
|
import { sum } from '../ops/sum';
|
export const broadcastToGradConfig = {
|
kernelName: BroadcastTo,
|
gradFunc: (dy, saved, attrs) => {
|
const broadCastToAttrs = attrs;
|
const inputShape = broadCastToAttrs.inputShape;
|
const outputShape = broadCastToAttrs.shape;
|
const reps = Array.from(outputShape);
|
for (let i = inputShape.length - 1; i >= 0; i--) {
|
if (inputShape[i] === outputShape[i]) {
|
reps[i] = 1;
|
}
|
else if (inputShape[i] !== 1) {
|
throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`);
|
}
|
}
|
const axes = [];
|
for (let i = 0; i < reps.length; i++) {
|
if (reps[i] > 1) {
|
axes.push(i);
|
}
|
}
|
return { x: () => sum(dy, axes, true /* keepDims */) };
|
}
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQnJvYWRjYXN0VG9fZ3JhZC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvZ3JhZGllbnRzL0Jyb2FkY2FzdFRvX2dyYWQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFdBQVcsRUFBbUIsTUFBTSxpQkFBaUIsQ0FBQztBQUU5RCxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBRy9CLE1BQU0sQ0FBQyxNQUFNLHFCQUFxQixHQUFlO0lBQy9DLFVBQVUsRUFBRSxXQUFXO0lBQ3ZCLFFBQVEsRUFBRSxDQUFDLEVBQVUsRUFBRSxLQUFlLEVBQUUsS0FBbUIsRUFBRSxFQUFFO1FBQzdELE1BQU0sZ0JBQWdCLEdBQ2xCLEtBQW9DLENBQUM7UUFFekMsTUFBTSxVQUFVLEdBQUcsZ0JBQWdCLENBQUMsVUFBVSxDQUFDO1FBQy9DLE1BQU0sV0FBVyxHQUFHLGdCQUFnQixDQUFDLEtBQUssQ0FBQztRQUUzQyxNQUFNLElBQUksR0FBYSxLQUFLLENBQUMsSUFBSSxDQUFDLFdBQVcsQ0FBQyxDQUFDO1FBQy9DLEtBQUssSUFBSSxDQUFDLEdBQUcsVUFBVSxDQUFDLE1BQU0sR0FBRyxDQUFDLEVBQUUsQ0FBQyxJQUFJLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRTtZQUMvQyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxXQUFXLENBQUMsQ0FBQyxDQUFDLEVBQUU7Z0JBQ3BDLElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7YUFDYjtpQkFBTSxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7Z0JBQzlCLE1BQU0sSUFBSSxLQUFLLENBQUMsbUJBQ1osVUFBVSw2QkFBNkIsV0FBVyxJQUFJLENBQUMsQ0FBQzthQUM3RDtTQUNGO1FBQ0QsTUFBTSxJQUFJLEdBQWEsRUFBRSxDQUFDO1FBQzFCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1lBQ3BDLElBQUksSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsRUFBRTtnQkFDZixJQUFJLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO2FBQ2Q7U0FDRjtRQUVELE9BQU8sRUFBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLEVBQUUsRUFBRSxJQUFJLEVBQUUsSUFBSSxDQUFDLGNBQWMsQ0FBQyxFQUFDLENBQUM7SUFDdkQsQ0FBQztDQUNGLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7QnJvYWRjYXN0VG8sIEJyb2FkQ2FzdFRvQXR0cnN9IGZyb20gJy4uL2tlcm5lbF9uYW1lcyc7XG5pbXBvcnQge0dyYWRDb25maWcsIE5hbWVkQXR0ck1hcH0gZnJvbSAnLi4va2VybmVsX3JlZ2lzdHJ5JztcbmltcG9ydCB7c3VtfSBmcm9tICcuLi9vcHMvc3VtJztcbmltcG9ydCB7VGVuc29yfSBmcm9tICcuLi90ZW5zb3InO1xuXG5leHBvcnQgY29uc3QgYnJvYWRjYXN0VG9HcmFkQ29uZmlnOiBHcmFkQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBCcm9hZGNhc3RUbyxcbiAgZ3JhZEZ1bmM6IChkeTogVGVuc29yLCBzYXZlZDogVGVuc29yW10sIGF0dHJzOiBOYW1lZEF0dHJNYXApID0+IHtcbiAgICBjb25zdCBicm9hZENhc3RUb0F0dHJzOiBCcm9hZENhc3RUb0F0dHJzID1cbiAgICAgICAgYXR0cnMgYXMgdW5rbm93biBhcyBCcm9hZENhc3RUb0F0dHJzO1xuXG4gICAgY29uc3QgaW5wdXRTaGFwZSA9IGJyb2FkQ2FzdFRvQXR0cnMuaW5wdXRTaGFwZTtcbiAgICBjb25zdCBvdXRwdXRTaGFwZSA9IGJyb2FkQ2FzdFRvQXR0cnMuc2hhcGU7XG5cbiAgICBjb25zdCByZXBzOiBudW1iZXJbXSA9IEFycmF5LmZyb20ob3V0cHV0U2hhcGUpO1xuICAgIGZvciAobGV0IGkgPSBpbnB1dFNoYXBlLmxlbmd0aCAtIDE7IGkgPj0gMDsgaS0tKSB7XG4gICAgICBpZiAoaW5wdXRTaGFwZVtpXSA9PT0gb3V0cHV0U2hhcGVbaV0pIHtcbiAgICAgICAgcmVwc1tpXSA9IDE7XG4gICAgICB9IGVsc2UgaWYgKGlucHV0U2hhcGVbaV0gIT09IDEpIHtcbiAgICAgICAgdGhyb3cgbmV3IEVycm9yKGBicm9hZGNhc3RUbygpOiBbJHtcbiAgICAgICAgICAgIGlucHV0U2hhcGV9XSBjYW5ub3QgYmUgYnJvYWRjYXN0IHRvIFske291dHB1dFNoYXBlfV0uYCk7XG4gICAgICB9XG4gICAgfVxuICAgIGNvbnN0IGF4ZXM6IG51bWJlcltdID0gW107XG4gICAgZm9yIChsZXQgaSA9IDA7IGkgPCByZXBzLmxlbmd0aDsgaSsrKSB7XG4gICAgICBpZiAocmVwc1tpXSA+IDEpIHtcbiAgICAgICAgYXhlcy5wdXNoKGkpO1xuICAgICAgfVxuICAgIH1cblxuICAgIHJldHVybiB7eDogKCkgPT4gc3VtKGR5LCBheGVzLCB0cnVlIC8qIGtlZXBEaW1zICovKX07XG4gIH1cbn07XG4iXX0=
|