/**
|
* @license
|
* Copyright 2018 Google Inc. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
|
import {Tensor4D} from '../../tensor';
|
import {GPGPUProgram} from './gpgpu_math';
|
|
export class ResizeBilinearBackpropProgram implements GPGPUProgram {
|
variableNames = ['dy'];
|
outputShape: number[] = [];
|
userCode: string;
|
|
constructor(dy: Tensor4D, x: Tensor4D, alignCorners: boolean) {
|
this.outputShape = x.shape;
|
const [, xHeight, xWidth, ] = x.shape;
|
const [, yHeight, yWidth] = dy.shape;
|
|
// In the backwards pass, we want to find the pixels that were generated for
|
// each pixel in the input image the forward pass and add the corresponding
|
// coefficient from dy to the gradient (with some interpolation).
|
|
const effectiveXSize: [number, number] = [
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
];
|
|
const effectiveYSize: [number, number] = [
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
];
|
|
const heightScale = effectiveXSize[0] / effectiveYSize[0];
|
const widthScale = effectiveXSize[1] / effectiveYSize[1];
|
|
const invHeightScale = 1 / heightScale;
|
const invWidthScale = 1 / widthScale;
|
|
// This defines the size of the window of values around a particular
|
// index in dy that we want to search for contributions to dx.
|
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
|
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
|
|
this.userCode = `
|
void main() {
|
ivec4 coords = getOutputCoords();
|
int b = coords[0];
|
int d = coords[3];
|
int r = coords[1];
|
int c = coords[2];
|
|
float accumulator = 0.0;
|
|
const float heightScale = float(${heightScale});
|
const float widthScale = float(${widthScale});
|
|
const float invHeightScale = float(${invHeightScale});
|
const float invWidthScale = float(${invWidthScale});
|
|
const int winHeight = int(${winHeight});
|
const int winWidth = int(${winWidth});
|
|
// Compute bounds for where in dy we will look
|
float startRLerp = floor(float(r) * invHeightScale);
|
int startDyR = int(startRLerp - float(winHeight / 2));
|
|
float startCLerp = floor(float(c) * invWidthScale);
|
int startDyC = int(startCLerp - float(winWidth / 2));
|
|
// Loop over dy
|
for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
|
int dyR = dyROffset + startDyR;
|
|
// Guard against the window exceeding the bounds of dy
|
if (dyR < 0 || dyR >= ${yHeight}) {
|
continue;
|
}
|
|
for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
|
int dyC = dyCOffset + startDyC;
|
|
// Guard against the window exceeding the bounds of dy
|
if (dyC < 0 || dyC >= ${yWidth}) {
|
continue;
|
}
|
|
float dxR = float(dyR) * heightScale;
|
int topDxRIndex = int(floor(dxR));
|
int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0));
|
float dxRLerp = dxR - float(topDxRIndex);
|
float inverseDxRLerp = 1.0 - dxRLerp;
|
|
float dxC = float(dyC) * widthScale;
|
int leftDxCIndex = int(floor(dxC));
|
int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0));
|
float dxCLerp = dxC - float(leftDxCIndex);
|
float inverseDxCLerp = 1.0 - dxCLerp;
|
|
if (r == topDxRIndex && c == leftDxCIndex) {
|
// topLeft
|
accumulator +=
|
getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
|
}
|
|
if (r == topDxRIndex && c == rightDxCIndex) {
|
// topRight
|
accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
|
}
|
|
if (r == bottomDxRIndex && c == leftDxCIndex) {
|
// bottomLeft
|
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
|
}
|
|
if (r == bottomDxRIndex && c == rightDxCIndex) {
|
// bottomRight
|
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
|
}
|
}
|
}
|
// End loop over dy
|
|
setOutput(accumulator);
|
}
|
`;
|
}
|
}
|