/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { useShapeUniforms } from './gpgpu_math'; export class Conv2DDerInputPackedProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'strides', type: 'vec2' }, ]; this.outputShape = convInfo.inShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d1 = coords[3]; ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads; int dyRCorner = dyCorner.x; int dyCCorner = dyCorner.y; vec4 result = vec4(0.); for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / strides[0]; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { int wCPerm = ${filterWidth} - 1 - wC; float dyC = float(dyCCorner + wC) / strides[1]; bool idyCVal = (dyC >= 0.0) && (dyC < ${convInfo.outWidth}.0) && (fract(dyC) == 0.0); int idyC = int(dyC); float dyC2 = float(dyCCorner + wC + 1) / strides[1]; bool idyCVal2 = (dyC2 >= 0.0) && (dyC2 < ${convInfo.outWidth}.0) && (fract(dyC2) == 0.0); int idyC2 = int(dyC2); if (idyCVal && idyCVal2) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC, d2); vec4 dySample2 = (idyC / 2 == idyC2 / 2) ? dySample : getDy(batch, idyR, idyC2, d2); vec2 dyValue = mod(float(idyC), 2.) == 0. ? dySample.xy : dySample.zw; result.xy += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); dyValue = mod(float(idyC2), 2.) == 0. ? dySample2.xy : dySample2.zw; result.zw += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } else if (idyCVal) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC, d2); vec2 dyValue = mod(float(idyC), 2.) == 0. ? dySample.xy : dySample.zw; result.xy += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } else if (idyCVal2) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC2, d2); vec2 dyValue = mod(float(idyC2), 2.) == 0. ? dySample.xy : dySample.zw; result.zw += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } } } setOutput(result); } `; } } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29udl9iYWNrcHJvcF9wYWNrZWRfZ3B1LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9jb252X2JhY2twcm9wX3BhY2tlZF9ncHUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBR0gsT0FBTyxFQUFlLGdCQUFnQixFQUFDLE1BQU0sY0FBYyxDQUFDO0FBRTVELE1BQU0sT0FBTywyQkFBMkI7SUFXdEMsWUFBWSxRQUFpQztRQVY3QyxrQkFBYSxHQUFHLENBQUMsSUFBSSxFQUFFLEdBQUcsQ0FBQyxDQUFDO1FBQzVCLGlCQUFZLEdBQUcsSUFBSSxDQUFDO1FBQ3BCLGlCQUFZLEdBQUcsSUFBSSxDQUFDO1FBSXBCLG1CQUFjLEdBQUc7WUFDZixFQUFDLElBQUksRUFBRSxTQUFTLEVBQUUsSUFBSSxFQUFFLE1BQWUsRUFBRTtTQUMxQyxDQUFDO1FBR0EsSUFBSSxDQUFDLFdBQVcsR0FBRyxRQUFRLENBQUMsT0FBTyxDQUFDO1FBQ3BDLElBQUksQ0FBQyxtQkFBbUIsR0FBRyxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsV0FBVyxDQUFDLE1BQU0sQ0FBQyxDQUFDO1FBRXJFLE1BQU0sWUFBWSxHQUFHLFFBQVEsQ0FBQyxZQUFZLENBQUM7UUFDM0MsTUFBTSxXQUFXLEdBQUcsUUFBUSxDQUFDLFdBQVcsQ0FBQztRQUV6QyxNQUFNLE1BQU0sR0FBRyxZQUFZLEdBQUcsQ0FBQyxHQUFHLFFBQVEsQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDO1FBQ3ZELE1BQU0sT0FBTyxHQUFHLFdBQVcsR0FBRyxDQUFDLEdBQUcsUUFBUSxDQUFDLE9BQU8sQ0FBQyxJQUFJLENBQUM7UUFFeEQsSUFBSSxDQUFDLFFBQVEsR0FBRztpQ0FDYSxNQUFNLEtBQUssT0FBTzs7Ozs7Ozs7Ozs7O2dDQVluQixZQUFZOztvQ0FFUixRQUFRLENBQUMsU0FBUzs7Ozt5QkFJN0IsWUFBWTs7a0NBRUgsV0FBVzsyQkFDbEIsV0FBVzs7O29EQUdjLFFBQVEsQ0FBQyxRQUFROzs7Ozt1REFLZCxRQUFRLENBQUMsUUFBUTs7Ozs7c0NBS2xDLFFBQVEsQ0FBQyxXQUFXOzs7Ozs7Ozs7Ozs7Ozs7OztzQ0FpQnBCLFFBQVEsQ0FBQyxXQUFXOzs7Ozs7Ozs7c0NBU3BCLFFBQVEsQ0FBQyxXQUFXOzs7Ozs7Ozs7Ozs7O0tBYXJELENBQUM7SUFDSixDQUFDO0NBQ0YiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7YmFja2VuZF91dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuaW1wb3J0IHtHUEdQVVByb2dyYW0sIHVzZVNoYXBlVW5pZm9ybXN9IGZyb20gJy4vZ3BncHVfbWF0aCc7XG5cbmV4cG9ydCBjbGFzcyBDb252MkREZXJJbnB1dFBhY2tlZFByb2dyYW0gaW1wbGVtZW50cyBHUEdQVVByb2dyYW0ge1xuICB2YXJpYWJsZU5hbWVzID0gWydkeScsICdXJ107XG4gIHBhY2tlZElucHV0cyA9IHRydWU7XG4gIHBhY2tlZE91dHB1dCA9IHRydWU7XG4gIG91dHB1dFNoYXBlOiBudW1iZXJbXTtcbiAgdXNlckNvZGU6IHN0cmluZztcbiAgZW5hYmxlU2hhcGVVbmlmb3JtczogYm9vbGVhbjtcbiAgY3VzdG9tVW5pZm9ybXMgPSBbXG4gICAge25hbWU6ICdzdHJpZGVzJywgdHlwZTogJ3ZlYzInIGFzIGNvbnN0IH0sXG4gIF07XG5cbiAgY29uc3RydWN0b3IoY29udkluZm86IGJhY2tlbmRfdXRpbC5Db252MkRJbmZvKSB7XG4gICAgdGhpcy5vdXRwdXRTaGFwZSA9IGNvbnZJbmZvLmluU2hhcGU7XG4gICAgdGhpcy5lbmFibGVTaGFwZVVuaWZvcm1zID0gdXNlU2hhcGVVbmlmb3Jtcyh0aGlzLm91dHB1dFNoYXBlLmxlbmd0aCk7XG5cbiAgICBjb25zdCBmaWx0ZXJIZWlnaHQgPSBjb252SW5mby5maWx0ZXJIZWlnaHQ7XG4gICAgY29uc3QgZmlsdGVyV2lkdGggPSBjb252SW5mby5maWx0ZXJXaWR0aDtcblxuICAgIGNvbnN0IHBhZFRvcCA9IGZpbHRlckhlaWdodCAtIDEgLSBjb252SW5mby5wYWRJbmZvLnRvcDtcbiAgICBjb25zdCBwYWRMZWZ0ID0gZmlsdGVyV2lkdGggLSAxIC0gY29udkluZm8ucGFkSW5mby5sZWZ0O1xuXG4gICAgdGhpcy51c2VyQ29kZSA9IGBcbiAgICAgIGNvbnN0IGl2ZWMyIHBhZHMgPSBpdmVjMigke3BhZFRvcH0sICR7cGFkTGVmdH0pO1xuXG4gICAgICB2b2lkIG1haW4oKSB7XG4gICAgICAgIGl2ZWM0IGNvb3JkcyA9IGdldE91dHB1dENvb3JkcygpO1xuICAgICAgICBpbnQgYmF0Y2ggPSBjb29yZHNbMF07XG4gICAgICAgIGludCBkMSA9IGNvb3Jkc1szXTtcblxuICAgICAgICBpdmVjMiBkeUNvcm5lciA9IGl2ZWMyKGNvb3Jkc1sxXSwgY29vcmRzWzJdKSAtIHBhZHM7XG4gICAgICAgIGludCBkeVJDb3JuZXIgPSBkeUNvcm5lci54O1xuICAgICAgICBpbnQgZHlDQ29ybmVyID0gZHlDb3JuZXIueTtcblxuICAgICAgICB2ZWM0IHJlc3VsdCA9IHZlYzQoMC4pO1xuICAgICAgICBmb3IgKGludCB3UiA9IDA7IHdSIDwgJHtmaWx0ZXJIZWlnaHR9OyB3UisrKSB7XG4gICAgICAgICAgZmxvYXQgZHlSID0gZmxvYXQoZHlSQ29ybmVyICsgd1IpIC8gc3RyaWRlc1swXTtcbiAgICAgICAgICBpZiAoZHlSIDwgMC4wIHx8IGR5UiA+PSAke2NvbnZJbmZvLm91dEhlaWdodH0uMCB8fCBmcmFjdChkeVIpID4gMC4wKSB7XG4gICAgICAgICAgICBjb250aW51ZTtcbiAgICAgICAgICB9XG4gICAgICAgICAgaW50IGlkeVIgPSBpbnQoZHlSKTtcbiAgICAgICAgICBpbnQgd1JQZXJtID0gJHtmaWx0ZXJIZWlnaHR9IC0gMSAtIHdSO1xuXG4gICAgICAgICAgZm9yIChpbnQgd0MgPSAwOyB3QyA8ICR7ZmlsdGVyV2lkdGh9OyB3QysrKSB7XG4gICAgICAgICAgICBpbnQgd0NQZXJtID0gJHtmaWx0ZXJXaWR0aH0gLSAxIC0gd0M7XG5cbiAgICAgICAgICAgIGZsb2F0IGR5QyA9IGZsb2F0KGR5Q0Nvcm5lciArIHdDKSAvIHN0cmlkZXNbMV07XG4gICAgICAgICAgICBib29sIGlkeUNWYWwgPSAoZHlDID49IDAuMCkgJiYgKGR5QyA8ICR7Y29udkluZm8ub3V0V2lkdGh9LjApXG4gICAgICAgICAgICAgICYmIChmcmFjdChkeUMpID09IDAuMCk7XG4gICAgICAgICAgICBpbnQgaWR5QyA9IGludChkeUMpO1xuXG4gICAgICAgICAgICBmbG9hdCBkeUMyID0gZmxvYXQoZHlDQ29ybmVyICsgd0MgKyAxKSAvIHN0cmlkZXNbMV07XG4gICAgICAgICAgICBib29sIGlkeUNWYWwyID0gKGR5QzIgPj0gMC4wKSAmJiAoZHlDMiA8ICR7Y29udkluZm8ub3V0V2lkdGh9LjApXG4gICAgICAgICAgICAgICYmIChmcmFjdChkeUMyKSA9PSAwLjApO1xuICAgICAgICAgICAgaW50IGlkeUMyID0gaW50KGR5QzIpO1xuXG4gICAgICAgICAgICBpZiAoaWR5Q1ZhbCAmJiBpZHlDVmFsMikge1xuICAgICAgICAgICAgICBmb3IgKGludCBkMiA9IDA7IGQyIDwgJHtjb252SW5mby5vdXRDaGFubmVsc307IGQyICs9IDIpIHtcbiAgICAgICAgICAgICAgICB2ZWM0IHdWYWx1ZSA9IGdldFcod1JQZXJtLCB3Q1Blcm0sIGQxLCBkMik7XG4gICAgICAgICAgICAgICAgdmVjNCBkeVNhbXBsZSA9IGdldER5KGJhdGNoLCBpZHlSLCBpZHlDLCBkMik7XG4gICAgICAgICAgICAgICAgdmVjNCBkeVNhbXBsZTIgPSAoaWR5QyAvIDIgPT0gaWR5QzIgLyAyKSA/XG4gICAgICAgICAgICAgICAgICBkeVNhbXBsZSA6IGdldER5KGJhdGNoLCBpZHlSLCBpZHlDMiwgZDIpO1xuXG4gICAgICAgICAgICAgICAgdmVjMiBkeVZhbHVlID0gbW9kKGZsb2F0KGlkeUMpLCAyLikgPT0gMC4gP1xuICAgICAgICAgICAgICAgICAgZHlTYW1wbGUueHkgOiBkeVNhbXBsZS56dztcbiAgICAgICAgICAgICAgICByZXN1bHQueHkgKz0gdmVjMihkb3QoZHlWYWx1ZSwgd1ZhbHVlLnh5KSxcbiAgICAgICAgICAgICAgICAgIGRvdChkeVZhbHVlLCB3VmFsdWUuencpKTtcblxuICAgICAgICAgICAgICAgIGR5VmFsdWUgPSBtb2QoZmxvYXQoaWR5QzIpLCAyLikgPT0gMC4gP1xuICAgICAgICAgICAgICAgICAgZHlTYW1wbGUyLnh5IDogZHlTYW1wbGUyLnp3O1xuICAgICAgICAgICAgICAgIHJlc3VsdC56dyArPSB2ZWMyKGRvdChkeVZhbHVlLCB3VmFsdWUueHkpLFxuICAgICAgICAgICAgICAgICAgZG90KGR5VmFsdWUsIHdWYWx1ZS56dykpO1xuICAgICAgICAgICAgICB9XG4gICAgICAgICAgICB9IGVsc2UgaWYgKGlkeUNWYWwpIHtcbiAgICAgICAgICAgICAgZm9yIChpbnQgZDIgPSAwOyBkMiA8ICR7Y29udkluZm8ub3V0Q2hhbm5lbHN9OyBkMiArPSAyKSB7XG4gICAgICAgICAgICAgICAgdmVjNCB3VmFsdWUgPSBnZXRXKHdSUGVybSwgd0NQZXJtLCBkMSwgZDIpO1xuICAgICAgICAgICAgICAgIHZlYzQgZHlTYW1wbGUgPSBnZXREeShiYXRjaCwgaWR5UiwgaWR5QywgZDIpO1xuICAgICAgICAgICAgICAgIHZlYzIgZHlWYWx1ZSA9IG1vZChmbG9hdChpZHlDKSwgMi4pID09IDAuID9cbiAgICAgICAgICAgICAgICAgIGR5U2FtcGxlLnh5IDogZHlTYW1wbGUuenc7XG4gICAgICAgICAgICAgICAgcmVzdWx0Lnh5ICs9IHZlYzIoZG90KGR5VmFsdWUsIHdWYWx1ZS54eSksXG4gICAgICAgICAgICAgICAgICBkb3QoZHlWYWx1ZSwgd1ZhbHVlLnp3KSk7XG4gICAgICAgICAgICAgIH1cbiAgICAgICAgICAgIH0gZWxzZSBpZiAoaWR5Q1ZhbDIpIHtcbiAgICAgICAgICAgICAgZm9yIChpbnQgZDIgPSAwOyBkMiA8ICR7Y29udkluZm8ub3V0Q2hhbm5lbHN9OyBkMiArPSAyKSB7XG4gICAgICAgICAgICAgICAgdmVjNCB3VmFsdWUgPSBnZXRXKHdSUGVybSwgd0NQZXJtLCBkMSwgZDIpO1xuICAgICAgICAgICAgICAgIHZlYzQgZHlTYW1wbGUgPSBnZXREeShiYXRjaCwgaWR5UiwgaWR5QzIsIGQyKTtcbiAgICAgICAgICAgICAgICB2ZWMyIGR5VmFsdWUgPSBtb2QoZmxvYXQoaWR5QzIpLCAyLikgPT0gMC4gP1xuICAgICAgICAgICAgICAgICAgZHlTYW1wbGUueHkgOiBkeVNhbXBsZS56dztcbiAgICAgICAgICAgICAgICByZXN1bHQuencgKz0gdmVjMihkb3QoZHlWYWx1ZSwgd1ZhbHVlLnh5KSxcbiAgICAgICAgICAgICAgICAgIGRvdChkeVZhbHVlLCB3VmFsdWUuencpKTtcbiAgICAgICAgICAgICAgfVxuICAgICAgICAgICAgfVxuICAgICAgICAgIH1cbiAgICAgICAgfVxuICAgICAgICBzZXRPdXRwdXQocmVzdWx0KTtcbiAgICAgIH1cbiAgICBgO1xuICB9XG59XG4iXX0=