/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { FusedConv2D } from '@tensorflow/tfjs-core';
|
import { applyActivation } from '../utils/fused_utils';
|
import { add } from './Add';
|
import { conv2D } from './Conv2D';
|
import { reshape } from './Reshape';
|
export function fusedConv2D(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
|
let result = conv2D({
|
inputs: { x, filter },
|
backend,
|
attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
|
});
|
if (bias) {
|
const resultOld = result;
|
// For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
|
// to the channel of the conv2d's result; if the bias is a scalar, the
|
// bias_add is computed as if the bias was broadcasted to the shape of the
|
// conv2d's result.
|
if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
|
bias.shape[0] !== 1) {
|
const reshapedBias = reshape({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
|
result =
|
add({ inputs: { a: result, b: reshapedBias }, backend });
|
backend.disposeIntermediateTensorInfo(reshapedBias);
|
}
|
else {
|
// This condition handles NHWC and NCHW (scalar case). The only other case
|
// for NCHW (1D case) is handled above.
|
result = add({ inputs: { a: result, b: bias }, backend });
|
}
|
backend.disposeIntermediateTensorInfo(resultOld);
|
}
|
if (activation) {
|
const resultOld = result;
|
// For NCHW format, if PReLu activation weights is a 1-D tensor, it is
|
// supposed to be aligned with the channel of the conv2d's result. For other
|
// cases, whether NCHW or NHWC data format, the conv2d result is
|
// already aligned with the activation weights.
|
if (dataFormat === 'NCHW' && activation === 'prelu' &&
|
preluActivationWeights.shape.length === 1 &&
|
preluActivationWeights.shape[0] !== 1) {
|
const reshapedAlpha = reshape({
|
inputs: { x: preluActivationWeights },
|
backend,
|
attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
|
});
|
result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
|
backend.disposeIntermediateTensorInfo(reshapedAlpha);
|
}
|
else {
|
result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
|
}
|
backend.disposeIntermediateTensorInfo(resultOld);
|
}
|
return result;
|
}
|
export const fusedConv2DConfig = {
|
kernelName: FusedConv2D,
|
backendName: 'cpu',
|
kernelFunc: fusedConv2D
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRnVzZWRDb252MkQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL0Z1c2VkQ29udjJELnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxXQUFXLEVBQTRFLE1BQU0sdUJBQXVCLENBQUM7QUFHN0gsT0FBTyxFQUFDLGVBQWUsRUFBQyxNQUFNLHNCQUFzQixDQUFDO0FBQ3JELE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFVBQVUsQ0FBQztBQUNoQyxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBRWxDLE1BQU0sVUFBVSxXQUFXLENBQUMsSUFJM0I7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUUsSUFBSSxFQUFFLHNCQUFzQixFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ3pELE1BQU0sRUFDSixPQUFPLEVBQ1AsR0FBRyxFQUNILFVBQVUsRUFDVixTQUFTLEVBQ1QsZUFBZSxFQUNmLFVBQVUsRUFDVixjQUFjLEVBQ2YsR0FBRyxLQUFLLENBQUM7SUFFVixJQUFJLE1BQU0sR0FBRyxNQUFNLENBQUM7UUFDbEIsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQztRQUNuQixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsT0FBTyxFQUFFLEdBQUcsRUFBRSxVQUFVLEVBQUUsU0FBUyxFQUFFLGVBQWUsRUFBQztLQUM5RCxDQUFDLENBQUM7SUFFSCxJQUFJLElBQUksRUFBRTtRQUNSLE1BQU0sU0FBUyxHQUFHLE1BQU0sQ0FBQztRQUN6Qix5RUFBeUU7UUFDekUsc0VBQXNFO1FBQ3RFLDBFQUEwRTtRQUMxRSxtQkFBbUI7UUFDbkIsSUFBSSxVQUFVLEtBQUssTUFBTSxJQUFJLElBQUksQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUM7WUFDaEQsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7WUFDdkIsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUN4QixFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxJQUFJLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUMsRUFBQyxDQUFDLENBQUM7WUFDekUsTUFBTTtnQkFDRixHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsTUFBTSxFQUFFLENBQUMsRUFBRSxZQUFZLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBZSxDQUFDO1lBQ3ZFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxZQUFZLENBQUMsQ0FBQztTQUNyRDthQUFNO1lBQ0wsMEVBQTBFO1lBQzFFLHVDQUF1QztZQUN2QyxNQUFNLEdBQUcsR0FBRyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBRSxDQUFDLEVBQUUsSUFBSSxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztTQUNyRTtRQUNELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxTQUFTLENBQUMsQ0FBQztLQUNsRDtJQUVELElBQUksVUFBVSxFQUFFO1FBQ2QsTUFBTSxTQUFTLEdBQUcsTUFBTSxDQUFDO1FBQ3pCLHNFQUFzRTtRQUN0RSw0RUFBNEU7UUFDNUUsZ0VBQWdFO1FBQ2hFLCtDQUErQztRQUMvQyxJQUFJLFVBQVUsS0FBSyxNQUFNLElBQUksVUFBVSxLQUFLLE9BQU87WUFDL0Msc0JBQXNCLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxDQUFDO1lBQ3pDLHNCQUFzQixDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7WUFDekMsTUFBTSxhQUFhLEdBQUcsT0FBTyxDQUFDO2dCQUM1QixNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsc0JBQXNCLEVBQUM7Z0JBQ25DLE9BQU87Z0JBQ1AsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLENBQUMsc0JBQXNCLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBQzthQUN4RCxDQUFDLENBQUM7WUFDSCxNQUFNLEdBQUcsZUFBZSxDQUNwQixPQUFPLEVBQUUsTUFBTSxFQUFFLFVBQVUsRUFBRSxhQUFhLEVBQUUsY0FBYyxDQUFDLENBQUM7WUFDaEUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLGFBQWEsQ0FBQyxDQUFDO1NBQ3REO2FBQU07WUFDTCxNQUFNLEdBQUcsZUFBZSxDQUNwQixPQUFPLEVBQUUsTUFBTSxFQUFFLFVBQVUsRUFBRSxzQkFBc0IsRUFBRSxjQUFjLENBQUMsQ0FBQztTQUMxRTtRQUNELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxTQUFTLENBQUMsQ0FBQztLQUNsRDtJQUVELE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxpQkFBaUIsR0FBaUI7SUFDN0MsVUFBVSxFQUFFLFdBQVc7SUFDdkIsV0FBVyxFQUFFLEtBQUs7SUFDbEIsVUFBVSxFQUFFLFdBQW9DO0NBQ2pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7RnVzZWRDb252MkQsIEZ1c2VkQ29udjJEQXR0cnMsIEZ1c2VkQ29udjJESW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm99IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXBwbHlBY3RpdmF0aW9ufSBmcm9tICcuLi91dGlscy9mdXNlZF91dGlscyc7XG5pbXBvcnQge2FkZH0gZnJvbSAnLi9BZGQnO1xuaW1wb3J0IHtjb252MkR9IGZyb20gJy4vQ29udjJEJztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcblxuZXhwb3J0IGZ1bmN0aW9uIGZ1c2VkQ29udjJEKGFyZ3M6IHtcbiAgaW5wdXRzOiBGdXNlZENvbnYyRElucHV0cyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRDUFUsXG4gIGF0dHJzOiBGdXNlZENvbnYyREF0dHJzXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4LCBmaWx0ZXIsIGJpYXMsIHByZWx1QWN0aXZhdGlvbldlaWdodHN9ID0gaW5wdXRzO1xuICBjb25zdCB7XG4gICAgc3RyaWRlcyxcbiAgICBwYWQsXG4gICAgZGF0YUZvcm1hdCxcbiAgICBkaWxhdGlvbnMsXG4gICAgZGltUm91bmRpbmdNb2RlLFxuICAgIGFjdGl2YXRpb24sXG4gICAgbGVha3lyZWx1QWxwaGFcbiAgfSA9IGF0dHJzO1xuXG4gIGxldCByZXN1bHQgPSBjb252MkQoe1xuICAgIGlucHV0czoge3gsIGZpbHRlcn0sXG4gICAgYmFja2VuZCxcbiAgICBhdHRyczoge3N0cmlkZXMsIHBhZCwgZGF0YUZvcm1hdCwgZGlsYXRpb25zLCBkaW1Sb3VuZGluZ01vZGV9XG4gIH0pO1xuXG4gIGlmIChiaWFzKSB7XG4gICAgY29uc3QgcmVzdWx0T2xkID0gcmVzdWx0O1xuICAgIC8vIEZvciBOQ0hXIGZvcm1hdCwgaWYgYmlhcyBpcyBhIDEtRCB0ZW5zb3IsIGl0IGlzIHN1cHBvc2VkIHRvIGJlIGFsaWduZWRcbiAgICAvLyB0byB0aGUgY2hhbm5lbCBvZiB0aGUgY29udjJkJ3MgcmVzdWx0OyBpZiB0aGUgYmlhcyBpcyBhIHNjYWxhciwgdGhlXG4gICAgLy8gYmlhc19hZGQgaXMgY29tcHV0ZWQgYXMgaWYgdGhlIGJpYXMgd2FzIGJyb2FkY2FzdGVkIHRvIHRoZSBzaGFwZSBvZiB0aGVcbiAgICAvLyBjb252MmQncyByZXN1bHQuXG4gICAgaWYgKGRhdGFGb3JtYXQgPT09ICdOQ0hXJyAmJiBiaWFzLnNoYXBlLmxlbmd0aCA9PT0gMSAmJlxuICAgICAgICBiaWFzLnNoYXBlWzBdICE9PSAxKSB7XG4gICAgICBjb25zdCByZXNoYXBlZEJpYXMgPSByZXNoYXBlKFxuICAgICAgICAgIHtpbnB1dHM6IHt4OiBiaWFzfSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogW2JpYXMuc2hhcGVbMF0sIDEsIDFdfX0pO1xuICAgICAgcmVzdWx0ID1cbiAgICAgICAgICBhZGQoe2lucHV0czoge2E6IHJlc3VsdCwgYjogcmVzaGFwZWRCaWFzfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc2hhcGVkQmlhcyk7XG4gICAgfSBlbHNlIHtcbiAgICAgIC8vIFRoaXMgY29uZGl0aW9uIGhhbmRsZXMgTkhXQyBhbmQgTkNIVyAoc2NhbGFyIGNhc2UpLiBUaGUgb25seSBvdGhlciBjYXNlXG4gICAgICAvLyBmb3IgTkNIVyAoMUQgY2FzZSkgaXMgaGFuZGxlZCBhYm92ZS5cbiAgICAgIHJlc3VsdCA9IGFkZCh7aW5wdXRzOiB7YTogcmVzdWx0LCBiOiBiaWFzfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgfVxuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ocmVzdWx0T2xkKTtcbiAgfVxuXG4gIGlmIChhY3RpdmF0aW9uKSB7XG4gICAgY29uc3QgcmVzdWx0T2xkID0gcmVzdWx0O1xuICAgIC8vIEZvciBOQ0hXIGZvcm1hdCwgaWYgUFJlTHUgYWN0aXZhdGlvbiB3ZWlnaHRzIGlzIGEgMS1EIHRlbnNvciwgaXQgaXNcbiAgICAvLyBzdXBwb3NlZCB0byBiZSBhbGlnbmVkIHdpdGggdGhlIGNoYW5uZWwgb2YgdGhlIGNvbnYyZCdzIHJlc3VsdC4gRm9yIG90aGVyXG4gICAgLy8gY2FzZXMsIHdoZXRoZXIgTkNIVyBvciBOSFdDIGRhdGEgZm9ybWF0LCB0aGUgY29udjJkIHJlc3VsdCBpc1xuICAgIC8vIGFscmVhZHkgYWxpZ25lZCB3aXRoIHRoZSBhY3RpdmF0aW9uIHdlaWdodHMuXG4gICAgaWYgKGRhdGFGb3JtYXQgPT09ICdOQ0hXJyAmJiBhY3RpdmF0aW9uID09PSAncHJlbHUnICYmXG4gICAgICAgIHByZWx1QWN0aXZhdGlvbldlaWdodHMuc2hhcGUubGVuZ3RoID09PSAxICYmXG4gICAgICAgIHByZWx1QWN0aXZhdGlvbldlaWdodHMuc2hhcGVbMF0gIT09IDEpIHtcbiAgICAgIGNvbnN0IHJlc2hhcGVkQWxwaGEgPSByZXNoYXBlKHtcbiAgICAgICAgaW5wdXRzOiB7eDogcHJlbHVBY3RpdmF0aW9uV2VpZ2h0c30sXG4gICAgICAgIGJhY2tlbmQsXG4gICAgICAgIGF0dHJzOiB7c2hhcGU6IFtwcmVsdUFjdGl2YXRpb25XZWlnaHRzLnNoYXBlWzBdLCAxLCAxXX1cbiAgICAgIH0pO1xuICAgICAgcmVzdWx0ID0gYXBwbHlBY3RpdmF0aW9uKFxuICAgICAgICAgIGJhY2tlbmQsIHJlc3VsdCwgYWN0aXZhdGlvbiwgcmVzaGFwZWRBbHBoYSwgbGVha3lyZWx1QWxwaGEpO1xuICAgICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhyZXNoYXBlZEFscGhhKTtcbiAgICB9IGVsc2Uge1xuICAgICAgcmVzdWx0ID0gYXBwbHlBY3RpdmF0aW9uKFxuICAgICAgICAgIGJhY2tlbmQsIHJlc3VsdCwgYWN0aXZhdGlvbiwgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0cywgbGVha3lyZWx1QWxwaGEpO1xuICAgIH1cbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHJlc3VsdE9sZCk7XG4gIH1cblxuICByZXR1cm4gcmVzdWx0O1xufVxuXG5leHBvcnQgY29uc3QgZnVzZWRDb252MkRDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogRnVzZWRDb252MkQsXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogZnVzZWRDb252MkQgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19
|