gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { ENGINE } from '../engine';
import { DepthwiseConv2dNative } from '../kernel_names';
import { convertToTensor } from '../tensor_util_env';
import * as util from '../util';
import * as conv_util from './conv_util';
import { op } from './operation';
import { reshape } from './reshape';
/**
 * Depthwise 2D convolution.
 *
 * Given a 4D `input` array and a `filter` array of shape
 * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
 * `inChannels` convolutional filters of depth 1, this op applies a
 * different filter to each input channel (expanding from 1 channel to
 * `channelMultiplier` channels for each), then concatenates the results
 * together. The output has `inChannels * channelMultiplier` channels.
 *
 * See
 * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
 *     https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
 * for more details.
 *
 * @param x The input tensor, of rank 4 or rank 3, of shape
 *     `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
 * assumed.
 * @param filter The filter tensor, rank 4, of shape
 *     `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
 * @param strides The strides of the convolution: `[strideHeight,
 * strideWidth]`. If strides is a single number, then `strideHeight ==
 * strideWidth`.
 * @param pad The type of padding algorithm.
 *   - `same` and stride 1: output will be of same size as input,
 *       regardless of filter size.
 *   - `valid`: output will be smaller than input if filter is larger
 *       than 1x1.
 *   - For more info, see this guide:
 *     [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
 *          https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
 * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
 *     in which we sample input values across the height and width dimensions
 *     in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
 *     number, then `dilationHeight == dilationWidth`. If it is greater than
 *     1, then all values of `strides` must be 1.
 * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
 *     "NHWC". Specify the data format of the input and output data. With the
 *     default format "NHWC", the data is stored in the order of: [batch,
 *     height, width, channels]. Only "NHWC" is currently supported.
 * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
 *     provided, it will default to truncate.
 *
 * @doc {heading: 'Operations', subheading: 'Convolution'}
 */
