/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { Tile } from '../kernel_names';
|
import { add } from '../ops/add';
|
import { slice } from '../ops/slice';
|
import { zerosLike } from '../ops/zeros_like';
|
export const tileGradConfig = {
|
kernelName: Tile,
|
inputsToSave: ['x'],
|
gradFunc: (dy, saved, attrs) => {
|
const [x] = saved;
|
const { reps } = attrs;
|
const derX = () => {
|
let xGrad = zerosLike(x);
|
// TODO(cais): Maybe reduce memory footprint by avoiding repeated
|
// slicing.
|
if (x.rank === 1) {
|
for (let i = 0; i < reps[0]; ++i) {
|
xGrad = add(xGrad, slice(dy, [i * x.shape[0]], [x.shape[0]]));
|
}
|
}
|
else if (x.rank === 2) {
|
for (let i = 0; i < reps[0]; ++i) {
|
for (let j = 0; j < reps[1]; ++j) {
|
xGrad = add(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1]], [
|
x.shape[0], x.shape[1]
|
]));
|
}
|
}
|
}
|
else if (x.rank === 3) {
|
for (let i = 0; i < reps[0]; ++i) {
|
for (let j = 0; j < reps[1]; ++j) {
|
for (let k = 0; k < reps[2]; ++k) {
|
xGrad =
|
add(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
|
}
|
}
|
}
|
}
|
else if (x.rank === 4) {
|
for (let i = 0; i < reps[0]; ++i) {
|
for (let j = 0; j < reps[1]; ++j) {
|
for (let k = 0; k < reps[2]; ++k) {
|
for (let l = 0; l < reps[3]; ++l) {
|
xGrad =
|
add(xGrad, slice(dy, [
|
i * x.shape[0], j * x.shape[1], k * x.shape[2],
|
l * x.shape[3]
|
], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
|
}
|
}
|
}
|
}
|
}
|
else {
|
throw new Error(`Gradient for tile operation is not implemented for rank-` +
|
`${x.rank} tensors yet.`);
|
}
|
return xGrad;
|
};
|
return { x: derX };
|
},
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiVGlsZV9ncmFkLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9ncmFkaWVudHMvVGlsZV9ncmFkLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxJQUFJLEVBQVksTUFBTSxpQkFBaUIsQ0FBQztBQUVoRCxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxLQUFLLEVBQUMsTUFBTSxjQUFjLENBQUM7QUFDbkMsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLG1CQUFtQixDQUFDO0FBRzVDLE1BQU0sQ0FBQyxNQUFNLGNBQWMsR0FBZTtJQUN4QyxVQUFVLEVBQUUsSUFBSTtJQUNoQixZQUFZLEVBQUUsQ0FBQyxHQUFHLENBQUM7SUFDbkIsUUFBUSxFQUFFLENBQUMsRUFBVSxFQUFFLEtBQWUsRUFBRSxLQUFtQixFQUFFLEVBQUU7UUFDN0QsTUFBTSxDQUFDLENBQUMsQ0FBQyxHQUFHLEtBQUssQ0FBQztRQUNsQixNQUFNLEVBQUMsSUFBSSxFQUFDLEdBQUcsS0FBNkIsQ0FBQztRQUU3QyxNQUFNLElBQUksR0FBRyxHQUFHLEVBQUU7WUFDaEIsSUFBSSxLQUFLLEdBQUcsU0FBUyxDQUFDLENBQUMsQ0FBQyxDQUFDO1lBQ3pCLGlFQUFpRTtZQUNqRSxXQUFXO1lBQ1gsSUFBSSxDQUFDLENBQUMsSUFBSSxLQUFLLENBQUMsRUFBRTtnQkFDaEIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTtvQkFDaEMsS0FBSyxHQUFHLEdBQUcsQ0FBQyxLQUFLLEVBQUUsS0FBSyxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO2lCQUMvRDthQUNGO2lCQUFNLElBQUksQ0FBQyxDQUFDLElBQUksS0FBSyxDQUFDLEVBQUU7Z0JBQ3ZCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLEVBQUU7b0JBQ2hDLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLEVBQUU7d0JBQ2hDLEtBQUssR0FBRyxHQUFHLENBQUMsS0FBSyxFQUFFLEtBQUssQ0FBQyxFQUFFLEVBQUUsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFOzRCQUNqRCxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDO3lCQUN2QixDQUFDLENBQUMsQ0FBQztxQkFDakI7aUJBQ0Y7YUFDRjtpQkFBTSxJQUFJLENBQUMsQ0FBQyxJQUFJLEtBQUssQ0FBQyxFQUFFO2dCQUN2QixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxFQUFFO29CQUNoQyxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxFQUFFO3dCQUNoQyxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxFQUFFOzRCQUNoQyxLQUFLO2dDQUNELEdBQUcsQ0FBQyxLQUFLLEVBQ0wsS0FBSyxDQUNELEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQ3BELENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7eUJBQ3BEO3FCQUNGO2lCQUNGO2FBQ0Y7aUJBQU0sSUFBSSxDQUFDLENBQUMsSUFBSSxLQUFLLENBQUMsRUFBRTtnQkFDdkIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTtvQkFDaEMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTt3QkFDaEMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTs0QkFDaEMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTtnQ0FDaEMsS0FBSztvQ0FDRCxHQUFHLENBQUMsS0FBSyxFQUNMLEtBQUssQ0FDRCxFQUFFLEVBQ0Y7d0NBQ0UsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDO3dDQUM5QyxDQUFDLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUM7cUNBQ2YsRUFDRCxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7NkJBQ2hFO3lCQUNGO3FCQUNGO2lCQUNGO2FBQ0Y7aUJBQU07Z0JBQ0wsTUFBTSxJQUFJLEtBQUssQ0FDWCwwREFBMEQ7b0JBQzFELEdBQUcsQ0FBQyxDQUFDLElBQUksZUFBZSxDQUFDLENBQUM7YUFDL0I7WUFDRCxPQUFPLEtBQUssQ0FBQztRQUNmLENBQUMsQ0FBQztRQUNGLE9BQU8sRUFBQyxDQUFDLEVBQUUsSUFBSSxFQUFDLENBQUM7SUFDbkIsQ0FBQztDQUNGLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7VGlsZSwgVGlsZUF0dHJzfSBmcm9tICcuLi9rZXJuZWxfbmFtZXMnO1xuaW1wb3J0IHtHcmFkQ29uZmlnLCBOYW1lZEF0dHJNYXB9IGZyb20gJy4uL2tlcm5lbF9yZWdpc3RyeSc7XG5pbXBvcnQge2FkZH0gZnJvbSAnLi4vb3BzL2FkZCc7XG5pbXBvcnQge3NsaWNlfSBmcm9tICcuLi9vcHMvc2xpY2UnO1xuaW1wb3J0IHt6ZXJvc0xpa2V9IGZyb20gJy4uL29wcy96ZXJvc19saWtlJztcbmltcG9ydCB7VGVuc29yfSBmcm9tICcuLi90ZW5zb3InO1xuXG5leHBvcnQgY29uc3QgdGlsZUdyYWRDb25maWc6IEdyYWRDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IFRpbGUsXG4gIGlucHV0c1RvU2F2ZTogWyd4J10sXG4gIGdyYWRGdW5jOiAoZHk6IFRlbnNvciwgc2F2ZWQ6IFRlbnNvcltdLCBhdHRyczogTmFtZWRBdHRyTWFwKSA9PiB7XG4gICAgY29uc3QgW3hdID0gc2F2ZWQ7XG4gICAgY29uc3Qge3JlcHN9ID0gYXR0cnMgYXMgdW5rbm93biBhcyBUaWxlQXR0cnM7XG5cbiAgICBjb25zdCBkZXJYID0gKCkgPT4ge1xuICAgICAgbGV0IHhHcmFkID0gemVyb3NMaWtlKHgpO1xuICAgICAgLy8gVE9ETyhjYWlzKTogTWF5YmUgcmVkdWNlIG1lbW9yeSBmb290cHJpbnQgYnkgYXZvaWRpbmcgcmVwZWF0ZWRcbiAgICAgIC8vIHNsaWNpbmcuXG4gICAgICBpZiAoeC5yYW5rID09PSAxKSB7XG4gICAgICAgIGZvciAobGV0IGkgPSAwOyBpIDwgcmVwc1swXTsgKytpKSB7XG4gICAgICAgICAgeEdyYWQgPSBhZGQoeEdyYWQsIHNsaWNlKGR5LCBbaSAqIHguc2hhcGVbMF1dLCBbeC5zaGFwZVswXV0pKTtcbiAgICAgICAgfVxuICAgICAgfSBlbHNlIGlmICh4LnJhbmsgPT09IDIpIHtcbiAgICAgICAgZm9yIChsZXQgaSA9IDA7IGkgPCByZXBzWzBdOyArK2kpIHtcbiAgICAgICAgICBmb3IgKGxldCBqID0gMDsgaiA8IHJlcHNbMV07ICsraikge1xuICAgICAgICAgICAgeEdyYWQgPSBhZGQoeEdyYWQsIHNsaWNlKGR5LCBbaSAqIHguc2hhcGVbMF0sIGogKiB4LnNoYXBlWzFdXSwgW1xuICAgICAgICAgICAgICAgICAgICAgICAgICB4LnNoYXBlWzBdLCB4LnNoYXBlWzFdXG4gICAgICAgICAgICAgICAgICAgICAgICBdKSk7XG4gICAgICAgICAgfVxuICAgICAgICB9XG4gICAgICB9IGVsc2UgaWYgKHgucmFuayA9PT0gMykge1xuICAgICAgICBmb3IgKGxldCBpID0gMDsgaSA8IHJlcHNbMF07ICsraSkge1xuICAgICAgICAgIGZvciAobGV0IGogPSAwOyBqIDwgcmVwc1sxXTsgKytqKSB7XG4gICAgICAgICAgICBmb3IgKGxldCBrID0gMDsgayA8IHJlcHNbMl07ICsraykge1xuICAgICAgICAgICAgICB4R3JhZCA9XG4gICAgICAgICAgICAgICAgICBhZGQoeEdyYWQsXG4gICAgICAgICAgICAgICAgICAgICAgc2xpY2UoXG4gICAgICAgICAgICAgICAgICAgICAgICAgIGR5LCBbaSAqIHguc2hhcGVbMF0sIGogKiB4LnNoYXBlWzFdLCBrICogeC5zaGFwZVsyXV0sXG4gICAgICAgICAgICAgICAgICAgICAgICAgIFt4LnNoYXBlWzBdLCB4LnNoYXBlWzFdLCB4LnNoYXBlWzJdXSkpO1xuICAgICAgICAgICAgfVxuICAgICAgICAgIH1cbiAgICAgICAgfVxuICAgICAgfSBlbHNlIGlmICh4LnJhbmsgPT09IDQpIHtcbiAgICAgICAgZm9yIChsZXQgaSA9IDA7IGkgPCByZXBzWzBdOyArK2kpIHtcbiAgICAgICAgICBmb3IgKGxldCBqID0gMDsgaiA8IHJlcHNbMV07ICsraikge1xuICAgICAgICAgICAgZm9yIChsZXQgayA9IDA7IGsgPCByZXBzWzJdOyArK2spIHtcbiAgICAgICAgICAgICAgZm9yIChsZXQgbCA9IDA7IGwgPCByZXBzWzNdOyArK2wpIHtcbiAgICAgICAgICAgICAgICB4R3JhZCA9XG4gICAgICAgICAgICAgICAgICAgIGFkZCh4R3JhZCxcbiAgICAgICAgICAgICAgICAgICAgICAgIHNsaWNlKFxuICAgICAgICAgICAgICAgICAgICAgICAgICAgIGR5LFxuICAgICAgICAgICAgICAgICAgICAgICAgICAgIFtcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGkgKiB4LnNoYXBlWzBdLCBqICogeC5zaGFwZVsxXSwgayAqIHguc2hhcGVbMl0sXG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICBsICogeC5zaGFwZVszXVxuICAgICAgICAgICAgICAgICAgICAgICAgICAgIF0sXG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgW3guc2hhcGVbMF0sIHguc2hhcGVbMV0sIHguc2hhcGVbMl0sIHguc2hhcGVbM11dKSk7XG4gICAgICAgICAgICAgIH1cbiAgICAgICAgICAgIH1cbiAgICAgICAgICB9XG4gICAgICAgIH1cbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgICAgIGBHcmFkaWVudCBmb3IgdGlsZSBvcGVyYXRpb24gaXMgbm90IGltcGxlbWVudGVkIGZvciByYW5rLWAgK1xuICAgICAgICAgICAgYCR7eC5yYW5rfSB0ZW5zb3JzIHlldC5gKTtcbiAgICAgIH1cbiAgICAgIHJldHVybiB4R3JhZDtcbiAgICB9O1xuICAgIHJldHVybiB7eDogZGVyWH07XG4gIH0sXG59O1xuIl19
|