/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { FusedBatchNorm } from '../kernel_names'; import { add } from '../ops/add'; import { getReductionAxes } from '../ops/broadcast_util'; import { mul } from '../ops/mul'; import { reshape } from '../ops/reshape'; import { rsqrt } from '../ops/rsqrt'; import { scalar } from '../ops/scalar'; import { sub } from '../ops/sub'; import { sum } from '../ops/sum'; import { tile } from '../ops/tile'; export const fusedBatchNormGradConfig = { kernelName: FusedBatchNorm, inputsToSave: ['x', 'mean', 'variance', 'scale'], gradFunc: (dy, saved, attrs) => { const { varianceEpsilon } = attrs; const [x, mean, variance, scale] = saved; const scaleValue = scale == null ? scalar(1) : scale; const reductionAxes = getReductionAxes(mean.shape, x.shape); const tileShape = []; if (mean.rank === 1) { for (let i = 0; i < x.shape.length - 1; ++i) { tileShape.push(x.shape[i]); } tileShape.push(1); } const xMinusMean = sub(x, mean); const dyTimesScaleValue = mul(dy, scaleValue); const oneOverSqrtVariance = rsqrt(add(variance, scalar(varianceEpsilon))); const minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5)); const derX = () => { if (mean.rank === 1) { return reshape(mul(mul(dy, tile(reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape); } else { return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape); } }; const derMean = () => { let meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue); if (mean.rank === 1) { meanDer = sum(meanDer, reductionAxes); } return reshape(meanDer, mean.shape); }; const derVariance = () => { let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue); if (mean.rank === 1) { varianceDer = sum(varianceDer, reductionAxes); } return reshape(varianceDer, mean.shape); }; const derScale = () => { const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance); let scaleDer = mul(dy, xMinusMean2TimesRsqrt); if (mean.rank === 1) { scaleDer = sum(scaleDer, reductionAxes); } return reshape(scaleDer, mean.shape); }; const derOffset = () => { let offsetDer = dy; if (mean.rank === 1) { offsetDer = sum(offsetDer, reductionAxes); } return reshape(offsetDer, mean.shape); }; return { x: derX, mean: derMean, variance: derVariance, scale: derScale, offset: derOffset }; } }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"FusedBatchNorm_grad.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/gradients/FusedBatchNorm_grad.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAC,cAAc,EAAsB,MAAM,iBAAiB,CAAC;AAEpE,OAAO,EAAC,GAAG,EAAC,MAAM,YAAY,CAAC;AAC/B,OAAO,EAAC,gBAAgB,EAAC,MAAM,uBAAuB,CAAC;AACvD,OAAO,EAAC,GAAG,EAAC,MAAM,YAAY,CAAC;AAC/B,OAAO,EAAC,OAAO,EAAC,MAAM,gBAAgB,CAAC;AACvC,OAAO,EAAC,KAAK,EAAC,MAAM,cAAc,CAAC;AACnC,OAAO,EAAC,MAAM,EAAC,MAAM,eAAe,CAAC;AACrC,OAAO,EAAC,GAAG,EAAC,MAAM,YAAY,CAAC;AAC/B,OAAO,EAAC,GAAG,EAAC,MAAM,YAAY,CAAC;AAC/B,OAAO,EAAC,IAAI,EAAC,MAAM,aAAa,CAAC;AAIjC,MAAM,CAAC,MAAM,wBAAwB,GAAe;IAClD,UAAU,EAAE,cAAc;IAC1B,YAAY,EAAE,CAAC,GAAG,EAAE,MAAM,EAAE,UAAU,EAAE,OAAO,CAAC;IAChD,QAAQ,EAAE,CACN,EAAU,EAAE,KAAe,EAAE,KAAmB,EAAE,EAAE;QACtD,MAAM,EAAC,eAAe,EAAC,GAAG,KAAuC,CAAC;QAClE,MAAM,CAAC,CAAC,EAAE,IAAI,EAAE,QAAQ,EAAE,KAAK,CAAC,GAAG,KAAK,CAAC;QAEzC,MAAM,UAAU,GAAG,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC;QACrD,MAAM,aAAa,GAAG,gBAAgB,CAAC,IAAI,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;QAC5D,MAAM,SAAS,GAAa,EAAE,CAAC;QAC/B,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,EAAE,CAAC,EAAE;gBAC3C,SAAS,CAAC,IAAI,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;aAC5B;YACD,SAAS,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACnB;QAED,MAAM,UAAU,GAAG,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;QAChC,MAAM,iBAAiB,GAAG,GAAG,CAAC,EAAE,EAAE,UAAU,CAAC,CAAC;QAC9C,MAAM,mBAAmB,GAAG,KAAK,CAAC,GAAG,CAAC,QAAQ,EAAE,MAAM,CAAC,eAAe,CAAC,CAAC,CAAC,CAAC;QAC1E,MAAM,cAAc,GAAG,GAAG,CACtB,GAAG,CAAC,GAAG,CAAC,mBAAmB,EAAE,mBAAmB,CAAC,EAAE,mBAAmB,CAAC,EACvE,MAAM,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QAElB,MAAM,IAAI,GAAG,GAAG,EAAE;YAChB,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,OAAO,OAAO,CACV,GAAG,CAAC,GAAG,CAAC,EAAE,EACF,IAAI,CACA,OAAO,CAAC,mBAAmB,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,EACtD,SAAS,CAAC,CAAC,EACnB,UAAU,CAAC,EACf,CAAC,CAAC,KAAK,CAAC,CAAC;aACd;iBAAM;gBACL,OAAO,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,EAAE,mBAAmB,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;aACxE;QACH,CAAC,CAAC;QACF,MAAM,OAAO,GAAG,GAAG,EAAE;YACnB,IAAI,OAAO,GACP,GAAG,CAAC,GAAG,CAAC,mBAAmB,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,iBAAiB,CAAC,CAAC;YACjE,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,OAAO,GAAG,GAAG,CAAC,OAAO,EAAE,aAAa,CAAC,CAAC;aACvC;YACD,OAAO,OAAO,CAAC,OAAO,EAAE,IAAI,CAAC,KAAoB,CAAC,CAAC;QACrD,CAAC,CAAC;QACF,MAAM,WAAW,GAAG,GAAG,EAAE;YACvB,IAAI,WAAW,GAAG,GAAG,CAAC,GAAG,CAAC,cAAc,EAAE,UAAU,CAAC,EAAE,iBAAiB,CAAC,CAAC;YAE1E,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,WAAW,GAAG,GAAG,CAAC,WAAW,EAAE,aAAa,CAAC,CAAC;aAC/C;YACD,OAAO,OAAO,CAAC,WAAW,EAAE,IAAI,CAAC,KAAoB,CAAC,CAAC;QACzD,CAAC,CAAC;QACF,MAAM,QAAQ,GAAG,GAAG,EAAE;YACpB,MAAM,qBAAqB,GAAG,GAAG,CAAC,UAAU,EAAE,mBAAmB,CAAC,CAAC;YAEnE,IAAI,QAAQ,GAAG,GAAG,CAAC,EAAE,EAAE,qBAAqB,CAAC,CAAC;YAC9C,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,QAAQ,GAAG,GAAG,CAAC,QAAQ,EAAE,aAAa,CAAC,CAAC;aACzC;YACD,OAAO,OAAO,CAAC,QAAQ,EAAE,IAAI,CAAC,KAAoB,CAAC,CAAC;QACtD,CAAC,CAAC;QACF,MAAM,SAAS,GAAG,GAAG,EAAE;YACrB,IAAI,SAAS,GAAG,EAAE,CAAC;YACnB,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,SAAS,GAAG,GAAG,CAAC,SAAS,EAAE,aAAa,CAAC,CAAC;aAC3C;YACD,OAAO,OAAO,CAAC,SAAS,EAAE,IAAI,CAAC,KAAoB,CAAC,CAAC;QACvD,CAAC,CAAC;QAEF,OAAO;YACL,CAAC,EAAE,IAAI;YACP,IAAI,EAAE,OAAO;YACb,QAAQ,EAAE,WAAW;YACrB,KAAK,EAAE,QAAQ;YACf,MAAM,EAAE,SAAS;SAClB,CAAC;IACJ,CAAC;CACF,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {FusedBatchNorm, FusedBatchNormAttrs} from '../kernel_names';\nimport {GradConfig, NamedAttrMap} from '../kernel_registry';\nimport {add} from '../ops/add';\nimport {getReductionAxes} from '../ops/broadcast_util';\nimport {mul} from '../ops/mul';\nimport {reshape} from '../ops/reshape';\nimport {rsqrt} from '../ops/rsqrt';\nimport {scalar} from '../ops/scalar';\nimport {sub} from '../ops/sub';\nimport {sum} from '../ops/sum';\nimport {tile} from '../ops/tile';\nimport {Tensor} from '../tensor';\nimport {Rank, ShapeMap} from '../types';\n\nexport const fusedBatchNormGradConfig: GradConfig = {\n  kernelName: FusedBatchNorm,\n  inputsToSave: ['x', 'mean', 'variance', 'scale'],\n  gradFunc: <R extends Rank>(\n      dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {\n    const {varianceEpsilon} = attrs as unknown as FusedBatchNormAttrs;\n    const [x, mean, variance, scale] = saved;\n\n    const scaleValue = scale == null ? scalar(1) : scale;\n    const reductionAxes = getReductionAxes(mean.shape, x.shape);\n    const tileShape: number[] = [];\n    if (mean.rank === 1) {\n      for (let i = 0; i < x.shape.length - 1; ++i) {\n        tileShape.push(x.shape[i]);\n      }\n      tileShape.push(1);\n    }\n\n    const xMinusMean = sub(x, mean);\n    const dyTimesScaleValue = mul(dy, scaleValue);\n    const oneOverSqrtVariance = rsqrt(add(variance, scalar(varianceEpsilon)));\n    const minusHalfRCube = mul(\n        mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance),\n        scalar(-0.5));\n\n    const derX = () => {\n      if (mean.rank === 1) {\n        return reshape(\n            mul(mul(dy,\n                    tile(\n                        reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]),\n                        tileShape)),\n                scaleValue),\n            x.shape);\n      } else {\n        return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);\n      }\n    };\n    const derMean = () => {\n      let meanDer =\n          mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);\n      if (mean.rank === 1) {\n        meanDer = sum(meanDer, reductionAxes);\n      }\n      return reshape(meanDer, mean.shape as ShapeMap[R]);\n    };\n    const derVariance = () => {\n      let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);\n\n      if (mean.rank === 1) {\n        varianceDer = sum(varianceDer, reductionAxes);\n      }\n      return reshape(varianceDer, mean.shape as ShapeMap[R]);\n    };\n    const derScale = () => {\n      const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);\n\n      let scaleDer = mul(dy, xMinusMean2TimesRsqrt);\n      if (mean.rank === 1) {\n        scaleDer = sum(scaleDer, reductionAxes);\n      }\n      return reshape(scaleDer, mean.shape as ShapeMap[R]);\n    };\n    const derOffset = () => {\n      let offsetDer = dy;\n      if (mean.rank === 1) {\n        offsetDer = sum(offsetDer, reductionAxes);\n      }\n      return reshape(offsetDer, mean.shape as ShapeMap[R]);\n    };\n\n    return {\n      x: derX,\n      mean: derMean,\n      variance: derVariance,\n      scale: derScale,\n      offset: derOffset\n    };\n  }\n};\n"]}