/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { FusedConv2D } from '@tensorflow/tfjs-core'; import { applyActivation } from '../utils/fused_utils'; import { add } from './Add'; import { conv2D } from './Conv2D'; import { reshape } from './Reshape'; export function fusedConv2D(args) { const { inputs, backend, attrs } = args; const { x, filter, bias, preluActivationWeights } = inputs; const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs; let result = conv2D({ inputs: { x, filter }, backend, attrs: { strides, pad, dataFormat, dilations, dimRoundingMode } }); if (bias) { const resultOld = result; // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned // to the channel of the conv2d's result; if the bias is a scalar, the // bias_add is computed as if the bias was broadcasted to the shape of the // conv2d's result. if (dataFormat === 'NCHW' && bias.shape.length === 1 && bias.shape[0] !== 1) { const reshapedBias = reshape({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } }); result = add({ inputs: { a: result, b: reshapedBias }, backend }); backend.disposeIntermediateTensorInfo(reshapedBias); } else { // This condition handles NHWC and NCHW (scalar case). The only other case // for NCHW (1D case) is handled above. result = add({ inputs: { a: result, b: bias }, backend }); } backend.disposeIntermediateTensorInfo(resultOld); } if (activation) { const resultOld = result; // For NCHW format, if PReLu activation weights is a 1-D tensor, it is // supposed to be aligned with the channel of the conv2d's result. For other // cases, whether NCHW or NHWC data format, the conv2d result is // already aligned with the activation weights. if (dataFormat === 'NCHW' && activation === 'prelu' && preluActivationWeights.shape.length === 1 && preluActivationWeights.shape[0] !== 1) { const reshapedAlpha = reshape({ inputs: { x: preluActivationWeights }, backend, attrs: { shape: [preluActivationWeights.shape[0], 1, 1] } }); result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha); backend.disposeIntermediateTensorInfo(reshapedAlpha); } else { result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha); } backend.disposeIntermediateTensorInfo(resultOld); } return result; } export const fusedConv2DConfig = { kernelName: FusedConv2D, backendName: 'cpu', kernelFunc: fusedConv2D }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRnVzZWRDb252MkQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL0Z1c2VkQ29udjJELnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxXQUFXLEVBQTRFLE1BQU0sdUJBQXVCLENBQUM7QUFHN0gsT0FBTyxFQUFDLGVBQWUsRUFBQyxNQUFNLHNCQUFzQixDQUFDO0FBQ3JELE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFVBQVUsQ0FBQztBQUNoQyxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBRWxDLE1BQU0sVUFBVSxXQUFXLENBQUMsSUFJM0I7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUUsSUFBSSxFQUFFLHNCQUFzQixFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ3pELE1BQU0sRUFDSixPQUFPLEVBQ1AsR0FBRyxFQUNILFVBQVUsRUFDVixTQUFTLEVBQ1QsZUFBZSxFQUNmLFVBQVUsRUFDVixjQUFjLEVBQ2YsR0FBRyxLQUFLLENBQUM7SUFFVixJQUFJLE1BQU0sR0FBRyxNQUFNLENBQUM7UUFDbEIsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQztRQUNuQixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsT0FBTyxFQUFFLEdBQUcsRUFBRSxVQUFVLEVBQUUsU0FBUyxFQUFFLGVBQWUsRUFBQztLQUM5RCxDQUFDLENBQUM7SUFFSCxJQUFJLElBQUksRUFBRTtRQUNSLE1BQU0sU0FBUyxHQUFHLE1BQU0sQ0FBQztRQUN6Qix5RUFBeUU7UUFDekUsc0VBQXNFO1FBQ3RFLDBFQUEwRTtRQUMxRSxtQkFBbUI7UUFDbkIsSUFBSSxVQUFVLEtBQUssTUFBTSxJQUFJLElBQUksQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUM7WUFDaEQsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7WUFDdkIsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUN4QixFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxJQUFJLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUMsRUFBQyxDQUFDLENBQUM7WUFDekUsTUFBTTtnQkFDRixHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsTUFBTSxFQUFFLENBQUMsRUFBRSxZQUFZLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBZSxDQUFDO1lBQ3ZFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxZQUFZLENBQUMsQ0FBQztTQUNyRDthQUFNO1lBQ0wsMEVBQTBFO1lBQzFFLHVDQUF1QztZQUN2QyxNQUFNLEdBQUcsR0FBRyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBRSxDQUFDLEVBQUUsSUFBSSxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztTQUNyRTtRQUNELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxTQUFTLENBQUMsQ0FBQztLQUNsRDtJQUVELElBQUksVUFBVSxFQUFFO1FBQ2QsTUFBTSxTQUFTLEdBQUcsTUFBTSxDQUFDO1FBQ3pCLHNFQUFzRTtRQUN0RSw0RUFBNEU7UUFDNUUsZ0VBQWdFO1FBQ2hFLCtDQUErQztRQUMvQyxJQUFJLFVBQVUsS0FBSyxNQUFNLElBQUksVUFBVSxLQUFLLE9BQU87WUFDL0Msc0JBQXNCLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxDQUFDO1lBQ3pDLHNCQUFzQixDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7WUFDekMsTUFBTSxhQUFhLEdBQUcsT0FBTyxDQUFDO2dCQUM1QixNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsc0JBQXNCLEVBQUM7Z0JBQ25DLE9BQU87Z0JBQ1AsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLENBQUMsc0JBQXNCLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBQzthQUN4RCxDQUFDLENBQUM7WUFDSCxNQUFNLEdBQUcsZUFBZSxDQUNwQixPQUFPLEVBQUUsTUFBTSxFQUFFLFVBQVUsRUFBRSxhQUFhLEVBQUUsY0FBYyxDQUFDLENBQUM7WUFDaEUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLGFBQWEsQ0FBQyxDQUFDO1NBQ3REO2FBQU07WUFDTCxNQUFNLEdBQUcsZUFBZSxDQUNwQixPQUFPLEVBQUUsTUFBTSxFQUFFLFVBQVUsRUFBRSxzQkFBc0IsRUFBRSxjQUFjLENBQUMsQ0FBQztTQUMxRTtRQUNELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxTQUFTLENBQUMsQ0FBQztLQUNsRDtJQUVELE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxpQkFBaUIsR0FBaUI7SUFDN0MsVUFBVSxFQUFFLFdBQVc7SUFDdkIsV0FBVyxFQUFFLEtBQUs7SUFDbEIsVUFBVSxFQUFFLFdBQW9DO0NBQ2pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7RnVzZWRDb252MkQsIEZ1c2VkQ29udjJEQXR0cnMsIEZ1c2VkQ29udjJESW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm99IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXBwbHlBY3RpdmF0aW9ufSBmcm9tICcuLi91dGlscy9mdXNlZF91dGlscyc7XG5pbXBvcnQge2FkZH0gZnJvbSAnLi9BZGQnO1xuaW1wb3J0IHtjb252MkR9IGZyb20gJy4vQ29udjJEJztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcblxuZXhwb3J0IGZ1bmN0aW9uIGZ1c2VkQ29udjJEKGFyZ3M6IHtcbiAgaW5wdXRzOiBGdXNlZENvbnYyRElucHV0cyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRDUFUsXG4gIGF0dHJzOiBGdXNlZENvbnYyREF0dHJzXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4LCBmaWx0ZXIsIGJpYXMsIHByZWx1QWN0aXZhdGlvbldlaWdodHN9ID0gaW5wdXRzO1xuICBjb25zdCB7XG4gICAgc3RyaWRlcyxcbiAgICBwYWQsXG4gICAgZGF0YUZvcm1hdCxcbiAgICBkaWxhdGlvbnMsXG4gICAgZGltUm91bmRpbmdNb2RlLFxuICAgIGFjdGl2YXRpb24sXG4gICAgbGVha3lyZWx1QWxwaGFcbiAgfSA9IGF0dHJzO1xuXG4gIGxldCByZXN1bHQgPSBjb252MkQoe1xuICAgIGlucHV0czoge3gsIGZpbHRlcn0sXG4gICAgYmFja2VuZCxcbiAgICBhdHRyczoge3N0cmlkZXMsIHBhZCwgZGF0YUZvcm1hdCwgZGlsYXRpb25zLCBkaW1Sb3VuZGluZ01vZGV9XG4gIH0pO1xuXG4gIGlmIChiaWFzKSB7XG4gICAgY29uc3QgcmVzdWx0T2xkID0gcmVzdWx0O1xuICAgIC8vIEZvciBOQ0hXIGZvcm1hdCwgaWYgYmlhcyBpcyBhIDEtRCB0ZW5zb3IsIGl0IGlzIHN1cHBvc2VkIHRvIGJlIGFsaWduZWRcbiAgICAvLyB0byB0aGUgY2hhbm5lbCBvZiB0aGUgY29udjJkJ3MgcmVzdWx0OyBpZiB0aGUgYmlhcyBpcyBhIHNjYWxhciwgdGhlXG4gICAgLy8gYmlhc19hZGQgaXMgY29tcHV0ZWQgYXMgaWYgdGhlIGJpYXMgd2FzIGJyb2FkY2FzdGVkIHRvIHRoZSBzaGFwZSBvZiB0aGVcbiAgICAvLyBjb252MmQncyByZXN1bHQuXG4gICAgaWYgKGRhdGFGb3JtYXQgPT09ICdOQ0hXJyAmJiBiaWFzLnNoYXBlLmxlbmd0aCA9PT0gMSAmJlxuICAgICAgICBiaWFzLnNoYXBlWzBdICE9PSAxKSB7XG4gICAgICBjb25zdCByZXNoYXBlZEJpYXMgPSByZXNoYXBlKFxuICAgICAgICAgIHtpbnB1dHM6IHt4OiBiaWFzfSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogW2JpYXMuc2hhcGVbMF0sIDEsIDFdfX0pO1xuICAgICAgcmVzdWx0ID1cbiAgICAgICAgICBhZGQoe2lucHV0czoge2E6IHJlc3VsdCwgYjogcmVzaGFwZWRCaWFzfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc2hhcGVkQmlhcyk7XG4gICAgfSBlbHNlIHtcbiAgICAgIC8vIFRoaXMgY29uZGl0aW9uIGhhbmRsZXMgTkhXQyBhbmQgTkNIVyAoc2NhbGFyIGNhc2UpLiBUaGUgb25seSBvdGhlciBjYXNlXG4gICAgICAvLyBmb3IgTkNIVyAoMUQgY2FzZSkgaXMgaGFuZGxlZCBhYm92ZS5cbiAgICAgIHJlc3VsdCA9IGFkZCh7aW5wdXRzOiB7YTogcmVzdWx0LCBiOiBiaWFzfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgfVxuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ocmVzdWx0T2xkKTtcbiAgfVxuXG4gIGlmIChhY3RpdmF0aW9uKSB7XG4gICAgY29uc3QgcmVzdWx0T2xkID0gcmVzdWx0O1xuICAgIC8vIEZvciBOQ0hXIGZvcm1hdCwgaWYgUFJlTHUgYWN0aXZhdGlvbiB3ZWlnaHRzIGlzIGEgMS1EIHRlbnNvciwgaXQgaXNcbiAgICAvLyBzdXBwb3NlZCB0byBiZSBhbGlnbmVkIHdpdGggdGhlIGNoYW5uZWwgb2YgdGhlIGNvbnYyZCdzIHJlc3VsdC4gRm9yIG90aGVyXG4gICAgLy8gY2FzZXMsIHdoZXRoZXIgTkNIVyBvciBOSFdDIGRhdGEgZm9ybWF0LCB0aGUgY29udjJkIHJlc3VsdCBpc1xuICAgIC8vIGFscmVhZHkgYWxpZ25lZCB3aXRoIHRoZSBhY3RpdmF0aW9uIHdlaWdodHMuXG4gICAgaWYgKGRhdGFGb3JtYXQgPT09ICdOQ0hXJyAmJiBhY3RpdmF0aW9uID09PSAncHJlbHUnICYmXG4gICAgICAgIHByZWx1QWN0aXZhdGlvbldlaWdodHMuc2hhcGUubGVuZ3RoID09PSAxICYmXG4gICAgICAgIHByZWx1QWN0aXZhdGlvbldlaWdodHMuc2hhcGVbMF0gIT09IDEpIHtcbiAgICAgIGNvbnN0IHJlc2hhcGVkQWxwaGEgPSByZXNoYXBlKHtcbiAgICAgICAgaW5wdXRzOiB7eDogcHJlbHVBY3RpdmF0aW9uV2VpZ2h0c30sXG4gICAgICAgIGJhY2tlbmQsXG4gICAgICAgIGF0dHJzOiB7c2hhcGU6IFtwcmVsdUFjdGl2YXRpb25XZWlnaHRzLnNoYXBlWzBdLCAxLCAxXX1cbiAgICAgIH0pO1xuICAgICAgcmVzdWx0ID0gYXBwbHlBY3RpdmF0aW9uKFxuICAgICAgICAgIGJhY2tlbmQsIHJlc3VsdCwgYWN0aXZhdGlvbiwgcmVzaGFwZWRBbHBoYSwgbGVha3lyZWx1QWxwaGEpO1xuICAgICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhyZXNoYXBlZEFscGhhKTtcbiAgICB9IGVsc2Uge1xuICAgICAgcmVzdWx0ID0gYXBwbHlBY3RpdmF0aW9uKFxuICAgICAgICAgIGJhY2tlbmQsIHJlc3VsdCwgYWN0aXZhdGlvbiwgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0cywgbGVha3lyZWx1QWxwaGEpO1xuICAgIH1cbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc3VsdE9sZCk7XG4gIH1cblxuICByZXR1cm4gcmVzdWx0O1xufVxuXG5leHBvcnQgY29uc3QgZnVzZWRDb252MkRDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogRnVzZWRDb252MkQsXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogZnVzZWRDb252MkQgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19