/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { BroadcastTo } from '../kernel_names'; import { sum } from '../ops/sum'; export const broadcastToGradConfig = { kernelName: BroadcastTo, gradFunc: (dy, saved, attrs) => { const broadCastToAttrs = attrs; const inputShape = broadCastToAttrs.inputShape; const outputShape = broadCastToAttrs.shape; const reps = Array.from(outputShape); for (let i = inputShape.length - 1; i >= 0; i--) { if (inputShape[i] === outputShape[i]) { reps[i] = 1; } else if (inputShape[i] !== 1) { throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`); } } const axes = []; for (let i = 0; i < reps.length; i++) { if (reps[i] > 1) { axes.push(i); } } return { x: () => sum(dy, axes, true /* keepDims */) }; } }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQnJvYWRjYXN0VG9fZ3JhZC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvZ3JhZGllbnRzL0Jyb2FkY2FzdFRvX2dyYWQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFdBQVcsRUFBbUIsTUFBTSxpQkFBaUIsQ0FBQztBQUU5RCxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBRy9CLE1BQU0sQ0FBQyxNQUFNLHFCQUFxQixHQUFlO0lBQy9DLFVBQVUsRUFBRSxXQUFXO0lBQ3ZCLFFBQVEsRUFBRSxDQUFDLEVBQVUsRUFBRSxLQUFlLEVBQUUsS0FBbUIsRUFBRSxFQUFFO1FBQzdELE1BQU0sZ0JBQWdCLEdBQ2xCLEtBQW9DLENBQUM7UUFFekMsTUFBTSxVQUFVLEdBQUcsZ0JBQWdCLENBQUMsVUFBVSxDQUFDO1FBQy9DLE1BQU0sV0FBVyxHQUFHLGdCQUFnQixDQUFDLEtBQUssQ0FBQztRQUUzQyxNQUFNLElBQUksR0FBYSxLQUFLLENBQUMsSUFBSSxDQUFDLFdBQVcsQ0FBQyxDQUFDO1FBQy9DLEtBQUssSUFBSSxDQUFDLEdBQUcsVUFBVSxDQUFDLE1BQU0sR0FBRyxDQUFDLEVBQUUsQ0FBQyxJQUFJLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRTtZQUMvQyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxXQUFXLENBQUMsQ0FBQyxDQUFDLEVBQUU7Z0JBQ3BDLElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7YUFDYjtpQkFBTSxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7Z0JBQzlCLE1BQU0sSUFBSSxLQUFLLENBQUMsbUJBQ1osVUFBVSw2QkFBNkIsV0FBVyxJQUFJLENBQUMsQ0FBQzthQUM3RDtTQUNGO1FBQ0QsTUFBTSxJQUFJLEdBQWEsRUFBRSxDQUFDO1FBQzFCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1lBQ3BDLElBQUksSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsRUFBRTtnQkFDZixJQUFJLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO2FBQ2Q7U0FDRjtRQUVELE9BQU8sRUFBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLEVBQUUsRUFBRSxJQUFJLEVBQUUsSUFBSSxDQUFDLGNBQWMsQ0FBQyxFQUFDLENBQUM7SUFDdkQsQ0FBQztDQUNGLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7QnJvYWRjYXN0VG8sIEJyb2FkQ2FzdFRvQXR0cnN9IGZyb20gJy4uL2tlcm5lbF9uYW1lcyc7XG5pbXBvcnQge0dyYWRDb25maWcsIE5hbWVkQXR0ck1hcH0gZnJvbSAnLi4va2VybmVsX3JlZ2lzdHJ5JztcbmltcG9ydCB7c3VtfSBmcm9tICcuLi9vcHMvc3VtJztcbmltcG9ydCB7VGVuc29yfSBmcm9tICcuLi90ZW5zb3InO1xuXG5leHBvcnQgY29uc3QgYnJvYWRjYXN0VG9HcmFkQ29uZmlnOiBHcmFkQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBCcm9hZGNhc3RUbyxcbiAgZ3JhZEZ1bmM6IChkeTogVGVuc29yLCBzYXZlZDogVGVuc29yW10sIGF0dHJzOiBOYW1lZEF0dHJNYXApID0+IHtcbiAgICBjb25zdCBicm9hZENhc3RUb0F0dHJzOiBCcm9hZENhc3RUb0F0dHJzID1cbiAgICAgICAgYXR0cnMgYXMgdW5rbm93biBhcyBCcm9hZENhc3RUb0F0dHJzO1xuXG4gICAgY29uc3QgaW5wdXRTaGFwZSA9IGJyb2FkQ2FzdFRvQXR0cnMuaW5wdXRTaGFwZTtcbiAgICBjb25zdCBvdXRwdXRTaGFwZSA9IGJyb2FkQ2FzdFRvQXR0cnMuc2hhcGU7XG5cbiAgICBjb25zdCByZXBzOiBudW1iZXJbXSA9IEFycmF5LmZyb20ob3V0cHV0U2hhcGUpO1xuICAgIGZvciAobGV0IGkgPSBpbnB1dFNoYXBlLmxlbmd0aCAtIDE7IGkgPj0gMDsgaS0tKSB7XG4gICAgICBpZiAoaW5wdXRTaGFwZVtpXSA9PT0gb3V0cHV0U2hhcGVbaV0pIHtcbiAgICAgICAgcmVwc1tpXSA9IDE7XG4gICAgICB9IGVsc2UgaWYgKGlucHV0U2hhcGVbaV0gIT09IDEpIHtcbiAgICAgICAgdGhyb3cgbmV3IEVycm9yKGBicm9hZGNhc3RUbygpOiBbJHtcbiAgICAgICAgICAgIGlucHV0U2hhcGV9XSBjYW5ub3QgYmUgYnJvYWRjYXN0IHRvIFske291dHB1dFNoYXBlfV0uYCk7XG4gICAgICB9XG4gICAgfVxuICAgIGNvbnN0IGF4ZXM6IG51bWJlcltdID0gW107XG4gICAgZm9yIChsZXQgaSA9IDA7IGkgPCByZXBzLmxlbmd0aDsgaSsrKSB7XG4gICAgICBpZiAocmVwc1tpXSA+IDEpIHtcbiAgICAgICAgYXhlcy5wdXNoKGkpO1xuICAgICAgfVxuICAgIH1cblxuICAgIHJldHVybiB7eDogKCkgPT4gc3VtKGR5LCBheGVzLCB0cnVlIC8qIGtlZXBEaW1zICovKX07XG4gIH1cbn07XG4iXX0=