/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from '../engine';
|
import { AvgPool3DGrad } from '../kernel_names';
|
import { convertToTensor } from '../tensor_util_env';
|
import * as util from '../util';
|
import { checkPadOnDimRoundingMode } from './conv_util';
|
import { op } from './operation';
|
import { reshape } from './reshape';
|
/**
|
* Computes the backprop of a 3d avg pool.
|
*
|
* @param dy The dy error, of rank 5 of shape
|
* [batchSize, depth, height, width, channels].
|
* assumed.
|
* @param input The original input image, of rank 5 or rank4 of shape
|
* [batchSize, depth, height, width, channels].
|
* @param filterSize The filter size:
|
* `[filterDepth, filterHeight, filterWidth]`.
|
* `filterSize` is a single number,
|
* then `filterDepth == filterHeight == filterWidth`.
|
* @param strides The strides of the pooling:
|
* `[strideDepth, strideHeight, strideWidth]`. If
|
* `strides` is a single number, then `strideHeight == strideWidth`.
|
* @param pad A string from: 'same', 'valid'. The type of padding algorithm
|
* used in the forward prop of the op.
|
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
|
* provided, it will default to truncate.
|
*/
|
function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) {
|
const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');
|
const $input = convertToTensor(input, 'input', 'avgPool3dGrad');
|
let dy5D = $dy;
|
let input5D = $input;
|
let reshapedTo5D = false;
|
if ($input.rank === 4) {
|
reshapedTo5D = true;
|
dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
|
input5D = reshape($input, [
|
1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
|
]);
|
}
|
util.assert(dy5D.rank === 5, () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` +
|
`${dy5D.rank}.`);
|
util.assert(input5D.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
|
`${input5D.rank}.`);
|
checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
|
const inputs = { dy: dy5D, input: input5D };
|
const attrs = { filterSize, strides, pad, dimRoundingMode };
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
const res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs);
|
if (reshapedTo5D) {
|
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
|
}
|
return res;
|
}
|
export const avgPool3dGrad = /* @__PURE__ */ op({ avgPool3dGrad_ });
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"avg_pool_3d_grad.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/avg_pool_3d_grad.ts"],"names":[],"mappings":"AACA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAC,MAAM,WAAW,CAAC;AACjC,OAAO,EAAC,aAAa,EAA0C,MAAM,iBAAiB,CAAC;AAIvF,OAAO,EAAC,eAAe,EAAC,MAAM,oBAAoB,CAAC;AAEnD,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAEhC,OAAO,EAAC,yBAAyB,EAAC,MAAM,aAAa,CAAC;AACtD,OAAO,EAAC,EAAE,EAAC,MAAM,aAAa,CAAC;AAC/B,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC;;;;;;;;;;;;;;;;;;;GAmBG;AACH,SAAS,cAAc,CACnB,EAAgB,EAAE,KAAmB,EACrC,UAA2C,EAC3C,OAAwC,EAAE,GAA0B,EACpE,eAAwC;IAC1C,MAAM,GAAG,GAAG,eAAe,CAAC,EAAE,EAAE,IAAI,EAAE,eAAe,CAAC,CAAC;IACvD,MAAM,MAAM,GAAG,eAAe,CAAC,KAAK,EAAE,OAAO,EAAE,eAAe,CAAC,CAAC;IAEhE,IAAI,IAAI,GAAG,GAAe,CAAC;IAC3B,IAAI,OAAO,GAAG,MAAkB,CAAC;IACjC,IAAI,YAAY,GAAG,KAAK,CAAC;IAEzB,IAAI,MAAM,CAAC,IAAI,KAAK,CAAC,EAAE;QACrB,YAAY,GAAG,IAAI,CAAC;QACpB,IAAI,GAAG,OAAO,CACV,GAAG,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACtE,OAAO,GAAG,OAAO,CAAC,MAAM,EAAE;YACxB,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC;SACtE,CAAC,CAAC;KACJ;IAED,IAAI,CAAC,MAAM,CACP,IAAI,CAAC,IAAI,KAAK,CAAC,EACf,GAAG,EAAE,CAAC,yDAAyD;QAC3D,GAAG,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC;IACzB,IAAI,CAAC,MAAM,CACP,OAAO,CAAC,IAAI,KAAK,CAAC,EAClB,GAAG,EAAE,CAAC,4DAA4D;QAC9D,GAAG,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;IAC5B,yBAAyB,CAAC,eAAe,EAAE,GAAG,EAAE,eAAe,CAAC,CAAC;IACjE,MAAM,MAAM,GAAwB,EAAC,EAAE,EAAE,IAAI,EAAE,KAAK,EAAE,OAAO,EAAC,CAAC;IAC/D,MAAM,KAAK,GAAuB,EAAC,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,eAAe,EAAC,CAAC;IAE9E,0DAA0D;IAC1D,MAAM,GAAG,GAAG,MAAM,CAAC,SAAS,CACZ,aAAa,EAAE,MAAmC,EAClD,KAAgC,CAAM,CAAC;IAEvD,IAAI,YAAY,EAAE;QAChB,OAAO,OAAO,CACH,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CACnE,CAAC;KACP;IAED,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,aAAa,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,cAAc,EAAC,CAAC,CAAC","sourcesContent":["\n/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../engine';\nimport {AvgPool3DGrad, AvgPool3DGradAttrs, AvgPool3DGradInputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor4D, Tensor5D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {checkPadOnDimRoundingMode} from './conv_util';\nimport {op} from './operation';\nimport {reshape} from './reshape';\n\n/**\n * Computes the backprop of a 3d avg pool.\n *\n * @param dy The dy error, of rank 5 of shape\n *     [batchSize, depth, height, width, channels].\n * assumed.\n * @param input The original input image, of rank 5 or rank4 of shape\n *     [batchSize, depth, height, width, channels].\n * @param filterSize The filter size:\n *     `[filterDepth, filterHeight, filterWidth]`.\n *     `filterSize` is a single number,\n *     then `filterDepth == filterHeight == filterWidth`.\n * @param strides The strides of the pooling:\n *     `[strideDepth, strideHeight, strideWidth]`. If\n *     `strides` is a single number, then `strideHeight == strideWidth`.\n * @param pad A string from: 'same', 'valid'. The type of padding algorithm\n *     used in the forward prop of the op.\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is\n *     provided, it will default to truncate.\n */\nfunction avgPool3dGrad_<T extends Tensor4D|Tensor5D>(\n    dy: T|TensorLike, input: T|TensorLike,\n    filterSize: [number, number, number]|number,\n    strides: [number, number, number]|number, pad: 'valid'|'same'|number,\n    dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n  const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');\n  const $input = convertToTensor(input, 'input', 'avgPool3dGrad');\n\n  let dy5D = $dy as Tensor5D;\n  let input5D = $input as Tensor5D;\n  let reshapedTo5D = false;\n\n  if ($input.rank === 4) {\n    reshapedTo5D = true;\n    dy5D = reshape(\n        $dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);\n    input5D = reshape($input, [\n      1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]\n    ]);\n  }\n\n  util.assert(\n      dy5D.rank === 5,\n      () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` +\n          `${dy5D.rank}.`);\n  util.assert(\n      input5D.rank === 5,\n      () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +\n          `${input5D.rank}.`);\n  checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);\n  const inputs: AvgPool3DGradInputs = {dy: dy5D, input: input5D};\n  const attrs: AvgPool3DGradAttrs = {filterSize, strides, pad, dimRoundingMode};\n\n  // tslint:disable-next-line: no-unnecessary-type-assertion\n  const res = ENGINE.runKernel(\n                  AvgPool3DGrad, inputs as unknown as NamedTensorMap,\n                  attrs as unknown as NamedAttrMap) as T;\n\n  if (reshapedTo5D) {\n    return reshape(\n               res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]) as\n        T;\n  }\n\n  return res;\n}\n\nexport const avgPool3dGrad = /* @__PURE__ */ op({avgPool3dGrad_});\n"]}
|