/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from '../engine';
|
import { Conv3DBackpropInputV2 } from '../kernel_names';
|
import * as util from '../util';
|
import { op } from './operation';
|
import { reshape } from './reshape';
|
/**
|
* Computes the derivative of the input of a 3D convolution.
|
*
|
* @param xShape The shape of the input: [batch, depth, height, width,
|
* in_channels]. If length of 4, batch of 1 is assumed.
|
* @param dy The derivative of the output, of rank 5 or rank 4 of shape
|
* `[batch, outDepth, outHeight, outWidth, in_channels]`.
|
* If rank 4, batch of 1 is assumed.
|
* @param filter The filter, rank 5, of shape
|
* `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
|
* @param strides The strides of the convolution: `[strideDepth, strideHeight,
|
* strideWidth]`.
|
* @param pad The type of padding algorithm used:
|
* - `same` and stride 1: output will be of same size as input,
|
* regardless of filter size.
|
* - `valid`: output will be smaller than input if filter is larger
|
* than 1x1.
|
*/
|
function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
|
util.assert(xShape.length === dy.rank, () => `Length of inShape ` +
|
`(${xShape.length}) and rank of dy (${dy.rank}) must match`);
|
let xShape5D = xShape;
|
let dy5D = dy;
|
let reshapedTo5D = false;
|
if (dy.rank === 4) {
|
reshapedTo5D = true;
|
dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
|
xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
|
}
|
const inDepth = xShape5D[4];
|
const outDepth = dy5D.shape[4];
|
util.assert(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
|
`${xShape5D.length}.`);
|
util.assert(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
|
`rank ${dy5D.rank}`);
|
util.assert(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
|
`rank ${filter.rank}`);
|
util.assert(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
|
`match input depth for filter ${filter.shape[3]}.`);
|
util.assert(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
|
`match output depth for filter ${filter.shape[4]}.`);
|
const inputs = { dy: dy5D, filter };
|
const attrs = { pad, strides, inputShape: xShape5D };
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
|
if (reshapedTo5D) {
|
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
|
}
|
return res;
|
}
|
export const conv3DBackpropInput = /* @__PURE__ */ op({ conv3DBackpropInput_ });
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"conv3d_backprop_input.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/conv3d_backprop_input.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAC,MAAM,EAAC,MAAM,WAAW,CAAC;AACjC,OAAO,EAAC,qBAAqB,EAA0D,MAAM,iBAAiB,CAAC;AAI/G,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAEhC,OAAO,EAAC,EAAE,EAAC,MAAM,aAAa,CAAC;AAC/B,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC;;;;;;;;;;;;;;;;;GAiBG;AACH,SAAS,oBAAoB,CACzB,MAE6C,EAC7C,EAAK,EAAE,MAAgB,EAAE,OAAwC,EACjE,GAAmB;IACrB,IAAI,CAAC,MAAM,CACP,MAAM,CAAC,MAAM,KAAK,EAAE,CAAC,IAAI,EACzB,GAAG,EAAE,CAAC,oBAAoB;QACtB,IAAI,MAAM,CAAC,MAAM,qBAAqB,EAAE,CAAC,IAAI,cAAc,CAAC,CAAC;IAErE,IAAI,QAAQ,GAAG,MAAkD,CAAC;IAClE,IAAI,IAAI,GAAG,EAAc,CAAC;IAC1B,IAAI,YAAY,GAAG,KAAK,CAAC;IACzB,IAAI,EAAE,CAAC,IAAI,KAAK,CAAC,EAAE;QACjB,YAAY,GAAG,IAAI,CAAC;QACpB,IAAI,GAAG,OAAO,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC5E,QAAQ,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;KAC5D;IAED,MAAM,OAAO,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;IAC5B,MAAM,QAAQ,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAC/B,IAAI,CAAC,MAAM,CACP,QAAQ,CAAC,MAAM,KAAK,CAAC,EACrB,GAAG,EAAE,CACD,oEAAoE;QACpE,GAAG,QAAQ,CAAC,MAAM,GAAG,CAAC,CAAC;IAC/B,IAAI,CAAC,MAAM,CACP,IAAI,CAAC,IAAI,KAAK,CAAC,EACf,GAAG,EAAE,CAAC,sDAAsD;QACxD,QAAQ,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;IAC7B,IAAI,CAAC,MAAM,CACP,MAAM,CAAC,IAAI,KAAK,CAAC,EACjB,GAAG,EAAE,CAAC,0DAA0D;QAC5D,QAAQ,MAAM,CAAC,IAAI,EAAE,CAAC,CAAC;IAC/B,IAAI,CAAC,MAAM,CACP,OAAO,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAC3B,GAAG,EAAE,CAAC,4CAA4C,OAAO,SAAS;QAC9D,gCAAgC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;IAC5D,IAAI,CAAC,MAAM,CACP,QAAQ,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAC5B,GAAG,EAAE,CAAC,6CAA6C,QAAQ,SAAS;QAChE,iCAAiC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;IAE7D,MAAM,MAAM,GAAgC,EAAC,EAAE,EAAE,IAAI,EAAE,MAAM,EAAC,CAAC;IAE/D,MAAM,KAAK,GACsB,EAAC,GAAG,EAAE,OAAO,EAAE,UAAU,EAAE,QAAQ,EAAC,CAAC;IAEtE,0DAA0D;IAC1D,MAAM,GAAG,GAAG,MAAM,CAAC,SAAS,CACZ,qBAAqB,EAAE,MAAmC,EAC1D,KAAgC,CAAM,CAAC;IAEvD,IAAI,YAAY,EAAE;QAChB,OAAO,OAAO,CACH,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CACnE,CAAC;KACP;IACD,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,mBAAmB,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,oBAAoB,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {ENGINE} from '../engine';\nimport {Conv3DBackpropInputV2, Conv3DBackpropInputV2Attrs, Conv3DBackpropInputV2Inputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor4D, Tensor5D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport * as util from '../util';\n\nimport {op} from './operation';\nimport {reshape} from './reshape';\n\n/**\n * Computes the derivative of the input of a 3D convolution.\n *\n * @param xShape The shape of the input: [batch, depth, height, width,\n * in_channels]. If length of 4, batch of 1 is assumed.\n * @param dy The derivative of the output, of rank 5 or rank 4 of shape\n *   `[batch, outDepth, outHeight, outWidth, in_channels]`.\n * If rank 4, batch of 1 is assumed.\n * @param filter The filter, rank 5, of shape\n *     `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.\n * @param strides The strides of the convolution: `[strideDepth, strideHeight,\n * strideWidth]`.\n * @param pad The type of padding algorithm used:\n *    - `same` and stride 1: output will be of same size as input,\n *       regardless of filter size.\n *    - `valid`: output will be smaller than input if filter is larger\n *       than 1x1.\n */\nfunction conv3DBackpropInput_<T extends Tensor4D|Tensor5D>(\n    xShape:\n        [number, number, number, number,\n         number]|[number, number, number, number],\n    dy: T, filter: Tensor5D, strides: [number, number, number]|number,\n    pad: 'valid'|'same'): T {\n  util.assert(\n      xShape.length === dy.rank,\n      () => `Length of inShape ` +\n          `(${xShape.length}) and rank of dy (${dy.rank}) must match`);\n\n  let xShape5D = xShape as [number, number, number, number, number];\n  let dy5D = dy as Tensor5D;\n  let reshapedTo5D = false;\n  if (dy.rank === 4) {\n    reshapedTo5D = true;\n    dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);\n    xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];\n  }\n\n  const inDepth = xShape5D[4];\n  const outDepth = dy5D.shape[4];\n  util.assert(\n      xShape5D.length === 5,\n      () =>\n          `Error in conv3dDerInput: inShape must be length 5, but got length ` +\n          `${xShape5D.length}.`);\n  util.assert(\n      dy5D.rank === 5,\n      () => `Error in conv3dDerInput: dy must be rank 5, but got ` +\n          `rank ${dy5D.rank}`);\n  util.assert(\n      filter.rank === 5,\n      () => `Error in conv3dDerInput: filter must be rank 5, but got ` +\n          `rank ${filter.rank}`);\n  util.assert(\n      inDepth === filter.shape[3],\n      () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +\n          `match input depth for filter ${filter.shape[3]}.`);\n  util.assert(\n      outDepth === filter.shape[4],\n      () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +\n          `match output depth for filter ${filter.shape[4]}.`);\n\n  const inputs: Conv3DBackpropInputV2Inputs = {dy: dy5D, filter};\n\n  const attrs:\n      Conv3DBackpropInputV2Attrs = {pad, strides, inputShape: xShape5D};\n\n  // tslint:disable-next-line: no-unnecessary-type-assertion\n  const res = ENGINE.runKernel(\n                  Conv3DBackpropInputV2, inputs as unknown as NamedTensorMap,\n                  attrs as unknown as NamedAttrMap) as T;\n\n  if (reshapedTo5D) {\n    return reshape(\n               res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]) as\n        T;\n  }\n  return res;\n}\n\nexport const conv3DBackpropInput = /* @__PURE__ */ op({conv3DBackpropInput_});\n"]}
|