gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Conv2D, TensorBuffer, util } from '@tensorflow/tfjs-core';
import { assertNotComplex } from '../cpu_util';
export function conv2D(args) {
    const { inputs, backend, attrs } = args;
    const { x, filter } = inputs;
    const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
    assertNotComplex([x, filter], 'conv2d');
    const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);
    const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
    const filterHeight = convInfo.filterHeight;
    const filterWidth = convInfo.filterWidth;
    const dilationHeight = convInfo.dilationHeight;
    const dilationWidth = convInfo.dilationWidth;
    const padLeft = convInfo.padInfo.left;
    const padTop = convInfo.padInfo.top;
    const isChannelsLast = convInfo.dataFormat === 'channelsLast';
    const y = new TensorBuffer(convInfo.outShape, x.dtype);
    const xStrides = util.computeStrides(x.shape);
    const filterStrides = util.computeStrides(filter.shape);
    const xBatchStride = xStrides[0];
    const xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
    const xColStride = isChannelsLast ? xStrides[2] : 1;
    const xChannelStride = isChannelsLast ? 1 : xStrides[1];
    const yBatchStride = y.strides[0];
    const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
    const yColStride = isChannelsLast ? y.strides[2] : 1;
    const yChannelStride = isChannelsLast ? 1 : y.strides[1];
    const xVals = backend.data.get(x.dataId).values;
    const wVals = backend.data.get(filter.dataId).values;
    const yVals = y.values;
    for (let b = 0; b < convInfo.batchSize; ++b) {
        const xOffset1 = b * xBatchStride;
        const yOffset1 = b * yBatchStride;
        for (let yR = 0; yR < convInfo.outHeight; ++yR) {
            const yOffset2 = yOffset1 + yR * yRowStride;
            const xRCorner = yR * convInfo.strideHeight - padTop;
            for (let wR = 0; wR < filterHeight; ++wR) {
                const xR = xRCorner + wR * dilationHeight;
                if (xR < 0 || xR >= convInfo.inHeight) {
                    continue;
                }
                const wOffset1 = wR * filterStrides[0];
                const xOffset2 = xOffset1 + xR * xRowStride;
                for (let yC = 0; yC < convInfo.outWidth; ++yC) {
                    const yOffset3 = yOffset2 + yC * yColStride;
                    const xCCorner = yC * convInfo.strideWidth - padLeft;
                    for (let wC = 0; wC < filterWidth; ++wC) {
                        const xC = xCCorner + wC * dilationWidth;
                        if (xC < 0 || xC >= convInfo.inWidth) {
                            continue;
                        }
                        const wOffset2 = wOffset1 + wC * filterStrides[1];
                        const xOffset3 = xOffset2 + xC * xColStride;
                        let wOffset3 = wOffset2;
                        for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
                            const xVal = xVals[xOffset3 + d1 * xChannelStride];
                            for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
                                yVals[yOffset3 + d2 * yChannelStride] +=
                                    xVal * wVals[wOffset3 + d2];
                            }
                            wOffset3 += convInfo.outChannels;
                        }
                    }
                }
            }
        }
    }
    return backend.makeTensorInfo(y.shape, y.dtype, yVals);
}
export const conv2DConfig = {
    kernelName: Conv2D,
    backendName: 'cpu',
    kernelFunc: conv2D
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Conv2D.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/Conv2D.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,MAAM,EAAuD,YAAY,EAA0B,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAG5J,OAAO,EAAC,gBAAgB,EAAC,MAAM,aAAa,CAAC;AAE7C,MAAM,UAAU,MAAM,CAClB,IAAyE;IAE3E,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,MAAM,EAAC,GAAG,MAAM,CAAC;IAC3B,MAAM,EAAC,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,SAAS,EAAE,eAAe,EAAC,GAAG,KAAK,CAAC;IAErE,gBAAgB,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,EAAE,QAAQ,CAAC,CAAC;IAExC,MAAM,WAAW,GAAG,YAAY,CAAC,uBAAuB,CAAC,UAAU,CAAC,CAAC;IACrE,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,CAAC,CAAC,KAAyC,EAC3C,MAAM,CAAC,KAAyC,EAAE,OAAO,EAAE,SAAS,EAAE,GAAG,EACzE,eAAe,EAAE,KAAK,CAAC,eAAe,EAAE,WAAW,CAAC,CAAC;IAEzD,MAAM,YAAY,GAAG,QAAQ,CAAC,YAAY,CAAC;IAC3C,MAAM,WAAW,GAAG,QAAQ,CAAC,WAAW,CAAC;IACzC,MAAM,cAAc,GAAG,QAAQ,CAAC,cAAc,CAAC;IAC/C,MAAM,aAAa,GAAG,QAAQ,CAAC,aAAa,CAAC;IAC7C,MAAM,OAAO,GAAG,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;IACtC,MAAM,MAAM,GAAG,QAAQ,CAAC,OAAO,CAAC,GAAG,CAAC;IACpC,MAAM,cAAc,GAAG,QAAQ,CAAC,UAAU,KAAK,cAAc,CAAC;IAE9D,MAAM,CAAC,GAAG,IAAI,YAAY,CAAC,QAAQ,CAAC,QAAQ,EAAE,CAAC,CAAC,KAAkB,CAAC,CAAC;IAEpE,MAAM,QAAQ,GAAG,IAAI,CAAC,cAAc,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IAC9C,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;IAExD,MAAM,YAAY,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;IACjC,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;IAC9D,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACpD,MAAM,cAAc,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;IACxD,MAAM,YAAY,GAAG,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAClC,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAChE,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACrD,MAAM,cAAc,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAEzD,MAAM,KAAK,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IAC9D,MAAM,KAAK,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IACnE,MAAM,KAAK,GAAG,CAAC,CAAC,MAAM,CAAC;IAEvB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,CAAC,SAAS,EAAE,EAAE,CAAC,EAAE;QAC3C,MAAM,QAAQ,GAAG,CAAC,GAAG,YAAY,CAAC;QAClC,MAAM,QAAQ,GAAG,CAAC,GAAG,YAAY,CAAC;QAClC,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,QAAQ,CAAC,SAAS,EAAE,EAAE,EAAE,EAAE;YAC9C,MAAM,QAAQ,GAAG,QAAQ,GAAG,EAAE,GAAG,UAAU,CAAC;YAC5C,MAAM,QAAQ,GAAG,EAAE,GAAG,QAAQ,CAAC,YAAY,GAAG,MAAM,CAAC;YACrD,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,YAAY,EAAE,EAAE,EAAE,EAAE;gBACxC,MAAM,EAAE,GAAG,QAAQ,GAAG,EAAE,GAAG,cAAc,CAAC;gBAC1C,IAAI,EAAE,GAAG,CAAC,IAAI,EAAE,IAAI,QAAQ,CAAC,QAAQ,EAAE;oBACrC,SAAS;iBACV;gBACD,MAAM,QAAQ,GAAG,EAAE,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;gBACvC,MAAM,QAAQ,GAAG,QAAQ,GAAG,EAAE,GAAG,UAAU,CAAC;gBAC5C,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,QAAQ,CAAC,QAAQ,EAAE,EAAE,EAAE,EAAE;oBAC7C,MAAM,QAAQ,GAAG,QAAQ,GAAG,EAAE,GAAG,UAAU,CAAC;oBAC5C,MAAM,QAAQ,GAAG,EAAE,GAAG,QAAQ,CAAC,WAAW,GAAG,OAAO,CAAC;oBACrD,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,WAAW,EAAE,EAAE,EAAE,EAAE;wBACvC,MAAM,EAAE,GAAG,QAAQ,GAAG,EAAE,GAAG,aAAa,CAAC;wBACzC,IAAI,EAAE,GAAG,CAAC,IAAI,EAAE,IAAI,QAAQ,CAAC,OAAO,EAAE;4BACpC,SAAS;yBACV;wBACD,MAAM,QAAQ,GAAG,QAAQ,GAAG,EAAE,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;wBAClD,MAAM,QAAQ,GAAG,QAAQ,GAAG,EAAE,GAAG,UAAU,CAAC;wBAC5C,IAAI,QAAQ,GAAG,QAAQ,CAAC;wBACxB,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,QAAQ,CAAC,UAAU,EAAE,EAAE,EAAE,EAAE;4BAC/C,MAAM,IAAI,GAAG,KAAK,CAAC,QAAQ,GAAG,EAAE,GAAG,cAAc,CAAC,CAAC;4BACnD,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,QAAQ,CAAC,WAAW,EAAE,EAAE,EAAE,EAAE;gCAChD,KAAK,CAAC,QAAQ,GAAG,EAAE,GAAG,cAAc,CAAC;oCACjC,IAAI,GAAG,KAAK,CAAC,QAAQ,GAAG,EAAE,CAAC,CAAC;6BACjC;4BACD,QAAQ,IAAI,QAAQ,CAAC,WAAW,CAAC;yBAClC;qBACF;iBACF;aACF;SACF;KACF;IAED,OAAO,OAAO,CAAC,cAAc,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC;AACzD,CAAC;AAED,MAAM,CAAC,MAAM,YAAY,GAAiB;IACxC,UAAU,EAAE,MAAM;IAClB,WAAW,EAAE,KAAK;IAClB,UAAU,EAAE,MAA+B;CAC5C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, Conv2D, Conv2DAttrs, Conv2DInputs, KernelConfig, KernelFunc, TensorBuffer, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nexport function conv2D(\n    args: {inputs: Conv2DInputs, backend: MathBackendCPU, attrs: Conv2DAttrs}):\n    TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {x, filter} = inputs;\n  const {strides, pad, dataFormat, dilations, dimRoundingMode} = attrs;\n\n  assertNotComplex([x, filter], 'conv2d');\n\n  const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);\n  const convInfo = backend_util.computeConv2DInfo(\n      x.shape as [number, number, number, number],\n      filter.shape as [number, number, number, number], strides, dilations, pad,\n      dimRoundingMode, false /* depthwise */, $dataFormat);\n\n  const filterHeight = convInfo.filterHeight;\n  const filterWidth = convInfo.filterWidth;\n  const dilationHeight = convInfo.dilationHeight;\n  const dilationWidth = convInfo.dilationWidth;\n  const padLeft = convInfo.padInfo.left;\n  const padTop = convInfo.padInfo.top;\n  const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n\n  const y = new TensorBuffer(convInfo.outShape, x.dtype as 'float32');\n\n  const xStrides = util.computeStrides(x.shape);\n  const filterStrides = util.computeStrides(filter.shape);\n\n  const xBatchStride = xStrides[0];\n  const xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];\n  const xColStride = isChannelsLast ? xStrides[2] : 1;\n  const xChannelStride = isChannelsLast ? 1 : xStrides[1];\n  const yBatchStride = y.strides[0];\n  const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];\n  const yColStride = isChannelsLast ? y.strides[2] : 1;\n  const yChannelStride = isChannelsLast ? 1 : y.strides[1];\n\n  const xVals = backend.data.get(x.dataId).values as TypedArray;\n  const wVals = backend.data.get(filter.dataId).values as TypedArray;\n  const yVals = y.values;\n\n  for (let b = 0; b < convInfo.batchSize; ++b) {\n    const xOffset1 = b * xBatchStride;\n    const yOffset1 = b * yBatchStride;\n    for (let yR = 0; yR < convInfo.outHeight; ++yR) {\n      const yOffset2 = yOffset1 + yR * yRowStride;\n      const xRCorner = yR * convInfo.strideHeight - padTop;\n      for (let wR = 0; wR < filterHeight; ++wR) {\n        const xR = xRCorner + wR * dilationHeight;\n        if (xR < 0 || xR >= convInfo.inHeight) {\n          continue;\n        }\n        const wOffset1 = wR * filterStrides[0];\n        const xOffset2 = xOffset1 + xR * xRowStride;\n        for (let yC = 0; yC < convInfo.outWidth; ++yC) {\n          const yOffset3 = yOffset2 + yC * yColStride;\n          const xCCorner = yC * convInfo.strideWidth - padLeft;\n          for (let wC = 0; wC < filterWidth; ++wC) {\n            const xC = xCCorner + wC * dilationWidth;\n            if (xC < 0 || xC >= convInfo.inWidth) {\n              continue;\n            }\n            const wOffset2 = wOffset1 + wC * filterStrides[1];\n            const xOffset3 = xOffset2 + xC * xColStride;\n            let wOffset3 = wOffset2;\n            for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {\n              const xVal = xVals[xOffset3 + d1 * xChannelStride];\n              for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {\n                yVals[yOffset3 + d2 * yChannelStride] +=\n                    xVal * wVals[wOffset3 + d2];\n              }\n              wOffset3 += convInfo.outChannels;\n            }\n          }\n        }\n      }\n    }\n  }\n\n  return backend.makeTensorInfo(y.shape, y.dtype, yVals);\n}\n\nexport const conv2DConfig: KernelConfig = {\n  kernelName: Conv2D,\n  backendName: 'cpu',\n  kernelFunc: conv2D as unknown as KernelFunc\n};\n"]}