/**
|
* @license
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
export class Pool2DProgram {
|
constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
|
this.variableNames = ['x'];
|
if (poolType === 'avg' && computePositions) {
|
throw new Error('Cannot compute positions for average pool.');
|
}
|
const filterWidth = convInfo.filterWidth;
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padTop = convInfo.padInfo.top;
|
const padLeft = convInfo.padInfo.left;
|
this.outputShape = convInfo.outShape;
|
const isAvgPool = poolType === 'avg';
|
const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
|
const flattenPositionStr = `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
|
let initializationValue = '0.0';
|
if (!isAvgPool) {
|
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
|
initializationValue = '-1.0 / 1e-20';
|
}
|
if (computePositions) {
|
const compareOp = '>=';
|
this.userCode = `
|
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
void main() {
|
ivec4 coords = getOutputCoords();
|
int batch = coords[0];
|
int d = coords[3];
|
|
ivec2 xRCCorner = coords.yz * strides - pads;
|
int xRCorner = xRCCorner.x;
|
int xCCorner = xRCCorner.y;
|
|
// max/min x(?, ?, d) to get y(yR, yC, d).
|
// ? = to be determined
|
float minMaxValue = 0.0;
|
float minMaxValueFound = 0.0;
|
int minMaxPosition = 0;
|
float avgValue = 0.0;
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
wR += ${dilationHeight}) {
|
int xR = xRCorner + wR;
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
continue;
|
}
|
|
for (int wC = 0; wC < ${effectiveFilterWidth};
|
wC += ${dilationWidth}) {
|
int xC = xCCorner + wC;
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
continue;
|
}
|
|
float value = getX(batch, xR, xC, d);
|
|
// If a min / max value has already been found, use it. If not,
|
// use the current value.
|
float currMinMaxValue = mix(
|
value, minMaxValue, minMaxValueFound);
|
if (value ${compareOp} currMinMaxValue) {
|
minMaxValue = value;
|
minMaxValueFound = 1.0;
|
minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :
|
flattenPositionStr) :
|
`wR * ${effectiveFilterWidth} + wC`};
|
}
|
}
|
}
|
setOutput(float(minMaxPosition));
|
}
|
`;
|
return;
|
}
|
const compareOp = 'max';
|
let returnValue = `${poolType}(${poolType}(${poolType}(` +
|
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
|
if (poolType === 'avg') {
|
returnValue = `avgValue / max(count, 1.0)`;
|
}
|
const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
|
const filterWidthVec4Remainder = filterWidth % 4;
|
const updateSnippet = `
|
if (${isAvgPool}) {
|
avgValue += dot(values, ones);
|
} else {
|
minMaxValue = ${compareOp}(values, minMaxValue);
|
}
|
`;
|
this.userCode = `
|
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
const float initializationValue = ${initializationValue};
|
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
float count = 0.0;
|
|
float getValue(int batch, int xR, int xC, int d) {
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
return initializationValue;
|
}
|
count += 1.0;
|
return getX(batch, xR, xC, d);
|
}
|
|
void main() {
|
ivec4 coords = getOutputCoords();
|
int batch = coords[0];
|
int d = coords[3];
|
|
ivec2 xRCCorner = coords.yz * strides - pads;
|
int xRCorner = xRCCorner.x;
|
int xCCorner = xRCCorner.y;
|
|
// max/min x(?, ?, d) to get y(yR, yC, d).
|
// ? = to be determined
|
vec4 minMaxValue = vec4(${initializationValue});
|
float avgValue = 0.0;
|
count = 0.0;
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
wR += ${dilationHeight}) {
|
int xR = xRCorner + wR;
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
continue;
|
}
|
|
for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
|
int xC = xCCorner + wC * ${dilationWidth};
|
|
vec4 values = vec4(
|
getValue(batch, xR, xC, d),
|
getValue(batch, xR, xC + ${dilationWidth}, d),
|
getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
|
getValue(batch, xR, xC + 3 * ${dilationWidth}, d)
|
);
|
|
${updateSnippet}
|
}
|
|
int xC = xCCorner + ${filterWidthNearestVec4};
|
if (${filterWidthVec4Remainder === 1}) {
|
vec4 values = vec4(
|
getValue(batch, xR, xC, d),
|
initializationValue,
|
initializationValue,
|
initializationValue
|
);
|
|
${updateSnippet}
|
} else if (${filterWidthVec4Remainder === 2}) {
|
vec4 values = vec4(
|
getValue(batch, xR, xC, d),
|
getValue(batch, xR, xC + ${dilationWidth}, d),
|
initializationValue,
|
initializationValue
|
);
|
|
${updateSnippet}
|
} else if (${filterWidthVec4Remainder === 3}) {
|
vec4 values = vec4(
|
getValue(batch, xR, xC, d),
|
getValue(batch, xR, xC + ${dilationWidth}, d),
|
getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
|
initializationValue
|
);
|
|
${updateSnippet}
|
}
|
}
|
setOutput(${returnValue});
|
}
|
`;
|
}
|
}
|
export class Pool3DProgram {
|
constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
|
this.variableNames = ['x'];
|
if (poolType === 'avg' && computePositions) {
|
throw new Error('Cannot compute positions for average pool.');
|
}
|
const filterWidth = convInfo.filterWidth;
|
const strideDepth = convInfo.strideDepth;
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const dilationDepth = convInfo.dilationDepth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padFront = convInfo.padInfo.front;
|
const padTop = convInfo.padInfo.top;
|
const padLeft = convInfo.padInfo.left;
|
this.outputShape = convInfo.outShape;
|
const isAvgPool = poolType === 'avg';
|
let initializationValue = '0.0';
|
if (!isAvgPool) {
|
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
|
initializationValue = '-1.0 / 1e-20';
|
}
|
if (computePositions) {
|
const compareOp = '>=';
|
this.userCode = `
|
const ivec3 strides =
|
ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
|
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
|
|
void main() {
|
ivec5 coords = getOutputCoords();
|
int batch = coords.x;
|
int ch = coords.u;
|
|
ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
|
int xDCorner = xCorner.x;
|
int xRCorner = xCorner.y;
|
int xCCorner = xCorner.z;
|
|
// max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).
|
// ? = to be determined
|
float minMaxValue = 0.0;
|
float minMaxValueFound = 0.0;
|
int minMaxPosition = 0;
|
|
for (int wD = 0; wD < ${effectiveFilterDepth};
|
wD += ${dilationDepth}) {
|
int xD = xDCorner + wD;
|
|
if (xD < 0 || xD >= ${convInfo.inDepth}) {
|
continue;
|
}
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
wR += ${dilationHeight}) {
|
int xR = xRCorner + wR;
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
continue;
|
}
|
|
for (int wC = 0; wC < ${effectiveFilterWidth};
|
wC += ${dilationWidth}) {
|
int xC = xCCorner + wC;
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
continue;
|
}
|
|
float value = getX(batch, xD, xR, xC, ch);
|
|
// If a min / max value has already been found, use it. If not,
|
// use the current value.
|
float currMinMaxValue = mix(
|
value, minMaxValue, minMaxValueFound);
|
if (value ${compareOp} currMinMaxValue) {
|
minMaxValue = value;
|
minMaxValueFound = 1.0;
|
minMaxPosition = ${flattenPositions ?
|
(includeBatchInIndex ?
|
`(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch` :
|
`((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) :
|
`wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
|
wR * ${effectiveFilterWidth} + wC`};
|
}
|
}
|
}
|
}
|
setOutput(float(minMaxPosition));
|
}
|
`;
|
return;
|
}
|
const compareOp = 'max';
|
let returnValue = `${poolType}(${poolType}(${poolType}(` +
|
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
|
if (poolType === 'avg') {
|
// Use `max(count, 1.0)` instead of `count` in case count === 0.0.
|
// If count === 0.0, `avgValue` is always 0.0 and we change `count`'s
|
// value to avoid dividing zero.
|
returnValue = `avgValue / max(count, 1.0)`;
|
}
|
const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
|
const filterWidthVec4Remainder = filterWidth % 4;
|
const updateSnippet = `
|
if (${isAvgPool}) {
|
avgValue += dot(values, ones);
|
} else {
|
minMaxValue = ${compareOp}(values, minMaxValue);
|
}
|
`;
|
this.userCode = `
|
const ivec3 strides =
|
ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
|
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
|
const float initializationValue = ${initializationValue};
|
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
float count = 0.0;
|
|
float getValue(int batch, int xD, int xR, int xC, int ch) {
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
return initializationValue;
|
}
|
count += 1.0;
|
return getX(batch, xD, xR, xC, ch);
|
}
|
|
void main() {
|
ivec5 coords = getOutputCoords();
|
int batch = coords.x;
|
int ch = coords.u;
|
|
ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
|
int xDCorner = xCorner.x;
|
int xRCorner = xCorner.y;
|
int xCCorner = xCorner.z;
|
|
// max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).
|
// ? = to be determined
|
vec4 minMaxValue = vec4(${initializationValue});
|
float avgValue = 0.0;
|
count = 0.0;
|
|
for (int wD = 0; wD < ${effectiveFilterDepth};
|
wD += ${dilationDepth}) {
|
int xD = xDCorner + wD;
|
|
if (xD < 0 || xD >= ${convInfo.inDepth}) {
|
continue;
|
}
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
wR += ${dilationHeight}) {
|
int xR = xRCorner + wR;
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
continue;
|
}
|
|
for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
|
int xC = xCCorner + wC * ${dilationWidth};
|
|
vec4 values = vec4(
|
getValue(batch, xD, xR, xC, ch),
|
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
|
getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
|
getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)
|
);
|
|
${updateSnippet}
|
}
|
|
int xC = xCCorner + ${filterWidthNearestVec4};
|
if (${filterWidthVec4Remainder === 1}) {
|
vec4 values = vec4(
|
getValue(batch, xD, xR, xC, ch),
|
initializationValue,
|
initializationValue,
|
initializationValue
|
);
|
|
${updateSnippet}
|
} else if (${filterWidthVec4Remainder === 2}) {
|
vec4 values = vec4(
|
getValue(batch, xD, xR, xC, ch),
|
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
|
initializationValue,
|
initializationValue
|
);
|
|
${updateSnippet}
|
} else if (${filterWidthVec4Remainder === 3}) {
|
vec4 values = vec4(
|
getValue(batch, xD, xR, xC, ch),
|
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
|
getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
|
initializationValue
|
);
|
|
${updateSnippet}
|
}
|
}
|
}
|
setOutput(${returnValue});
|
}
|
`;
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"pool_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/pool_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAKH,MAAM,OAAO,aAAa;IAKxB,YACI,QAAiC,EAAE,QAAqB,EACxD,gBAAyB,EAAE,gBAAgB,GAAG,KAAK,EACnD,mBAAmB,GAAG,KAAK;QAP/B,kBAAa,GAAG,CAAC,GAAG,CAAC,CAAC;QAQpB,IAAI,QAAQ,KAAK,KAAK,IAAI,gBAAgB,EAAE;YAC1C,MAAM,IAAI,KAAK,CAAC,4CAA4C,CAAC,CAAC;SAC/D;QAED,MAAM,WAAW,GAAG,QAAQ,CAAC,WAAW,CAAC;QACzC,MAAM,YAAY,GAAG,QAAQ,CAAC,YAAY,CAAC;QAC3C,MAAM,WAAW,GAAG,QAAQ,CAAC,WAAW,CAAC;QACzC,MAAM,cAAc,GAAG,QAAQ,CAAC,cAAc,CAAC;QAC/C,MAAM,aAAa,GAAG,QAAQ,CAAC,aAAa,CAAC;QAC7C,MAAM,qBAAqB,GAAG,QAAQ,CAAC,qBAAqB,CAAC;QAC7D,MAAM,oBAAoB,GAAG,QAAQ,CAAC,oBAAoB,CAAC;QAE3D,MAAM,MAAM,GAAG,QAAQ,CAAC,OAAO,CAAC,GAAG,CAAC;QACpC,MAAM,OAAO,GAAG,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;QACtC,IAAI,CAAC,WAAW,GAAG,QAAQ,CAAC,QAAQ,CAAC;QAErC,MAAM,SAAS,GAAG,QAAQ,KAAK,KAAK,CAAC;QACrC,MAAM,uBAAuB,GAAG,cAAc,QAAQ,CAAC,QAAQ,YAC3D,QAAQ,CAAC,OAAO,YAAY,QAAQ,CAAC,UAAU,MAAM,CAAC;QAC1D,MAAM,kBAAkB,GACpB,SAAS,QAAQ,CAAC,OAAO,YAAY,QAAQ,CAAC,UAAU,MAAM,CAAC;QAEnE,IAAI,mBAAmB,GAAG,KAAK,CAAC;QAChC,IAAI,CAAC,SAAS,EAAE;YACd,2DAA2D;YAC3D,mBAAmB,GAAG,cAAc,CAAC;SACtC;QAED,IAAI,gBAAgB,EAAE;YACpB,MAAM,SAAS,GAAG,IAAI,CAAC;YAEvB,IAAI,CAAC,QAAQ,GAAG;sCACgB,YAAY,KAAK,WAAW;mCAC/B,MAAM,KAAK,OAAO;;;;;;;;;;;;;;;;;;kCAkBnB,qBAAqB;sBACjC,cAAc;;;kCAGF,QAAQ,CAAC,QAAQ;;;;oCAIf,oBAAoB;wBAChC,aAAa;;;oCAGD,QAAQ,CAAC,OAAO;;;;;;;;;;0BAU1B,SAAS;;;mCAIzB,gBAAgB,CAAC,CAAC,CAAC,CAAC,mBAAmB,CAAC,CAAC,CAAC,uBAAuB,CAAC,CAAC;gBACzB,kBAAkB,CAAC,CAAC,CAAC;gBAC5C,QAAQ,oBAAoB,OAAO;;;;;;OAMzD,CAAC;YACF,OAAO;SACR;QAED,MAAM,SAAS,GAAG,KAAK,CAAC;QAExB,IAAI,WAAW,GAAG,GAAG,QAAQ,IAAI,QAAQ,IAAI,QAAQ,GAAG;YACpD,mEAAmE,CAAC;QACxE,IAAI,QAAQ,KAAK,KAAK,EAAE;YACtB,WAAW,GAAG,4BAA4B,CAAC;SAC5C;QAED,MAAM,sBAAsB,GAAG,IAAI,CAAC,KAAK,CAAC,WAAW,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;QAC/D,MAAM,wBAAwB,GAAG,WAAW,GAAG,CAAC,CAAC;QAEjD,MAAM,aAAa,GAAG;YACd,SAAS;;;wBAGG,SAAS;;KAE5B,CAAC;QAEF,IAAI,CAAC,QAAQ,GAAG;oCACgB,YAAY,KAAK,WAAW;iCAC/B,MAAM,KAAK,OAAO;0CACT,mBAAmB;;;;;;8BAM/B,QAAQ,CAAC,OAAO;;;;;;;;;;;;;;;;;;kCAkBZ,mBAAmB;;;;gCAIrB,qBAAqB;oBACjC,cAAc;;;gCAGF,QAAQ,CAAC,QAAQ;;;;kCAIf,sBAAsB;uCACjB,aAAa;;;;yCAIX,aAAa;6CACT,aAAa;6CACb,aAAa;;;cAG5C,aAAa;;;gCAGK,sBAAsB;gBACtC,wBAAwB,KAAK,CAAC;;;;;;;;cAQhC,aAAa;uBACJ,wBAAwB,KAAK,CAAC;;;yCAGZ,aAAa;;;;;cAKxC,aAAa;uBACJ,wBAAwB,KAAK,CAAC;;;yCAGZ,aAAa;6CACT,aAAa;;;;cAI5C,aAAa;;;oBAGP,WAAW;;KAE1B,CAAC;IACJ,CAAC;CACF;AAED,MAAM,OAAO,aAAa;IAKxB,YACI,QAAiC,EAAE,QAAqB,EACxD,gBAAyB,EAAE,gBAAgB,GAAG,KAAK,EACnD,mBAAmB,GAAG,KAAK;QAP/B,kBAAa,GAAG,CAAC,GAAG,CAAC,CAAC;QAQpB,IAAI,QAAQ,KAAK,KAAK,IAAI,gBAAgB,EAAE;YAC1C,MAAM,IAAI,KAAK,CAAC,4CAA4C,CAAC,CAAC;SAC/D;QAED,MAAM,WAAW,GAAG,QAAQ,CAAC,WAAW,CAAC;QACzC,MAAM,WAAW,GAAG,QAAQ,CAAC,WAAW,CAAC;QACzC,MAAM,YAAY,GAAG,QAAQ,CAAC,YAAY,CAAC;QAC3C,MAAM,WAAW,GAAG,QAAQ,CAAC,WAAW,CAAC;QACzC,MAAM,aAAa,GAAG,QAAQ,CAAC,aAAa,CAAC;QAC7C,MAAM,cAAc,GAAG,QAAQ,CAAC,cAAc,CAAC;QAC/C,MAAM,aAAa,GAAG,QAAQ,CAAC,aAAa,CAAC;QAC7C,MAAM,oBAAoB,GAAG,QAAQ,CAAC,oBAAoB,CAAC;QAC3D,MAAM,qBAAqB,GAAG,QAAQ,CAAC,qBAAqB,CAAC;QAC7D,MAAM,oBAAoB,GAAG,QAAQ,CAAC,oBAAoB,CAAC;QAE3D,MAAM,QAAQ,GAAG,QAAQ,CAAC,OAAO,CAAC,KAAK,CAAC;QACxC,MAAM,MAAM,GAAG,QAAQ,CAAC,OAAO,CAAC,GAAG,CAAC;QACpC,MAAM,OAAO,GAAG,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;QACtC,IAAI,CAAC,WAAW,GAAG,QAAQ,CAAC,QAAQ,CAAC;QAErC,MAAM,SAAS,GAAG,QAAQ,KAAK,KAAK,CAAC;QAErC,IAAI,mBAAmB,GAAG,KAAK,CAAC;QAChC,IAAI,CAAC,SAAS,EAAE;YACd,2DAA2D;YAC3D,mBAAmB,GAAG,cAAc,CAAC;SACtC;QAED,IAAI,gBAAgB,EAAE;YACpB,MAAM,SAAS,GAAG,IAAI,CAAC;YAEvB,IAAI,CAAC,QAAQ,GAAG;;oBAEF,WAAW,KAAK,YAAY,KAAK,WAAW;mCAC7B,QAAQ,KAAK,MAAM,KAAK,OAAO;;;;;;;;;;;;;;;;;;kCAkBhC,oBAAoB;sBAChC,aAAa;;;kCAGD,QAAQ,CAAC,OAAO;;;;oCAId,qBAAqB;wBACjC,cAAc;;;oCAGF,QAAQ,CAAC,QAAQ;;;;sCAIf,oBAAoB;0BAChC,aAAa;;;sCAGD,QAAQ,CAAC,OAAO;;;;;;;;;;4BAU1B,SAAS;;;qCAI3B,gBAAgB,CAAC,CAAC;gBACd,CAAC,mBAAmB,CAAC,CAAC;oBACjB,cAAc,QAAQ,CAAC,OAAO,YAC1B,QAAQ,CAAC,QAAQ,YAAY,QAAQ,CAAC,OAAO,YAC7C,QAAQ,CAAC,UAAU,OAAO,CAAC,CAAC;oBAChC,UAAU,QAAQ,CAAC,QAAQ,YACvB,QAAQ,CAAC,OAAO,YAAY,QAAQ,CAAC,UAAU,OAAO,CAAC,CAAC,CAAC;gBAClE,QAAQ,qBAAqB,MAAM,oBAAoB;6BACxC,oBAAoB,OAAO;;;;;;;OAOjD,CAAC;YACF,OAAO;SACR;QAED,MAAM,SAAS,GAAG,KAAK,CAAC;QAExB,IAAI,WAAW,GAAG,GAAG,QAAQ,IAAI,QAAQ,IAAI,QAAQ,GAAG;YACpD,mEAAmE,CAAC;QACxE,IAAI,QAAQ,KAAK,KAAK,EAAE;YACtB,kEAAkE;YAClE,qEAAqE;YACrE,gCAAgC;YAChC,WAAW,GAAG,4BAA4B,CAAC;SAC5C;QAED,MAAM,sBAAsB,GAAG,IAAI,CAAC,KAAK,CAAC,WAAW,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;QAC/D,MAAM,wBAAwB,GAAG,WAAW,GAAG,CAAC,CAAC;QAEjD,MAAM,aAAa,GAAG;YACd,SAAS;;;wBAGG,SAAS;;KAE5B,CAAC;QAEF,IAAI,CAAC,QAAQ,GAAG;;gBAEJ,WAAW,KAAK,YAAY,KAAK,WAAW;iCAC3B,QAAQ,KAAK,MAAM,KAAK,OAAO;0CACtB,mBAAmB;;;;;;8BAM/B,QAAQ,CAAC,OAAO;;;;;;;;;;;;;;;;;;;kCAmBZ,mBAAmB;;;;gCAIrB,oBAAoB;oBAChC,aAAa;;;gCAGD,QAAQ,CAAC,OAAO;;;;kCAId,qBAAqB;oBACnC,cAAc;;;kCAGA,QAAQ,CAAC,QAAQ;;;;oCAIf,sBAAsB;yCACjB,aAAa;;;;+CAIP,aAAa;mDACT,aAAa;mDACb,aAAa;;;gBAGhD,aAAa;;;kCAGK,sBAAsB;kBACtC,wBAAwB,KAAK,CAAC;;;;;;;;gBAQhC,aAAa;yBACJ,wBAAwB,KAAK,CAAC;;;+CAGR,aAAa;;;;;gBAK5C,aAAa;yBACJ,wBAAwB,KAAK,CAAC;;;+CAGR,aAAa;mDACT,aAAa;;;;gBAIhD,aAAa;;;;oBAIT,WAAW;;KAE1B,CAAC;IACJ,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2017 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util} from '@tensorflow/tfjs-core';\nimport {GPGPUProgram} from './gpgpu_math';\n\nexport class Pool2DProgram implements GPGPUProgram {\n  variableNames = ['x'];\n  outputShape: number[];\n  userCode: string;\n\n  constructor(\n      convInfo: backend_util.Conv2DInfo, poolType: 'max'|'avg',\n      computePositions: boolean, flattenPositions = false,\n      includeBatchInIndex = false) {\n    if (poolType === 'avg' && computePositions) {\n      throw new Error('Cannot compute positions for average pool.');\n    }\n\n    const filterWidth = convInfo.filterWidth;\n    const strideHeight = convInfo.strideHeight;\n    const strideWidth = convInfo.strideWidth;\n    const dilationHeight = convInfo.dilationHeight;\n    const dilationWidth = convInfo.dilationWidth;\n    const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n    const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n\n    const padTop = convInfo.padInfo.top;\n    const padLeft = convInfo.padInfo.left;\n    this.outputShape = convInfo.outShape;\n\n    const isAvgPool = poolType === 'avg';\n    const batchFlattenPositionStr = `((batch  * ${convInfo.inHeight} + xR) * ${\n        convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;\n    const flattenPositionStr =\n        `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;\n\n    let initializationValue = '0.0';\n    if (!isAvgPool) {\n      // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.\n      initializationValue = '-1.0 / 1e-20';\n    }\n\n    if (computePositions) {\n      const compareOp = '>=';\n\n      this.userCode = `\n        const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});\n        const ivec2 pads = ivec2(${padTop}, ${padLeft});\n\n        void main() {\n          ivec4 coords = getOutputCoords();\n          int batch = coords[0];\n          int d = coords[3];\n\n          ivec2 xRCCorner = coords.yz * strides - pads;\n          int xRCorner = xRCCorner.x;\n          int xCCorner = xRCCorner.y;\n\n          // max/min x(?, ?, d) to get y(yR, yC, d).\n          // ? = to be determined\n          float minMaxValue = 0.0;\n          float minMaxValueFound = 0.0;\n          int minMaxPosition = 0;\n          float avgValue = 0.0;\n\n          for (int wR = 0; wR < ${effectiveFilterHeight};\n              wR += ${dilationHeight}) {\n            int xR = xRCorner + wR;\n\n            if (xR < 0 || xR >= ${convInfo.inHeight}) {\n              continue;\n            }\n\n            for (int wC = 0; wC < ${effectiveFilterWidth};\n                wC += ${dilationWidth}) {\n              int xC = xCCorner + wC;\n\n              if (xC < 0 || xC >= ${convInfo.inWidth}) {\n                continue;\n              }\n\n              float value = getX(batch, xR, xC, d);\n\n              // If a min / max value has already been found, use it. If not,\n              // use the current value.\n              float currMinMaxValue = mix(\n                  value, minMaxValue, minMaxValueFound);\n              if (value ${compareOp} currMinMaxValue) {\n                minMaxValue = value;\n                minMaxValueFound = 1.0;\n                minMaxPosition = ${\n          flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :\n                                                    flattenPositionStr) :\n                             `wR * ${effectiveFilterWidth} + wC`};\n              }\n            }\n          }\n          setOutput(float(minMaxPosition));\n        }\n      `;\n      return;\n    }\n\n    const compareOp = 'max';\n\n    let returnValue = `${poolType}(${poolType}(${poolType}(` +\n        'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';\n    if (poolType === 'avg') {\n      returnValue = `avgValue / max(count, 1.0)`;\n    }\n\n    const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;\n    const filterWidthVec4Remainder = filterWidth % 4;\n\n    const updateSnippet = `\n      if (${isAvgPool}) {\n        avgValue += dot(values, ones);\n      } else {\n        minMaxValue = ${compareOp}(values, minMaxValue);\n      }\n    `;\n\n    this.userCode = `\n      const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});\n      const ivec2 pads = ivec2(${padTop}, ${padLeft});\n      const float initializationValue = ${initializationValue};\n      const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n      float count = 0.0;\n\n      float getValue(int batch, int xR, int xC, int d) {\n        if (xC < 0 || xC >= ${convInfo.inWidth}) {\n          return initializationValue;\n        }\n        count += 1.0;\n        return getX(batch, xR, xC, d);\n      }\n\n      void main() {\n        ivec4 coords = getOutputCoords();\n        int batch = coords[0];\n        int d = coords[3];\n\n        ivec2 xRCCorner = coords.yz * strides - pads;\n        int xRCorner = xRCCorner.x;\n        int xCCorner = xRCCorner.y;\n\n        // max/min x(?, ?, d) to get y(yR, yC, d).\n        // ? = to be determined\n        vec4 minMaxValue = vec4(${initializationValue});\n        float avgValue = 0.0;\n        count = 0.0;\n\n        for (int wR = 0; wR < ${effectiveFilterHeight};\n            wR += ${dilationHeight}) {\n          int xR = xRCorner + wR;\n\n          if (xR < 0 || xR >= ${convInfo.inHeight}) {\n            continue;\n          }\n\n          for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {\n            int xC = xCCorner + wC * ${dilationWidth};\n\n            vec4 values = vec4(\n              getValue(batch, xR, xC, d),\n              getValue(batch, xR, xC + ${dilationWidth}, d),\n              getValue(batch, xR, xC + 2 * ${dilationWidth}, d),\n              getValue(batch, xR, xC + 3 * ${dilationWidth}, d)\n            );\n\n            ${updateSnippet}\n          }\n\n          int xC = xCCorner + ${filterWidthNearestVec4};\n          if (${filterWidthVec4Remainder === 1}) {\n            vec4 values = vec4(\n              getValue(batch, xR, xC, d),\n              initializationValue,\n              initializationValue,\n              initializationValue\n            );\n\n            ${updateSnippet}\n          } else if (${filterWidthVec4Remainder === 2}) {\n            vec4 values = vec4(\n              getValue(batch, xR, xC, d),\n              getValue(batch, xR, xC + ${dilationWidth}, d),\n              initializationValue,\n              initializationValue\n            );\n\n            ${updateSnippet}\n          } else if (${filterWidthVec4Remainder === 3}) {\n            vec4 values = vec4(\n              getValue(batch, xR, xC, d),\n              getValue(batch, xR, xC + ${dilationWidth}, d),\n              getValue(batch, xR, xC + 2 * ${dilationWidth}, d),\n              initializationValue\n            );\n\n            ${updateSnippet}\n          }\n        }\n        setOutput(${returnValue});\n      }\n    `;\n  }\n}\n\nexport class Pool3DProgram implements GPGPUProgram {\n  variableNames = ['x'];\n  outputShape: number[];\n  userCode: string;\n\n  constructor(\n      convInfo: backend_util.Conv3DInfo, poolType: 'max'|'avg',\n      computePositions: boolean, flattenPositions = false,\n      includeBatchInIndex = false) {\n    if (poolType === 'avg' && computePositions) {\n      throw new Error('Cannot compute positions for average pool.');\n    }\n\n    const filterWidth = convInfo.filterWidth;\n    const strideDepth = convInfo.strideDepth;\n    const strideHeight = convInfo.strideHeight;\n    const strideWidth = convInfo.strideWidth;\n    const dilationDepth = convInfo.dilationDepth;\n    const dilationHeight = convInfo.dilationHeight;\n    const dilationWidth = convInfo.dilationWidth;\n    const effectiveFilterDepth = convInfo.effectiveFilterDepth;\n    const effectiveFilterHeight = convInfo.effectiveFilterHeight;\n    const effectiveFilterWidth = convInfo.effectiveFilterWidth;\n\n    const padFront = convInfo.padInfo.front;\n    const padTop = convInfo.padInfo.top;\n    const padLeft = convInfo.padInfo.left;\n    this.outputShape = convInfo.outShape;\n\n    const isAvgPool = poolType === 'avg';\n\n    let initializationValue = '0.0';\n    if (!isAvgPool) {\n      // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.\n      initializationValue = '-1.0 / 1e-20';\n    }\n\n    if (computePositions) {\n      const compareOp = '>=';\n\n      this.userCode = `\n        const ivec3 strides =\n            ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});\n        const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});\n\n        void main() {\n          ivec5 coords = getOutputCoords();\n          int batch = coords.x;\n          int ch = coords.u;\n\n          ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n          int xDCorner = xCorner.x;\n          int xRCorner = xCorner.y;\n          int xCCorner = xCorner.z;\n\n          // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n          // ? = to be determined\n          float minMaxValue = 0.0;\n          float minMaxValueFound = 0.0;\n          int minMaxPosition = 0;\n\n          for (int wD = 0; wD < ${effectiveFilterDepth};\n              wD += ${dilationDepth}) {\n            int xD = xDCorner + wD;\n\n            if (xD < 0 || xD >= ${convInfo.inDepth}) {\n              continue;\n            }\n\n            for (int wR = 0; wR < ${effectiveFilterHeight};\n                wR += ${dilationHeight}) {\n              int xR = xRCorner + wR;\n\n              if (xR < 0 || xR >= ${convInfo.inHeight}) {\n                continue;\n              }\n\n              for (int wC = 0; wC < ${effectiveFilterWidth};\n                  wC += ${dilationWidth}) {\n                int xC = xCCorner + wC;\n\n                if (xC < 0 || xC >= ${convInfo.inWidth}) {\n                  continue;\n                }\n\n                float value = getX(batch, xD, xR, xC, ch);\n\n                // If a min / max value has already been found, use it. If not,\n                // use the current value.\n                float currMinMaxValue = mix(\n                    value, minMaxValue, minMaxValueFound);\n                if (value ${compareOp} currMinMaxValue) {\n                  minMaxValue = value;\n                  minMaxValueFound = 1.0;\n                  minMaxPosition = ${\n          flattenPositions ?\n              (includeBatchInIndex ?\n                   `(((batch * ${convInfo.inDepth} + xD) * ${\n                       convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${\n                       convInfo.inChannels} + ch` :\n                   `((xD * ${convInfo.inHeight} + xR) * ${\n                       convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) :\n              `wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +\n                      wR * ${effectiveFilterWidth} + wC`};\n                }\n              }\n            }\n          }\n          setOutput(float(minMaxPosition));\n        }\n      `;\n      return;\n    }\n\n    const compareOp = 'max';\n\n    let returnValue = `${poolType}(${poolType}(${poolType}(` +\n        'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';\n    if (poolType === 'avg') {\n      // Use `max(count, 1.0)` instead of `count` in case count === 0.0.\n      // If count === 0.0, `avgValue` is always 0.0 and we change `count`'s\n      // value to avoid dividing zero.\n      returnValue = `avgValue / max(count, 1.0)`;\n    }\n\n    const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;\n    const filterWidthVec4Remainder = filterWidth % 4;\n\n    const updateSnippet = `\n      if (${isAvgPool}) {\n        avgValue += dot(values, ones);\n      } else {\n        minMaxValue = ${compareOp}(values, minMaxValue);\n      }\n    `;\n\n    this.userCode = `\n      const ivec3 strides =\n        ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});\n      const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});\n      const float initializationValue = ${initializationValue};\n      const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n      float count = 0.0;\n\n      float getValue(int batch, int xD, int xR, int xC, int ch) {\n        if (xC < 0 || xC >= ${convInfo.inWidth}) {\n          return initializationValue;\n        }\n        count += 1.0;\n        return getX(batch, xD, xR, xC, ch);\n      }\n\n      void main() {\n        ivec5 coords = getOutputCoords();\n        int batch = coords.x;\n        int ch = coords.u;\n\n        ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n        int xDCorner = xCorner.x;\n        int xRCorner = xCorner.y;\n        int xCCorner = xCorner.z;\n\n        // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n        // ? = to be determined\n        vec4 minMaxValue = vec4(${initializationValue});\n        float avgValue = 0.0;\n        count = 0.0;\n\n        for (int wD = 0; wD < ${effectiveFilterDepth};\n            wD += ${dilationDepth}) {\n          int xD = xDCorner + wD;\n\n          if (xD < 0 || xD >= ${convInfo.inDepth}) {\n            continue;\n          }\n\n          for (int wR = 0; wR < ${effectiveFilterHeight};\n            wR += ${dilationHeight}) {\n            int xR = xRCorner + wR;\n\n            if (xR < 0 || xR >= ${convInfo.inHeight}) {\n              continue;\n            }\n\n            for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {\n              int xC = xCCorner + wC * ${dilationWidth};\n\n              vec4 values = vec4(\n                getValue(batch, xD, xR, xC, ch),\n                getValue(batch, xD, xR, xC + ${dilationWidth}, ch),\n                getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),\n                getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)\n              );\n\n              ${updateSnippet}\n            }\n\n            int xC = xCCorner + ${filterWidthNearestVec4};\n            if (${filterWidthVec4Remainder === 1}) {\n              vec4 values = vec4(\n                getValue(batch, xD, xR, xC, ch),\n                initializationValue,\n                initializationValue,\n                initializationValue\n              );\n\n              ${updateSnippet}\n            } else if (${filterWidthVec4Remainder === 2}) {\n              vec4 values = vec4(\n                getValue(batch, xD, xR, xC, ch),\n                getValue(batch, xD, xR, xC + ${dilationWidth}, ch),\n                initializationValue,\n                initializationValue\n              );\n\n              ${updateSnippet}\n            } else if (${filterWidthVec4Remainder === 3}) {\n              vec4 values = vec4(\n                getValue(batch, xD, xR, xC, ch),\n                getValue(batch, xD, xR, xC + ${dilationWidth}, ch),\n                getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),\n                initializationValue\n              );\n\n              ${updateSnippet}\n            }\n          }\n        }\n        setOutput(${returnValue});\n      }\n    `;\n  }\n}\n"]}
|