function depthwiseConv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
    const $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
    const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
    let x4D = $x;
    let reshapedTo4D = false;
    if ($x.rank === 3) {
        reshapedTo4D = true;
        x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
    }
    util.assert(x4D.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got ` +
        `rank ${x4D.rank}.`);
    util.assert($filter.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ` +
        `${$filter.rank}.`);
    const inChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
    util.assert(inChannels === $filter.shape[2], () => `Error in depthwiseConv2d: number of input channels ` +
        `(${inChannels}) must match the inChannels dimension in ` +
        `filter ${$filter.shape[2]}.`);
    conv_util.checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
    const inputs = { x: x4D, filter: $filter };
    const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
    // tslint:disable-next-line: no-unnecessary-type-assertion
    const res = ENGINE.runKernel(DepthwiseConv2dNative, inputs, attrs);
    if (reshapedTo4D) {
        return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
    }
    return res;
}
export const depthwiseConv2d = /* @__PURE__ */ op({ depthwiseConv2d_ });
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"depthwise_conv2d.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/depthwise_conv2d.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAC,MAAM,EAAC,MAAM,WAAW,CAAC;AACjC,OAAO,EAAC,qBAAqB,EAA0D,MAAM,iBAAiB,CAAC;AAI/G,OAAO,EAAC,eAAe,EAAC,MAAM,oBAAoB,CAAC;AAEnD,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAEhC,OAAO,KAAK,SAAS,MAAM,aAAa,CAAC;AACzC,OAAO,EAAC,EAAE,EAAC,MAAM,aAAa,CAAC;AAC/B,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA4CG;AACH,SAAS,gBAAgB,CACrB,CAAe,EAAE,MAA2B,EAC5C,OAAgC,EAChC,GAAoD,EACpD,aAA4B,MAAM,EAClC,YAAqC,CAAC,CAAC,EAAE,CAAC,CAAC,EAC3C,eAAwC;IAC1C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,GAAG,EAAE,iBAAiB,EAAE,SAAS,CAAC,CAAC;IACjE,MAAM,OAAO,GACT,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,iBAAiB,EAAE,SAAS,CAAC,CAAC;IAEpE,IAAI,GAAG,GAAG,EAAc,CAAC;IACzB,IAAI,YAAY,GAAG,KAAK,CAAC;IACzB,IAAI,EAAE,CAAC,IAAI,KAAK,CAAC,EAAE;QACjB,YAAY,GAAG,IAAI,CAAC;QACpB,GAAG,GAAG,OAAO,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;KAC/D;IACD,IAAI,CAAC,MAAM,CACP,GAAG,CAAC,IAAI,KAAK,CAAC,EACd,GAAG,EAAE,CAAC,0DAA0D;QAC5D,QAAQ,GAAG,CAAC,IAAI,GAAG,CAAC,CAAC;IAC7B,IAAI,CAAC,MAAM,CACP,OAAO,CAAC,IAAI,KAAK,CAAC,EAClB,GAAG,EAAE,CAAC,gEAAgE;QAClE,GAAG,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;IAC5B,MAAM,UAAU,GAAG,UAAU,KAAK,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACvE,IAAI,CAAC,MAAM,CACP,UAAU,KAAK,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EAC/B,GAAG,EAAE,CAAC,qDAAqD;QACvD,IAAI,UAAU,2CAA2C;QACzD,UAAU,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;IACvC,SAAS,CAAC,yBAAyB,CAAC,iBAAiB,EAAE,GAAG,EAAE,eAAe,CAAC,CAAC;IAC7E,MAAM,MAAM,GAAgC,EAAC,CAAC,EAAE,GAAG,EAAE,MAAM,EAAE,OAAO,EAAC,CAAC;IACtE,MAAM,KAAK,GACP,EAAC,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,SAAS,EAAE,eAAe,EAAC,CAAC;IAE3D,0DAA0D;IAC1D,MAAM,GAAG,GAAG,MAAM,CAAC,SAAS,CACZ,qBAAqB,EAAE,MAAmC,EAC1D,KAAgC,CAAM,CAAC;IAEvD,IAAI,YAAY,EAAE;QAChB,OAAO,OAAO,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAM,CAAC;KACtE;IACD,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,eAAe,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,gBAAgB,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {ENGINE} from '../engine';\nimport {DepthwiseConv2dNative, DepthwiseConv2dNativeAttrs, DepthwiseConv2dNativeInputs} from '../kernel_names';\nimport {NamedAttrMap} from '../kernel_registry';\nimport {Tensor3D, Tensor4D} from '../tensor';\nimport {NamedTensorMap} from '../tensor_types';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport * as conv_util from './conv_util';\nimport {op} from './operation';\nimport {reshape} from './reshape';\n\n/**\n * Depthwise 2D convolution.\n *\n * Given a 4D `input` array and a `filter` array of shape\n * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing\n * `inChannels` convolutional filters of depth 1, this op applies a\n * different filter to each input channel (expanding from 1 channel to\n * `channelMultiplier` channels for each), then concatenates the results\n * together. The output has `inChannels * channelMultiplier` channels.\n *\n * See\n * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](\n *     https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)\n * for more details.\n *\n * @param x The input tensor, of rank 4 or rank 3, of shape\n *     `[batch, height, width, inChannels]`. If rank 3, batch of 1 is\n * assumed.\n * @param filter The filter tensor, rank 4, of shape\n *     `[filterHeight, filterWidth, inChannels, channelMultiplier]`.\n * @param strides The strides of the convolution: `[strideHeight,\n * strideWidth]`. If strides is a single number, then `strideHeight ==\n * strideWidth`.\n * @param pad The type of padding algorithm.\n *   - `same` and stride 1: output will be of same size as input,\n *       regardless of filter size.\n *   - `valid`: output will be smaller than input if filter is larger\n *       than 1x1.\n *   - For more info, see this guide:\n *     [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](\n *          https://www.tensorflow.org/api_docs/python/tf/nn/convolution)\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n *     in which we sample input values across the height and width dimensions\n *     in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single\n *     number, then `dilationHeight == dilationWidth`. If it is greater than\n *     1, then all values of `strides` must be 1.\n * @param dataFormat: An optional string from: \"NHWC\", \"NCHW\". Defaults to\n *     \"NHWC\". Specify the data format of the input and output data. With the\n *     default format \"NHWC\", the data is stored in the order of: [batch,\n *     height, width, channels]. Only \"NHWC\" is currently supported.\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is\n *     provided, it will default to truncate.\n *\n * @doc {heading: 'Operations', subheading: 'Convolution'}\n */\nfunction depthwiseConv2d_<T extends Tensor3D|Tensor4D>(\n    x: T|TensorLike, filter: Tensor4D|TensorLike,\n    strides: [number, number]|number,\n    pad: 'valid'|'same'|number|conv_util.ExplicitPadding,\n    dataFormat: 'NHWC'|'NCHW' = 'NHWC',\n    dilations: [number, number]|number = [1, 1],\n    dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n  const $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');\n  const $filter =\n      convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');\n\n  let x4D = $x as Tensor4D;\n  let reshapedTo4D = false;\n  if ($x.rank === 3) {\n    reshapedTo4D = true;\n    x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);\n  }\n  util.assert(\n      x4D.rank === 4,\n      () => `Error in depthwiseConv2d: input must be rank 4, but got ` +\n          `rank ${x4D.rank}.`);\n  util.assert(\n      $filter.rank === 4,\n      () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ` +\n          `${$filter.rank}.`);\n  const inChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];\n  util.assert(\n      inChannels === $filter.shape[2],\n      () => `Error in depthwiseConv2d: number of input channels ` +\n          `(${inChannels}) must match the inChannels dimension in ` +\n          `filter ${$filter.shape[2]}.`);\n  conv_util.checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);\n  const inputs: DepthwiseConv2dNativeInputs = {x: x4D, filter: $filter};\n  const attrs: DepthwiseConv2dNativeAttrs =\n      {strides, pad, dataFormat, dilations, dimRoundingMode};\n\n  // tslint:disable-next-line: no-unnecessary-type-assertion\n  const res = ENGINE.runKernel(\n                  DepthwiseConv2dNative, inputs as unknown as NamedTensorMap,\n                  attrs as unknown as NamedAttrMap) as T;\n\n  if (reshapedTo4D) {\n    return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]) as T;\n  }\n  return res;\n}\n\nexport const depthwiseConv2d = /* @__PURE__ */ op({depthwiseConv2d_});\n"]}