/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { FusedBatchNorm } from '../kernel_names';
|
import { add } from '../ops/add';
|
import { getReductionAxes } from '../ops/broadcast_util';
|
import { mul } from '../ops/mul';
|
import { reshape } from '../ops/reshape';
|
import { rsqrt } from '../ops/rsqrt';
|
import { scalar } from '../ops/scalar';
|
import { sub } from '../ops/sub';
|
import { sum } from '../ops/sum';
|
import { tile } from '../ops/tile';
|
export const fusedBatchNormGradConfig = {
|
kernelName: FusedBatchNorm,
|
inputsToSave: ['x', 'mean', 'variance', 'scale'],
|
gradFunc: (dy, saved, attrs) => {
|
const { varianceEpsilon } = attrs;
|
const [x, mean, variance, scale] = saved;
|
const scaleValue = scale == null ? scalar(1) : scale;
|
const reductionAxes = getReductionAxes(mean.shape, x.shape);
|
const tileShape = [];
|
if (mean.rank === 1) {
|
for (let i = 0; i < x.shape.length - 1; ++i) {
|
tileShape.push(x.shape[i]);
|
}
|
tileShape.push(1);
|
}
|
const xMinusMean = sub(x, mean);
|
const dyTimesScaleValue = mul(dy, scaleValue);
|
const oneOverSqrtVariance = rsqrt(add(variance, scalar(varianceEpsilon)));
|
const minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
|
const derX = () => {
|
if (mean.rank === 1) {
|
return reshape(mul(mul(dy, tile(reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape);
|
}
|
else {
|
return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
|
}
|
};
|
const derMean = () => {
|
let meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
|
if (mean.rank === 1) {
|
meanDer = sum(meanDer, reductionAxes);
|
}
|
return reshape(meanDer, mean.shape);
|
};
|
const derVariance = () => {
|
let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
|
if (mean.rank === 1) {
|
varianceDer = sum(varianceDer, reductionAxes);
|
}
|
return reshape(varianceDer, mean.shape);
|
};
|
const derScale = () => {
|
const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
|
let scaleDer = mul(dy, xMinusMean2TimesRsqrt);
|
if (mean.rank === 1) {
|
scaleDer = sum(scaleDer, reductionAxes);
|
}
|
return reshape(scaleDer, mean.shape);
|
};
|
const derOffset = () => {
|
let offsetDer = dy;
|
if (mean.rank === 1) {
|
offsetDer = sum(offsetDer, reductionAxes);
|
}
|
return reshape(offsetDer, mean.shape);
|
};
|
return {
|
x: derX,
|
mean: derMean,
|
variance: derVariance,
|
scale: derScale,
|
offset: derOffset
|
};
|
}
|
};
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"FusedBatchNorm_grad.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/gradients/FusedBatchNorm_grad.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAC,cAAc,EAAsB,MAAM,iBAAiB,CAAC;AAEpE,OAAO,EAAC,GAAG,EAAC,MAAM,YAAY,CAAC;AAC/B,OAAO,EAAC,gBAAgB,EAAC,MAAM,uBAAuB,CAAC;AACvD,OAAO,EAAC,GAAG,EAAC,MAAM,YAAY,CAAC;AAC/B,OAAO,EAAC,OAAO,EAAC,MAAM,gBAAgB,CAAC;AACvC,OAAO,EAAC,KAAK,EAAC,MAAM,cAAc,CAAC;AACnC,OAAO,EAAC,MAAM,EAAC,MAAM,eAAe,CAAC;AACrC,OAAO,EAAC,GAAG,EAAC,MAAM,YAAY,CAAC;AAC/B,OAAO,EAAC,GAAG,EAAC,MAAM,YAAY,CAAC;AAC/B,OAAO,EAAC,IAAI,EAAC,MAAM,aAAa,CAAC;AAIjC,MAAM,CAAC,MAAM,wBAAwB,GAAe;IAClD,UAAU,EAAE,cAAc;IAC1B,YAAY,EAAE,CAAC,GAAG,EAAE,MAAM,EAAE,UAAU,EAAE,OAAO,CAAC;IAChD,QAAQ,EAAE,CACN,EAAU,EAAE,KAAe,EAAE,KAAmB,EAAE,EAAE;QACtD,MAAM,EAAC,eAAe,EAAC,GAAG,KAAuC,CAAC;QAClE,MAAM,CAAC,CAAC,EAAE,IAAI,EAAE,QAAQ,EAAE,KAAK,CAAC,GAAG,KAAK,CAAC;QAEzC,MAAM,UAAU,GAAG,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC;QACrD,MAAM,aAAa,GAAG,gBAAgB,CAAC,IAAI,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;QAC5D,MAAM,SAAS,GAAa,EAAE,CAAC;QAC/B,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,EAAE,CAAC,EAAE;gBAC3C,SAAS,CAAC,IAAI,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;aAC5B;YACD,SAAS,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACnB;QAED,MAAM,UAAU,GAAG,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;QAChC,MAAM,iBAAiB,GAAG,GAAG,CAAC,EAAE,EAAE,UAAU,CAAC,CAAC;QAC9C,MAAM,mBAAmB,GAAG,KAAK,CAAC,GAAG,CAAC,QAAQ,EAAE,MAAM,CAAC,eAAe,CAAC,CAAC,CAAC,CAAC;QAC1E,MAAM,cAAc,GAAG,GAAG,CACtB,GAAG,CAAC,GAAG,CAAC,mBAAmB,EAAE,mBAAmB,CAAC,EAAE,mBAAmB,CAAC,EACvE,MAAM,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QAElB,MAAM,IAAI,GAAG,GAAG,EAAE;YAChB,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,OAAO,OAAO,CACV,GAAG,CAAC,GAAG,CAAC,EAAE,EACF,IAAI,CACA,OAAO,CAAC,mBAAmB,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,EACtD,SAAS,CAAC,CAAC,EACnB,UAAU,CAAC,EACf,CAAC,CAAC,KAAK,CAAC,CAAC;aACd;iBAAM;gBACL,OAAO,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,EAAE,mBAAmB,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;aACxE;QACH,CAAC,CAAC;QACF,MAAM,OAAO,GAAG,GAAG,EAAE;YACnB,IAAI,OAAO,GACP,GAAG,CAAC,GAAG,CAAC,mBAAmB,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,iBAAiB,CAAC,CAAC;YACjE,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,OAAO,GAAG,GAAG,CAAC,OAAO,EAAE,aAAa,CAAC,CAAC;aACvC;YACD,OAAO,OAAO,CAAC,OAAO,EAAE,IAAI,CAAC,KAAoB,CAAC,CAAC;QACrD,CAAC,CAAC;QACF,MAAM,WAAW,GAAG,GAAG,EAAE;YACvB,IAAI,WAAW,GAAG,GAAG,CAAC,GAAG,CAAC,cAAc,EAAE,UAAU,CAAC,EAAE,iBAAiB,CAAC,CAAC;YAE1E,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,WAAW,GAAG,GAAG,CAAC,WAAW,EAAE,aAAa,CAAC,CAAC;aAC/C;YACD,OAAO,OAAO,CAAC,WAAW,EAAE,IAAI,CAAC,KAAoB,CAAC,CAAC;QACzD,CAAC,CAAC;QACF,MAAM,QAAQ,GAAG,GAAG,EAAE;YACpB,MAAM,qBAAqB,GAAG,GAAG,CAAC,UAAU,EAAE,mBAAmB,CAAC,CAAC;YAEnE,IAAI,QAAQ,GAAG,GAAG,CAAC,EAAE,EAAE,qBAAqB,CAAC,CAAC;YAC9C,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,QAAQ,GAAG,GAAG,CAAC,QAAQ,EAAE,aAAa,CAAC,CAAC;aACzC;YACD,OAAO,OAAO,CAAC,QAAQ,EAAE,IAAI,CAAC,KAAoB,CAAC,CAAC;QACtD,CAAC,CAAC;QACF,MAAM,SAAS,GAAG,GAAG,EAAE;YACrB,IAAI,SAAS,GAAG,EAAE,CAAC;YACnB,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;gBACnB,SAAS,GAAG,GAAG,CAAC,SAAS,EAAE,aAAa,CAAC,CAAC;aAC3C;YACD,OAAO,OAAO,CAAC,SAAS,EAAE,IAAI,CAAC,KAAoB,CAAC,CAAC;QACvD,CAAC,CAAC;QAEF,OAAO;YACL,CAAC,EAAE,IAAI;YACP,IAAI,EAAE,OAAO;YACb,QAAQ,EAAE,WAAW;YACrB,KAAK,EAAE,QAAQ;YACf,MAAM,EAAE,SAAS;SAClB,CAAC;IACJ,CAAC;CACF,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {FusedBatchNorm, FusedBatchNormAttrs} from '../kernel_names';\nimport {GradConfig, NamedAttrMap} from '../kernel_registry';\nimport {add} from '../ops/add';\nimport {getReductionAxes} from '../ops/broadcast_util';\nimport {mul} from '../ops/mul';\nimport {reshape} from '../ops/reshape';\nimport {rsqrt} from '../ops/rsqrt';\nimport {scalar} from '../ops/scalar';\nimport {sub} from '../ops/sub';\nimport {sum} from '../ops/sum';\nimport {tile} from '../ops/tile';\nimport {Tensor} from '../tensor';\nimport {Rank, ShapeMap} from '../types';\n\nexport const fusedBatchNormGradConfig: GradConfig = {\n  kernelName: FusedBatchNorm,\n  inputsToSave: ['x', 'mean', 'variance', 'scale'],\n  gradFunc: <R extends Rank>(\n      dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {\n    const {varianceEpsilon} = attrs as unknown as FusedBatchNormAttrs;\n    const [x, mean, variance, scale] = saved;\n\n    const scaleValue = scale == null ? scalar(1) : scale;\n    const reductionAxes = getReductionAxes(mean.shape, x.shape);\n    const tileShape: number[] = [];\n    if (mean.rank === 1) {\n      for (let i = 0; i < x.shape.length - 1; ++i) {\n        tileShape.push(x.shape[i]);\n      }\n      tileShape.push(1);\n    }\n\n    const xMinusMean = sub(x, mean);\n    const dyTimesScaleValue = mul(dy, scaleValue);\n    const oneOverSqrtVariance = rsqrt(add(variance, scalar(varianceEpsilon)));\n    const minusHalfRCube = mul(\n        mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance),\n        scalar(-0.5));\n\n    const derX = () => {\n      if (mean.rank === 1) {\n        return reshape(\n            mul(mul(dy,\n                    tile(\n                        reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]),\n                        tileShape)),\n                scaleValue),\n            x.shape);\n      } else {\n        return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);\n      }\n    };\n    const derMean = () => {\n      let meanDer =\n          mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);\n      if (mean.rank === 1) {\n        meanDer = sum(meanDer, reductionAxes);\n      }\n      return reshape(meanDer, mean.shape as ShapeMap[R]);\n    };\n    const derVariance = () => {\n      let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);\n\n      if (mean.rank === 1) {\n        varianceDer = sum(varianceDer, reductionAxes);\n      }\n      return reshape(varianceDer, mean.shape as ShapeMap[R]);\n    };\n    const derScale = () => {\n      const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);\n\n      let scaleDer = mul(dy, xMinusMean2TimesRsqrt);\n      if (mean.rank === 1) {\n        scaleDer = sum(scaleDer, reductionAxes);\n      }\n      return reshape(scaleDer, mean.shape as ShapeMap[R]);\n    };\n    const derOffset = () => {\n      let offsetDer = dy;\n      if (mean.rank === 1) {\n        offsetDer = sum(offsetDer, reductionAxes);\n      }\n      return reshape(offsetDer, mean.shape as ShapeMap[R]);\n    };\n\n    return {\n      x: derX,\n      mean: derMean,\n      variance: derVariance,\n      scale: derScale,\n      offset: derOffset\n    };\n  }\n};\n"]}
|