gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { FusedConv2D } from '@tensorflow/tfjs-core';
import { applyActivation } from '../utils/fused_utils';
import { add } from './Add';
import { conv2D } from './Conv2D';
import { reshape } from './Reshape';
export function fusedConv2D(args) {
    const { inputs, backend, attrs } = args;
    const { x, filter, bias, preluActivationWeights } = inputs;
    const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
    let result = conv2D({
        inputs: { x, filter },
        backend,
        attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
    });
    if (bias) {
        const resultOld = result;
        // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
        // to the channel of the conv2d's result; if the bias is a scalar, the
        // bias_add is computed as if the bias was broadcasted to the shape of the
        // conv2d's result.
        if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
            bias.shape[0] !== 1) {
            const reshapedBias = reshape({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
            result =
                add({ inputs: { a: result, b: reshapedBias }, backend });
            backend.disposeIntermediateTensorInfo(reshapedBias);
        }
        else {
            // This condition handles NHWC and NCHW (scalar case). The only other case
            // for NCHW (1D case) is handled above.
            result = add({ inputs: { a: result, b: bias }, backend });
        }
        backend.disposeIntermediateTensorInfo(resultOld);
    }
    if (activation) {
        const resultOld = result;
        // For NCHW format, if PReLu activation weights is a 1-D tensor, it is
        // supposed to be aligned with the channel of the conv2d's result. For other
        // cases, whether NCHW or NHWC data format, the conv2d result is
        // already aligned with the activation weights.
        if (dataFormat === 'NCHW' && activation === 'prelu' &&
            preluActivationWeights.shape.length === 1 &&
            preluActivationWeights.shape[0] !== 1) {
            const reshapedAlpha = reshape({
                inputs: { x: preluActivationWeights },
                backend,
                attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
            });
            result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
            backend.disposeIntermediateTensorInfo(reshapedAlpha);
        }
        else {
            result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
        }
        backend.disposeIntermediateTensorInfo(resultOld);
    }
    return result;
}
export const fusedConv2DConfig = {
    kernelName: FusedConv2D,
    backendName: 'cpu',
    kernelFunc: fusedConv2D
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRnVzZWRDb252MkQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL0Z1c2VkQ29udjJELnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxXQUFXLEVBQTRFLE1BQU0sdUJBQXVCLENBQUM7QUFHN0gsT0FBTyxFQUFDLGVBQWUsRUFBQyxNQUFNLHNCQUFzQixDQUFDO0FBQ3JELE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFVBQVUsQ0FBQztBQUNoQyxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBRWxDLE1BQU0sVUFBVSxXQUFXLENBQUMsSUFJM0I7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUUsSUFBSSxFQUFFLHNCQUFzQixFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ3pELE1BQU0sRUFDSixPQUFPLEVBQ1AsR0FBRyxFQUNILFVBQVUsRUFDVixTQUFTLEVBQ1QsZUFBZSxFQUNmLFVBQVUsRUFDVixjQUFjLEVBQ2YsR0FBRyxLQUFLLENBQUM7SUFFVixJQUFJLE1BQU0sR0FBRyxNQUFNLENBQUM7UUFDbEIsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQztRQUNuQixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsT0FBTyxFQUFFLEdBQUcsRUFBRSxVQUFVLEVBQUUsU0FBUyxFQUFFLGVBQWUsRUFBQztLQUM5RCxDQUFDLENBQUM7SUFFSCxJQUFJLElBQUksRUFBRTtRQUNSLE1BQU0sU0FBUyxHQUFHLE1BQU0sQ0FBQztRQUN6Qix5RUFBeUU7UUFDekUsc0VBQXNFO1FBQ3RFLDBFQUEwRTtRQUMxRSxtQkFBbUI7UUFDbkIsSUFBSSxVQUFVLEtBQUssTUFBTSxJQUFJLElBQUksQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUM7WUFDaEQsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7WUFDdkIsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUN4QixFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxJQUFJLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUMsRUFBQyxDQUFDLENBQUM7WUFDekUsTUFBTTtnQkFDRixHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsTUFBTSxFQUFFLENBQUMsRUFBRSxZQUFZLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBZSxDQUFDO1lBQ3ZFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxZQUFZLENBQUMsQ0FBQztTQUNyRDthQUFNO1lBQ0wsMEVBQTBFO1lBQzFFLHVDQUF1QztZQUN2QyxNQUFNLEdBQUcsR0FBRyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBRSxDQUFDLEVBQUUsSUFBSSxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztTQUNyRTtRQUNELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxTQUFTLENBQUMsQ0FBQztLQUNsRDtJQUVELElBQUksVUFBVSxFQUFFO1FBQ2QsTUFBTSxTQUFTLEdBQUcsTUFBTSxDQUFDO1FBQ3pCLHNFQUFzRTtRQUN0RSw0RUFBNEU7UUFDNUUsZ0VBQWdFO1FBQ2hFLCtDQUErQztRQUMvQyxJQUFJLFVBQVUsS0FBSyxNQUFNLElBQUksVUFBVSxLQUFLLE9BQU87WUFDL0Msc0JBQXNCLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxDQUFDO1lBQ3pDLHNCQUFzQixDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7WUFDekMsTUFBTSxhQUFhLEdBQUcsT0FBTyxDQUFDO2dCQUM1QixNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsc0JBQXNCLEVBQUM7Z0JBQ25DLE9BQU87Z0JBQ1AsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLENBQUMsc0JBQXNCLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBQzthQUN4RCxDQUFDLENBQUM7WUFDSCxNQUFNLEdBQUcsZUFBZSxDQUNwQixPQUFPLEVBQUUsTUFBTSxFQUFFLFVBQVUsRUFBRSxhQUFhLEVBQUUsY0FBYyxDQUFDLENBQUM7WUFDaEUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLGFBQWEsQ0FBQyxDQUFDO1NBQ3REO2FBQU07WUFDTCxNQUFNLEdBQUcsZUFBZSxDQUNwQixPQUFPLEVBQUUsTUFBTSxFQUFFLFVBQVUsRUFBRSxzQkFBc0IsRUFBRSxjQUFjLENBQUMsQ0FBQztTQUMxRTtRQUNELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxTQUFTLENBQUMsQ0FBQztLQUNsRDtJQUVELE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxpQkFBaUIsR0FBaUI7SUFDN0MsVUFBVSxFQUFFLFdBQVc7SUFDdkIsV0FBVyxFQUFFLEtBQUs7SUFDbEIsVUFBVSxFQUFFLFdBQW9DO0NBQ2pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7RnVzZWRDb252MkQsIEZ1c2VkQ29udjJEQXR0cnMsIEZ1c2VkQ29udjJESW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm99IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXBwbHlBY3RpdmF0aW9ufSBmcm9tICcuLi91dGlscy9mdXNlZF91dGlscyc7XG5pbXBvcnQge2FkZH0gZnJvbSAnLi9BZGQnO1xuaW1wb3J0IHtjb252MkR9IGZyb20gJy4vQ29udjJEJztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcblxuZXhwb3J0IGZ1bmN0aW9uIGZ1c2VkQ29udjJEKGFyZ3M6IHtcbiAgaW5wdXRzOiBGdXNlZENvbnYyRElucHV0cyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRDUFUsXG4gIGF0dHJzOiBGdXNlZENvbnYyREF0dHJzXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4LCBmaWx0ZXIsIGJpYXMsIHByZWx1QWN0aXZhdGlvbldlaWdodHN9ID0gaW5wdXRzO1xuICBjb25zdCB7XG4gICAgc3RyaWRlcyxcbiAgICBwYWQsXG4gICAgZGF0YUZvcm1hdCxcbiAgICBkaWxhdGlvbnMsXG4gICAgZGltUm91bmRpbmdNb2RlLFxuICAgIGFjdGl2YXRpb24sXG4gICAgbGVha3lyZWx1QWxwaGFcbiAgfSA9IGF0dHJzO1xuXG4gIGxldCByZXN1bHQgPSBjb252MkQoe1xuICAgIGlucHV0czoge3gsIGZpbHRlcn0sXG4gICAgYmFja2VuZCxcbiAgICBhdHRyczoge3N0cmlkZXMsIHBhZCwgZGF0YUZvcm1hdCwgZGlsYXRpb25zLCBkaW1Sb3VuZGluZ01vZGV9XG4gIH0pO1xuXG4gIGlmIChiaWFzKSB7XG4gICAgY29uc3QgcmVzdWx0T2xkID0gcmVzdWx0O1xuICAgIC8vIEZvciBOQ0hXIGZvcm1hdCwgaWYgYmlhcyBpcyBhIDEtRCB0ZW5zb3IsIGl0IGlzIHN1cHBvc2VkIHRvIGJlIGFsaWduZWRcbiAgICAvLyB0byB0aGUgY2hhbm5lbCBvZiB0aGUgY29udjJkJ3MgcmVzdWx0OyBpZiB0aGUgYmlhcyBpcyBhIHNjYWxhciwgdGhlXG4gICAgLy8gYmlhc19hZGQgaXMgY29tcHV0ZWQgYXMgaWYgdGhlIGJpYXMgd2FzIGJyb2FkY2FzdGVkIHRvIHRoZSBzaGFwZSBvZiB0aGVcbiAgICAvLyBjb252MmQncyByZXN1bHQuXG4gICAgaWYgKGRhdGFGb3JtYXQgPT09ICdOQ0hXJyAmJiBiaWFzLnNoYXBlLmxlbmd0aCA9PT0gMSAmJlxuICAgICAgICBiaWFzLnNoYXBlWzBdICE9PSAxKSB7XG4gICAgICBjb25zdCByZXNoYXBlZEJpYXMgPSByZXNoYXBlKFxuICAgICAgICAgIHtpbnB1dHM6IHt4OiBiaWFzfSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogW2JpYXMuc2hhcGVbMF0sIDEsIDFdfX0pO1xuICAgICAgcmVzdWx0ID1cbiAgICAgICAgICBhZGQoe2lucHV0czoge2E6IHJlc3VsdCwgYjogcmVzaGFwZWRCaWFzfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc2hhcGVkQmlhcyk7XG4gICAgfSBlbHNlIHtcbiAgICAgIC8vIFRoaXMgY29uZGl0aW9uIGhhbmRsZXMgTkhXQyBhbmQgTkNIVyAoc2NhbGFyIGNhc2UpLiBUaGUgb25seSBvdGhlciBjYXNlXG4gICAgICAvLyBmb3IgTkNIVyAoMUQgY2FzZSkgaXMgaGFuZGxlZCBhYm92ZS5cbiAgICAgIHJlc3VsdCA9IGFkZCh7aW5wdXRzOiB7YTogcmVzdWx0LCBiOiBiaWFzfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgfVxuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ocmVzdWx0T2xkKTtcbiAgfVxuXG4gIGlmIChhY3RpdmF0aW9uKSB7XG4gICAgY29uc3QgcmVzdWx0T2xkID0gcmVzdWx0O1xuICAgIC8vIEZvciBOQ0hXIGZvcm1hdCwgaWYgUFJlTHUgYWN0aXZhdGlvbiB3ZWlnaHRzIGlzIGEgMS1EIHRlbnNvciwgaXQgaXNcbiAgICAvLyBzdXBwb3NlZCB0byBiZSBhbGlnbmVkIHdpdGggdGhlIGNoYW5uZWwgb2YgdGhlIGNvbnYyZCdzIHJlc3VsdC4gRm9yIG90aGVyXG4gICAgLy8gY2FzZXMsIHdoZXRoZXIgTkNIVyBvciBOSFdDIGRhdGEgZm9ybWF0LCB0aGUgY29udjJkIHJlc3VsdCBpc1xuICAgIC8vIGFscmVhZHkgYWxpZ25lZCB3aXRoIHRoZSBhY3RpdmF0aW9uIHdlaWdodHMuXG4gICAgaWYgKGRhdGFGb3JtYXQgPT09ICdOQ0hXJyAmJiBhY3RpdmF0aW9uID09PSAncHJlbHUnICYmXG4gICAgICAgIHByZWx1QWN0aXZhdGlvbldlaWdodHMuc2hhcGUubGVuZ3RoID09PSAxICYmXG4gICAgICAgIHByZWx1QWN0aXZhdGlvbldlaWdodHMuc2hhcGVbMF0gIT09IDEpIHtcbiAgICAgIGNvbnN0IHJlc2hhcGVkQWxwaGEgPSByZXNoYXBlKHtcbiAgICAgICAgaW5wdXRzOiB7eDogcHJlbHVBY3RpdmF0aW9uV2VpZ2h0c30sXG4gICAgICAgIGJhY2tlbmQsXG4gICAgICAgIGF0dHJzOiB7c2hhcGU6IFtwcmVsdUFjdGl2YXRpb25XZWlnaHRzLnNoYXBlWzBdLCAxLCAxXX1cbiAgICAgIH0pO1xuICAgICAgcmVzdWx0ID0gYXBwbHlBY3RpdmF0aW9uKFxuICAgICAgICAgIGJhY2tlbmQsIHJlc3VsdCwgYWN0aXZhdGlvbiwgcmVzaGFwZWRBbHBoYSwgbGVha3lyZWx1QWxwaGEpO1xuICAgICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhyZXNoYXBlZEFscGhhKTtcbiAgICB9IGVsc2Uge1xuICAgICAgcmVzdWx0ID0gYXBwbHlBY3RpdmF0aW9uKFxuICAgICAgICAgIGJhY2tlbmQsIHJlc3VsdCwgYWN0aXZhdGlvbiwgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0cywgbGVha3lyZWx1QWxwaGEpO1xuICAgIH1cbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc3VsdE9sZCk7XG4gIH1cblxuICByZXR1cm4gcmVzdWx0O1xufVxuXG5leHBvcnQgY29uc3QgZnVzZWRDb252MkRDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogRnVzZWRDb252MkQsXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogZnVzZWRDb252MkQgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19