gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { FusedConv2D } from '@tensorflow/tfjs-core';
import { applyActivation } from '../utils/fused_utils';
import { add } from './Add';
import { conv2D } from './Conv2D';
import { reshape } from './Reshape';
export function fusedConv2D(args) {
    const { inputs, backend, attrs } = args;
    const { x, filter, bias, preluActivationWeights } = inputs;
    const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
    let result = conv2D({
        inputs: { x, filter },
        backend,
        attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
    });
    if (bias) {
        const resultOld = result;
        // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
        // to the channel of the conv2d's result; if the bias is a scalar, the
        // bias_add is computed as if the bias was broadcasted to the shape of the
        // conv2d's result.
        if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
            bias.shape[0] !== 1) {
            const reshapedBias = reshape({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
            result =
                add({ inputs: { a: result, b: reshapedBias }, backend });
            backend.disposeIntermediateTensorInfo(reshapedBias);
        }
        else {
            // This condition handles NHWC and NCHW (scalar case). The only other case
            // for NCHW (1D case) is handled above.
            result = add({ inputs: { a: result, b: bias }, backend });
        }
        backend.disposeIntermediateTensorInfo(resultOld);
    }
    if (activation) {
        const resultOld = result;
        // For NCHW format, if PReLu activation weights is a 1-D tensor, it is
        // supposed to be aligned with the channel of the conv2d's result. For other
        // cases, whether NCHW or NHWC data format, the conv2d result is
        // already aligned with the activation weights.
        if (dataFormat === 'NCHW' && activation === 'prelu' &&
            preluActivationWeights.shape.length === 1 &&
            preluActivationWeights.shape[0] !== 1) {
            const reshapedAlpha = reshape({
                inputs: { x: preluActivationWeights },
                backend,
                attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
            });
            result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
            backend.disposeIntermediateTensorInfo(reshapedAlpha);
        }
        else {
            result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
        }
        backend.disposeIntermediateTensorInfo(resultOld);
    }
    return result;
}
export const fusedConv2DConfig = {
    kernelName: FusedConv2D,
    backendName: 'cpu',
    kernelFunc: fusedConv2D
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"FusedConv2D.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/FusedConv2D.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,WAAW,EAA4E,MAAM,uBAAuB,CAAC;AAG7H,OAAO,EAAC,eAAe,EAAC,MAAM,sBAAsB,CAAC;AACrD,OAAO,EAAC,GAAG,EAAC,MAAM,OAAO,CAAC;AAC1B,OAAO,EAAC,MAAM,EAAC,MAAM,UAAU,CAAC;AAChC,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC,MAAM,UAAU,WAAW,CAAC,IAI3B;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,sBAAsB,EAAC,GAAG,MAAM,CAAC;IACzD,MAAM,EACJ,OAAO,EACP,GAAG,EACH,UAAU,EACV,SAAS,EACT,eAAe,EACf,UAAU,EACV,cAAc,EACf,GAAG,KAAK,CAAC;IAEV,IAAI,MAAM,GAAG,MAAM,CAAC;QAClB,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC;QACnB,OAAO;QACP,KAAK,EAAE,EAAC,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,SAAS,EAAE,eAAe,EAAC;KAC9D,CAAC,CAAC;IAEH,IAAI,IAAI,EAAE;QACR,MAAM,SAAS,GAAG,MAAM,CAAC;QACzB,yEAAyE;QACzE,sEAAsE;QACtE,0EAA0E;QAC1E,mBAAmB;QACnB,IAAI,UAAU,KAAK,MAAM,IAAI,IAAI,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC;YAChD,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE;YACvB,MAAM,YAAY,GAAG,OAAO,CACxB,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,IAAI,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAC,EAAC,CAAC,CAAC;YACzE,MAAM;gBACF,GAAG,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,YAAY,EAAC,EAAE,OAAO,EAAC,CAAe,CAAC;YACvE,OAAO,CAAC,6BAA6B,CAAC,YAAY,CAAC,CAAC;SACrD;aAAM;YACL,0EAA0E;YAC1E,uCAAuC;YACvC,MAAM,GAAG,GAAG,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,IAAI,EAAC,EAAE,OAAO,EAAC,CAAe,CAAC;SACrE;QACD,OAAO,CAAC,6BAA6B,CAAC,SAAS,CAAC,CAAC;KAClD;IAED,IAAI,UAAU,EAAE;QACd,MAAM,SAAS,GAAG,MAAM,CAAC;QACzB,sEAAsE;QACtE,4EAA4E;QAC5E,gEAAgE;QAChE,+CAA+C;QAC/C,IAAI,UAAU,KAAK,MAAM,IAAI,UAAU,KAAK,OAAO;YAC/C,sBAAsB,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC;YACzC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE;YACzC,MAAM,aAAa,GAAG,OAAO,CAAC;gBAC5B,MAAM,EAAE,EAAC,CAAC,EAAE,sBAAsB,EAAC;gBACnC,OAAO;gBACP,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAC;aACxD,CAAC,CAAC;YACH,MAAM,GAAG,eAAe,CACpB,OAAO,EAAE,MAAM,EAAE,UAAU,EAAE,aAAa,EAAE,cAAc,CAAC,CAAC;YAChE,OAAO,CAAC,6BAA6B,CAAC,aAAa,CAAC,CAAC;SACtD;aAAM;YACL,MAAM,GAAG,eAAe,CACpB,OAAO,EAAE,MAAM,EAAE,UAAU,EAAE,sBAAsB,EAAE,cAAc,CAAC,CAAC;SAC1E;QACD,OAAO,CAAC,6BAA6B,CAAC,SAAS,CAAC,CAAC;KAClD;IAED,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,MAAM,CAAC,MAAM,iBAAiB,GAAiB;IAC7C,UAAU,EAAE,WAAW;IACvB,WAAW,EAAE,KAAK;IAClB,UAAU,EAAE,WAAoC;CACjD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {FusedConv2D, FusedConv2DAttrs, FusedConv2DInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {applyActivation} from '../utils/fused_utils';\nimport {add} from './Add';\nimport {conv2D} from './Conv2D';\nimport {reshape} from './Reshape';\n\nexport function fusedConv2D(args: {\n  inputs: FusedConv2DInputs,\n  backend: MathBackendCPU,\n  attrs: FusedConv2DAttrs\n}): TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {x, filter, bias, preluActivationWeights} = inputs;\n  const {\n    strides,\n    pad,\n    dataFormat,\n    dilations,\n    dimRoundingMode,\n    activation,\n    leakyreluAlpha\n  } = attrs;\n\n  let result = conv2D({\n    inputs: {x, filter},\n    backend,\n    attrs: {strides, pad, dataFormat, dilations, dimRoundingMode}\n  });\n\n  if (bias) {\n    const resultOld = result;\n    // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned\n    // to the channel of the conv2d's result; if the bias is a scalar, the\n    // bias_add is computed as if the bias was broadcasted to the shape of the\n    // conv2d's result.\n    if (dataFormat === 'NCHW' && bias.shape.length === 1 &&\n        bias.shape[0] !== 1) {\n      const reshapedBias = reshape(\n          {inputs: {x: bias}, backend, attrs: {shape: [bias.shape[0], 1, 1]}});\n      result =\n          add({inputs: {a: result, b: reshapedBias}, backend}) as TensorInfo;\n      backend.disposeIntermediateTensorInfo(reshapedBias);\n    } else {\n      // This condition handles NHWC and NCHW (scalar case). The only other case\n      // for NCHW (1D case) is handled above.\n      result = add({inputs: {a: result, b: bias}, backend}) as TensorInfo;\n    }\n    backend.disposeIntermediateTensorInfo(resultOld);\n  }\n\n  if (activation) {\n    const resultOld = result;\n    // For NCHW format, if PReLu activation weights is a 1-D tensor, it is\n    // supposed to be aligned with the channel of the conv2d's result. For other\n    // cases, whether NCHW or NHWC data format, the conv2d result is\n    // already aligned with the activation weights.\n    if (dataFormat === 'NCHW' && activation === 'prelu' &&\n        preluActivationWeights.shape.length === 1 &&\n        preluActivationWeights.shape[0] !== 1) {\n      const reshapedAlpha = reshape({\n        inputs: {x: preluActivationWeights},\n        backend,\n        attrs: {shape: [preluActivationWeights.shape[0], 1, 1]}\n      });\n      result = applyActivation(\n          backend, result, activation, reshapedAlpha, leakyreluAlpha);\n      backend.disposeIntermediateTensorInfo(reshapedAlpha);\n    } else {\n      result = applyActivation(\n          backend, result, activation, preluActivationWeights, leakyreluAlpha);\n    }\n    backend.disposeIntermediateTensorInfo(resultOld);\n  }\n\n  return result;\n}\n\nexport const fusedConv2DConfig: KernelConfig = {\n  kernelName: FusedConv2D,\n  backendName: 'cpu',\n  kernelFunc: fusedConv2D as unknown as KernelFunc\n};\n"]}