gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, DepthwiseConv2dNative, env, util } from '@tensorflow/tfjs-core';
import { DepthwiseConv2DProgram } from '../conv_gpu_depthwise';
import { DepthwiseConvPacked2DProgram } from '../conv_packed_gpu_depthwise';
export function depthwiseConv2dNative(args) {
    const { inputs, backend, attrs } = args;
    const { x, filter } = inputs;
    const { strides, pad, dilations, dimRoundingMode } = attrs;
    let $dilations = dilations;
    if ($dilations == null) {
        $dilations = [1, 1];
    }
    util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
        `1. Got strides ${strides} and dilations '${$dilations}'`);
    const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
    let program;
    if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 &&
        convInfo.outChannels / convInfo.inChannels === 1) {
        program = new DepthwiseConvPacked2DProgram(convInfo);
    }
    else {
        program = new DepthwiseConv2DProgram(convInfo);
    }
    const customValues = [
        [convInfo.padInfo.top, convInfo.padInfo.left],
        [convInfo.strideHeight, convInfo.strideWidth],
        [convInfo.dilationHeight, convInfo.dilationWidth],
        [convInfo.inHeight, convInfo.inWidth]
    ];
    return backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
}
export const depthwiseConv2dNativeConfig = {
    kernelName: DepthwiseConv2dNative,
    backendName: 'webgl',
    kernelFunc: depthwiseConv2dNative,
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRGVwdGh3aXNlQ29udjJkTmF0aXZlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9rZXJuZWxzL0RlcHRod2lzZUNvbnYyZE5hdGl2ZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFFLHFCQUFxQixFQUEyRCxHQUFHLEVBQTRCLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBR3hLLE9BQU8sRUFBQyxzQkFBc0IsRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBQzdELE9BQU8sRUFBQyw0QkFBNEIsRUFBQyxNQUFNLDhCQUE4QixDQUFDO0FBRTFFLE1BQU0sVUFBVSxxQkFBcUIsQ0FBQyxJQUlyQztJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUMzQixNQUFNLEVBQUMsT0FBTyxFQUFFLEdBQUcsRUFBRSxTQUFTLEVBQUUsZUFBZSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRXpELElBQUksVUFBVSxHQUFHLFNBQVMsQ0FBQztJQUMzQixJQUFJLFVBQVUsSUFBSSxJQUFJLEVBQUU7UUFDdEIsVUFBVSxHQUFHLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO0tBQ3JCO0lBRUQsSUFBSSxDQUFDLE1BQU0sQ0FDUCxZQUFZLENBQUMsOEJBQThCLENBQUMsT0FBTyxFQUFFLFVBQVUsQ0FBQyxFQUNoRSxHQUFHLEVBQUUsQ0FBQyxnRUFBZ0U7UUFDbEUsa0JBQWtCLE9BQU8sbUJBQW1CLFVBQVUsR0FBRyxDQUFDLENBQUM7SUFFbkUsTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLGlCQUFpQixDQUMzQyxDQUFDLENBQUMsS0FBeUMsRUFDM0MsTUFBTSxDQUFDLEtBQXlDLEVBQUUsT0FBTyxFQUFFLFVBQVUsRUFDckUsR0FBRyxFQUFFLGVBQWUsRUFBRSxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7SUFFaEQsSUFBSSxPQUE0RCxDQUFDO0lBQ2pFLElBQUksR0FBRyxFQUFFLENBQUMsT0FBTyxDQUFDLDBCQUEwQixDQUFDLElBQUksUUFBUSxDQUFDLFdBQVcsSUFBSSxDQUFDO1FBQ3RFLFFBQVEsQ0FBQyxXQUFXLEdBQUcsUUFBUSxDQUFDLFVBQVUsS0FBSyxDQUFDLEVBQUU7UUFDcEQsT0FBTyxHQUFHLElBQUksNEJBQTRCLENBQUMsUUFBUSxDQUFDLENBQUM7S0FDdEQ7U0FBTTtRQUNMLE9BQU8sR0FBRyxJQUFJLHNCQUFzQixDQUFDLFFBQVEsQ0FBQyxDQUFDO0tBQ2hEO0lBQ0QsTUFBTSxZQUFZLEdBQUc7UUFDbkIsQ0FBQyxRQUFRLENBQUMsT0FBTyxDQUFDLEdBQUcsRUFBRSxRQUFRLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQztRQUM3QyxDQUFDLFFBQVEsQ0FBQyxZQUFZLEVBQUUsUUFBUSxDQUFDLFdBQVcsQ0FBQztRQUM3QyxDQUFDLFFBQVEsQ0FBQyxjQUFjLEVBQUUsUUFBUSxDQUFDLGFBQWEsQ0FBQztRQUNqRCxDQUFDLFFBQVEsQ0FBQyxRQUFRLEVBQUUsUUFBUSxDQUFDLE9BQU8sQ0FBQztLQUN0QyxDQUFDO0lBQ0YsT0FBTyxPQUFPLENBQUMsZUFBZSxDQUFDLE9BQU8sRUFBRSxDQUFDLENBQUMsRUFBRSxNQUFNLENBQUMsRUFBRSxTQUFTLEVBQUUsWUFBWSxDQUFDLENBQUM7QUFDaEYsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLDJCQUEyQixHQUFpQjtJQUN2RCxVQUFVLEVBQUUscUJBQXFCO0lBQ2pDLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxxQkFBOEM7Q0FDM0QsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIERlcHRod2lzZUNvbnYyZE5hdGl2ZSwgRGVwdGh3aXNlQ29udjJkTmF0aXZlQXR0cnMsIERlcHRod2lzZUNvbnYyZE5hdGl2ZUlucHV0cywgZW52LCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5pbXBvcnQge0RlcHRod2lzZUNvbnYyRFByb2dyYW19IGZyb20gJy4uL2NvbnZfZ3B1X2RlcHRod2lzZSc7XG5pbXBvcnQge0RlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW19IGZyb20gJy4uL2NvbnZfcGFja2VkX2dwdV9kZXB0aHdpc2UnO1xuXG5leHBvcnQgZnVuY3Rpb24gZGVwdGh3aXNlQ29udjJkTmF0aXZlKGFyZ3M6IHtcbiAgaW5wdXRzOiBEZXB0aHdpc2VDb252MmROYXRpdmVJbnB1dHMsXG4gIGF0dHJzOiBEZXB0aHdpc2VDb252MmROYXRpdmVBdHRycyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTFxufSkge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eCwgZmlsdGVyfSA9IGlucHV0cztcbiAgY29uc3Qge3N0cmlkZXMsIHBhZCwgZGlsYXRpb25zLCBkaW1Sb3VuZGluZ01vZGV9ID0gYXR0cnM7XG5cbiAgbGV0ICRkaWxhdGlvbnMgPSBkaWxhdGlvbnM7XG4gIGlmICgkZGlsYXRpb25zID09IG51bGwpIHtcbiAgICAkZGlsYXRpb25zID0gWzEsIDFdO1xuICB9XG5cbiAgdXRpbC5hc3NlcnQoXG4gICAgICBiYWNrZW5kX3V0aWwuZWl0aGVyU3RyaWRlc09yRGlsYXRpb25zQXJlT25lKHN0cmlkZXMsICRkaWxhdGlvbnMpLFxuICAgICAgKCkgPT4gJ0Vycm9yIGluIGRlcHRod2lzZUNvbnYyZDogRWl0aGVyIHN0cmlkZXMgb3IgZGlsYXRpb25zIG11c3QgYmUgJyArXG4gICAgICAgICAgYDEuIEdvdCBzdHJpZGVzICR7c3RyaWRlc30gYW5kIGRpbGF0aW9ucyAnJHskZGlsYXRpb25zfSdgKTtcblxuICBjb25zdCBjb252SW5mbyA9IGJhY2tlbmRfdXRpbC5jb21wdXRlQ29udjJESW5mbyhcbiAgICAgIHguc2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sXG4gICAgICBmaWx0ZXIuc2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIHN0cmlkZXMsICRkaWxhdGlvbnMsXG4gICAgICBwYWQsIGRpbVJvdW5kaW5nTW9kZSwgdHJ1ZSAvKiBkZXB0aHdpc2UgKi8pO1xuXG4gIGxldCBwcm9ncmFtOiBEZXB0aHdpc2VDb252MkRQcm9ncmFtfERlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW07XG4gIGlmIChlbnYoKS5nZXRCb29sKCdXRUJHTF9QQUNLX0RFUFRIV0lTRUNPTlYnKSAmJiBjb252SW5mby5zdHJpZGVXaWR0aCA8PSAyICYmXG4gICAgICBjb252SW5mby5vdXRDaGFubmVscyAvIGNvbnZJbmZvLmluQ2hhbm5lbHMgPT09IDEpIHtcbiAgICBwcm9ncmFtID0gbmV3IERlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW0oY29udkluZm8pO1xuICB9IGVsc2Uge1xuICAgIHByb2dyYW0gPSBuZXcgRGVwdGh3aXNlQ29udjJEUHJvZ3JhbShjb252SW5mbyk7XG4gIH1cbiAgY29uc3QgY3VzdG9tVmFsdWVzID0gW1xuICAgIFtjb252SW5mby5wYWRJbmZvLnRvcCwgY29udkluZm8ucGFkSW5mby5sZWZ0XSxcbiAgICBbY29udkluZm8uc3RyaWRlSGVpZ2h0LCBjb252SW5mby5zdHJpZGVXaWR0aF0sXG4gICAgW2NvbnZJbmZvLmRpbGF0aW9uSGVpZ2h0LCBjb252SW5mby5kaWxhdGlvbldpZHRoXSxcbiAgICBbY29udkluZm8uaW5IZWlnaHQsIGNvbnZJbmZvLmluV2lkdGhdXG4gIF07XG4gIHJldHVybiBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShwcm9ncmFtLCBbeCwgZmlsdGVyXSwgJ2Zsb2F0MzInLCBjdXN0b21WYWx1ZXMpO1xufVxuXG5leHBvcnQgY29uc3QgZGVwdGh3aXNlQ29udjJkTmF0aXZlQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IERlcHRod2lzZUNvbnYyZE5hdGl2ZSxcbiAgYmFja2VuZE5hbWU6ICd3ZWJnbCcsXG4gIGtlcm5lbEZ1bmM6IGRlcHRod2lzZUNvbnYyZE5hdGl2ZSBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmMsXG59O1xuIl19