/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, DepthwiseConv2dNative, env, util } from '@tensorflow/tfjs-core';
|
import { DepthwiseConv2DProgram } from '../conv_gpu_depthwise';
|
import { DepthwiseConvPacked2DProgram } from '../conv_packed_gpu_depthwise';
|
export function depthwiseConv2dNative(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, filter } = inputs;
|
const { strides, pad, dilations, dimRoundingMode } = attrs;
|
let $dilations = dilations;
|
if ($dilations == null) {
|
$dilations = [1, 1];
|
}
|
util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
|
`1. Got strides ${strides} and dilations '${$dilations}'`);
|
const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
|
let program;
|
if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 &&
|
convInfo.outChannels / convInfo.inChannels === 1) {
|
program = new DepthwiseConvPacked2DProgram(convInfo);
|
}
|
else {
|
program = new DepthwiseConv2DProgram(convInfo);
|
}
|
const customValues = [
|
[convInfo.padInfo.top, convInfo.padInfo.left],
|
[convInfo.strideHeight, convInfo.strideWidth],
|
[convInfo.dilationHeight, convInfo.dilationWidth],
|
[convInfo.inHeight, convInfo.inWidth]
|
];
|
return backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
|
}
|
export const depthwiseConv2dNativeConfig = {
|
kernelName: DepthwiseConv2dNative,
|
backendName: 'webgl',
|
kernelFunc: depthwiseConv2dNative,
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRGVwdGh3aXNlQ29udjJkTmF0aXZlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9rZXJuZWxzL0RlcHRod2lzZUNvbnYyZE5hdGl2ZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFFLHFCQUFxQixFQUEyRCxHQUFHLEVBQTRCLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBR3hLLE9BQU8sRUFBQyxzQkFBc0IsRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBQzdELE9BQU8sRUFBQyw0QkFBNEIsRUFBQyxNQUFNLDhCQUE4QixDQUFDO0FBRTFFLE1BQU0sVUFBVSxxQkFBcUIsQ0FBQyxJQUlyQztJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUMzQixNQUFNLEVBQUMsT0FBTyxFQUFFLEdBQUcsRUFBRSxTQUFTLEVBQUUsZUFBZSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRXpELElBQUksVUFBVSxHQUFHLFNBQVMsQ0FBQztJQUMzQixJQUFJLFVBQVUsSUFBSSxJQUFJLEVBQUU7UUFDdEIsVUFBVSxHQUFHLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO0tBQ3JCO0lBRUQsSUFBSSxDQUFDLE1BQU0sQ0FDUCxZQUFZLENBQUMsOEJBQThCLENBQUMsT0FBTyxFQUFFLFVBQVUsQ0FBQyxFQUNoRSxHQUFHLEVBQUUsQ0FBQyxnRUFBZ0U7UUFDbEUsa0JBQWtCLE9BQU8sbUJBQW1CLFVBQVUsR0FBRyxDQUFDLENBQUM7SUFFbkUsTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLGlCQUFpQixDQUMzQyxDQUFDLENBQUMsS0FBeUMsRUFDM0MsTUFBTSxDQUFDLEtBQXlDLEVBQUUsT0FBTyxFQUFFLFVBQVUsRUFDckUsR0FBRyxFQUFFLGVBQWUsRUFBRSxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7SUFFaEQsSUFBSSxPQUE0RCxDQUFDO0lBQ2pFLElBQUksR0FBRyxFQUFFLENBQUMsT0FBTyxDQUFDLDBCQUEwQixDQUFDLElBQUksUUFBUSxDQUFDLFdBQVcsSUFBSSxDQUFDO1FBQ3RFLFFBQVEsQ0FBQyxXQUFXLEdBQUcsUUFBUSxDQUFDLFVBQVUsS0FBSyxDQUFDLEVBQUU7UUFDcEQsT0FBTyxHQUFHLElBQUksNEJBQTRCLENBQUMsUUFBUSxDQUFDLENBQUM7S0FDdEQ7U0FBTTtRQUNMLE9BQU8sR0FBRyxJQUFJLHNCQUFzQixDQUFDLFFBQVEsQ0FBQyxDQUFDO0tBQ2hEO0lBQ0QsTUFBTSxZQUFZLEdBQUc7UUFDbkIsQ0FBQyxRQUFRLENBQUMsT0FBTyxDQUFDLEdBQUcsRUFBRSxRQUFRLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQztRQUM3QyxDQUFDLFFBQVEsQ0FBQyxZQUFZLEVBQUUsUUFBUSxDQUFDLFdBQVcsQ0FBQztRQUM3QyxDQUFDLFFBQVEsQ0FBQyxjQUFjLEVBQUUsUUFBUSxDQUFDLGFBQWEsQ0FBQztRQUNqRCxDQUFDLFFBQVEsQ0FBQyxRQUFRLEVBQUUsUUFBUSxDQUFDLE9BQU8sQ0FBQztLQUN0QyxDQUFDO0lBQ0YsT0FBTyxPQUFPLENBQUMsZUFBZSxDQUFDLE9BQU8sRUFBRSxDQUFDLENBQUMsRUFBRSxNQUFNLENBQUMsRUFBRSxTQUFTLEVBQUUsWUFBWSxDQUFDLENBQUM7QUFDaEYsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLDJCQUEyQixHQUFpQjtJQUN2RCxVQUFVLEVBQUUscUJBQXFCO0lBQ2pDLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxxQkFBOEM7Q0FDM0QsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIERlcHRod2lzZUNvbnYyZE5hdGl2ZSwgRGVwdGh3aXNlQ29udjJkTmF0aXZlQXR0cnMsIERlcHRod2lzZUNvbnYyZE5hdGl2ZUlucHV0cywgZW52LCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5pbXBvcnQge0RlcHRod2lzZUNvbnYyRFByb2dyYW19IGZyb20gJy4uL2NvbnZfZ3B1X2RlcHRod2lzZSc7XG5pbXBvcnQge0RlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW19IGZyb20gJy4uL2NvbnZfcGFja2VkX2dwdV9kZXB0aHdpc2UnO1xuXG5leHBvcnQgZnVuY3Rpb24gZGVwdGh3aXNlQ29udjJkTmF0aXZlKGFyZ3M6IHtcbiAgaW5wdXRzOiBEZXB0aHdpc2VDb252MmROYXRpdmVJbnB1dHMsXG4gIGF0dHJzOiBEZXB0aHdpc2VDb252MmROYXRpdmVBdHRycyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTFxufSkge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eCwgZmlsdGVyfSA9IGlucHV0cztcbiAgY29uc3Qge3N0cmlkZXMsIHBhZCwgZGlsYXRpb25zLCBkaW1Sb3VuZGluZ01vZGV9ID0gYXR0cnM7XG5cbiAgbGV0ICRkaWxhdGlvbnMgPSBkaWxhdGlvbnM7XG4gIGlmICgkZGlsYXRpb25zID09IG51bGwpIHtcbiAgICAkZGlsYXRpb25zID0gWzEsIDFdO1xuICB9XG5cbiAgdXRpbC5hc3NlcnQoXG4gICAgICBiYWNrZW5kX3V0aWwuZWl0aGVyU3RyaWRlc09yRGlsYXRpb25zQXJlT25lKHN0cmlkZXMsICRkaWxhdGlvbnMpLFxuICAgICAgKCkgPT4gJ0Vycm9yIGluIGRlcHRod2lzZUNvbnYyZDogRWl0aGVyIHN0cmlkZXMgb3IgZGlsYXRpb25zIG11c3QgYmUgJyArXG4gICAgICAgICAgYDEuIEdvdCBzdHJpZGVzICR7c3RyaWRlc30gYW5kIGRpbGF0aW9ucyAnJHskZGlsYXRpb25zfSdgKTtcblxuICBjb25zdCBjb252SW5mbyA9IGJhY2tlbmRfdXRpbC5jb21wdXRlQ29udjJESW5mbyhcbiAgICAgIHguc2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sXG4gICAgICBmaWx0ZXIuc2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIHN0cmlkZXMsICRkaWxhdGlvbnMsXG4gICAgICBwYWQsIGRpbVJvdW5kaW5nTW9kZSwgdHJ1ZSAvKiBkZXB0aHdpc2UgKi8pO1xuXG4gIGxldCBwcm9ncmFtOiBEZXB0aHdpc2VDb252MkRQcm9ncmFtfERlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW07XG4gIGlmIChlbnYoKS5nZXRCb29sKCdXRUJHTF9QQUNLX0RFUFRIV0lTRUNPTlYnKSAmJiBjb252SW5mby5zdHJpZGVXaWR0aCA8PSAyICYmXG4gICAgICBjb252SW5mby5vdXRDaGFubmVscyAvIGNvbnZJbmZvLmluQ2hhbm5lbHMgPT09IDEpIHtcbiAgICBwcm9ncmFtID0gbmV3IERlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW0oY29udkluZm8pO1xuICB9IGVsc2Uge1xuICAgIHByb2dyYW0gPSBuZXcgRGVwdGh3aXNlQ29udjJEUHJvZ3JhbShjb252SW5mbyk7XG4gIH1cbiAgY29uc3QgY3VzdG9tVmFsdWVzID0gW1xuICAgIFtjb252SW5mby5wYWRJbmZvLnRvcCwgY29udkluZm8ucGFkSW5mby5sZWZ0XSxcbiAgICBbY29udkluZm8uc3RyaWRlSGVpZ2h0LCBjb252SW5mby5zdHJpZGVXaWR0aF0sXG4gICAgW2NvbnZJbmZvLmRpbGF0aW9uSGVpZ2h0LCBjb252SW5mby5kaWxhdGlvbldpZHRoXSxcbiAgICBbY29udkluZm8uaW5IZWlnaHQsIGNvbnZJbmZvLmluV2lkdGhdXG4gIF07XG4gIHJldHVybiBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShwcm9ncmFtLCBbeCwgZmlsdGVyXSwgJ2Zsb2F0MzInLCBjdXN0b21WYWx1ZXMpO1xufVxuXG5leHBvcnQgY29uc3QgZGVwdGh3aXNlQ29udjJkTmF0aXZlQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IERlcHRod2lzZUNvbnYyZE5hdGl2ZSxcbiAgYmFja2VuZE5hbWU6ICd3ZWJnbCcsXG4gIGtlcm5lbEZ1bmM6IGRlcHRod2lzZUNvbnYyZE5hdGl2ZSBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmMsXG59O1xuIl19
|