/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { FusedConv2D } from '@tensorflow/tfjs-core';
|
import { applyActivation } from '../utils/fused_utils';
|
import { add } from './Add';
|
import { conv2D } from './Conv2D';
|
import { reshape } from './Reshape';
|
export function fusedConv2D(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
|
let result = conv2D({
|
inputs: { x, filter },
|
backend,
|
attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
|
});
|
if (bias) {
|
const resultOld = result;
|
// For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
|
// to the channel of the conv2d's result; if the bias is a scalar, the
|
// bias_add is computed as if the bias was broadcasted to the shape of the
|
// conv2d's result.
|
if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
|
bias.shape[0] !== 1) {
|
const reshapedBias = reshape({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
|
result =
|
add({ inputs: { a: result, b: reshapedBias }, backend });
|
backend.disposeIntermediateTensorInfo(reshapedBias);
|
}
|
else {
|
// This condition handles NHWC and NCHW (scalar case). The only other case
|
// for NCHW (1D case) is handled above.
|
result = add({ inputs: { a: result, b: bias }, backend });
|
}
|
backend.disposeIntermediateTensorInfo(resultOld);
|
}
|
if (activation) {
|
const resultOld = result;
|
// For NCHW format, if PReLu activation weights is a 1-D tensor, it is
|
// supposed to be aligned with the channel of the conv2d's result. For other
|
// cases, whether NCHW or NHWC data format, the conv2d result is
|
// already aligned with the activation weights.
|
if (dataFormat === 'NCHW' && activation === 'prelu' &&
|
preluActivationWeights.shape.length === 1 &&
|
preluActivationWeights.shape[0] !== 1) {
|
const reshapedAlpha = reshape({
|
inputs: { x: preluActivationWeights },
|
backend,
|
attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
|
});
|
result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
|
backend.disposeIntermediateTensorInfo(reshapedAlpha);
|
}
|
else {
|
result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
|
}
|
backend.disposeIntermediateTensorInfo(resultOld);
|
}
|
return result;
|
}
|
export const fusedConv2DConfig = {
|
kernelName: FusedConv2D,
|
backendName: 'cpu',
|
kernelFunc: fusedConv2D
|
};
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"FusedConv2D.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/FusedConv2D.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,WAAW,EAA4E,MAAM,uBAAuB,CAAC;AAG7H,OAAO,EAAC,eAAe,EAAC,MAAM,sBAAsB,CAAC;AACrD,OAAO,EAAC,GAAG,EAAC,MAAM,OAAO,CAAC;AAC1B,OAAO,EAAC,MAAM,EAAC,MAAM,UAAU,CAAC;AAChC,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC,MAAM,UAAU,WAAW,CAAC,IAI3B;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,sBAAsB,EAAC,GAAG,MAAM,CAAC;IACzD,MAAM,EACJ,OAAO,EACP,GAAG,EACH,UAAU,EACV,SAAS,EACT,eAAe,EACf,UAAU,EACV,cAAc,EACf,GAAG,KAAK,CAAC;IAEV,IAAI,MAAM,GAAG,MAAM,CAAC;QAClB,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC;QACnB,OAAO;QACP,KAAK,EAAE,EAAC,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,SAAS,EAAE,eAAe,EAAC;KAC9D,CAAC,CAAC;IAEH,IAAI,IAAI,EAAE;QACR,MAAM,SAAS,GAAG,MAAM,CAAC;QACzB,yEAAyE;QACzE,sEAAsE;QACtE,0EAA0E;QAC1E,mBAAmB;QACnB,IAAI,UAAU,KAAK,MAAM,IAAI,IAAI,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC;YAChD,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE;YACvB,MAAM,YAAY,GAAG,OAAO,CACxB,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,IAAI,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAC,EAAC,CAAC,CAAC;YACzE,MAAM;gBACF,GAAG,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,YAAY,EAAC,EAAE,OAAO,EAAC,CAAe,CAAC;YACvE,OAAO,CAAC,6BAA6B,CAAC,YAAY,CAAC,CAAC;SACrD;aAAM;YACL,0EAA0E;YAC1E,uCAAuC;YACvC,MAAM,GAAG,GAAG,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,IAAI,EAAC,EAAE,OAAO,EAAC,CAAe,CAAC;SACrE;QACD,OAAO,CAAC,6BAA6B,CAAC,SAAS,CAAC,CAAC;KAClD;IAED,IAAI,UAAU,EAAE;QACd,MAAM,SAAS,GAAG,MAAM,CAAC;QACzB,sEAAsE;QACtE,4EAA4E;QAC5E,gEAAgE;QAChE,+CAA+C;QAC/C,IAAI,UAAU,KAAK,MAAM,IAAI,UAAU,KAAK,OAAO;YAC/C,sBAAsB,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC;YACzC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE;YACzC,MAAM,aAAa,GAAG,OAAO,CAAC;gBAC5B,MAAM,EAAE,EAAC,CAAC,EAAE,sBAAsB,EAAC;gBACnC,OAAO;gBACP,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAC;aACxD,CAAC,CAAC;YACH,MAAM,GAAG,eAAe,CACpB,OAAO,EAAE,MAAM,EAAE,UAAU,EAAE,aAAa,EAAE,cAAc,CAAC,CAAC;YAChE,OAAO,CAAC,6BAA6B,CAAC,aAAa,CAAC,CAAC;SACtD;aAAM;YACL,MAAM,GAAG,eAAe,CACpB,OAAO,EAAE,MAAM,EAAE,UAAU,EAAE,sBAAsB,EAAE,cAAc,CAAC,CAAC;SAC1E;QACD,OAAO,CAAC,6BAA6B,CAAC,SAAS,CAAC,CAAC;KAClD;IAED,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,MAAM,CAAC,MAAM,iBAAiB,GAAiB;IAC7C,UAAU,EAAE,WAAW;IACvB,WAAW,EAAE,KAAK;IAClB,UAAU,EAAE,WAAoC;CACjD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {FusedConv2D, FusedConv2DAttrs, FusedConv2DInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {applyActivation} from '../utils/fused_utils';\nimport {add} from './Add';\nimport {conv2D} from './Conv2D';\nimport {reshape} from './Reshape';\n\nexport function fusedConv2D(args: {\n  inputs: FusedConv2DInputs,\n  backend: MathBackendCPU,\n  attrs: FusedConv2DAttrs\n}): TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {x, filter, bias, preluActivationWeights} = inputs;\n  const {\n    strides,\n    pad,\n    dataFormat,\n    dilations,\n    dimRoundingMode,\n    activation,\n    leakyreluAlpha\n  } = attrs;\n\n  let result = conv2D({\n    inputs: {x, filter},\n    backend,\n    attrs: {strides, pad, dataFormat, dilations, dimRoundingMode}\n  });\n\n  if (bias) {\n    const resultOld = result;\n    // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned\n    // to the channel of the conv2d's result; if the bias is a scalar, the\n    // bias_add is computed as if the bias was broadcasted to the shape of the\n    // conv2d's result.\n    if (dataFormat === 'NCHW' && bias.shape.length === 1 &&\n        bias.shape[0] !== 1) {\n      const reshapedBias = reshape(\n          {inputs: {x: bias}, backend, attrs: {shape: [bias.shape[0], 1, 1]}});\n      result =\n          add({inputs: {a: result, b: reshapedBias}, backend}) as TensorInfo;\n      backend.disposeIntermediateTensorInfo(reshapedBias);\n    } else {\n      // This condition handles NHWC and NCHW (scalar case). The only other case\n      // for NCHW (1D case) is handled above.\n      result = add({inputs: {a: result, b: bias}, backend}) as TensorInfo;\n    }\n    backend.disposeIntermediateTensorInfo(resultOld);\n  }\n\n  if (activation) {\n    const resultOld = result;\n    // For NCHW format, if PReLu activation weights is a 1-D tensor, it is\n    // supposed to be aligned with the channel of the conv2d's result. For other\n    // cases, whether NCHW or NHWC data format, the conv2d result is\n    // already aligned with the activation weights.\n    if (dataFormat === 'NCHW' && activation === 'prelu' &&\n        preluActivationWeights.shape.length === 1 &&\n        preluActivationWeights.shape[0] !== 1) {\n      const reshapedAlpha = reshape({\n        inputs: {x: preluActivationWeights},\n        backend,\n        attrs: {shape: [preluActivationWeights.shape[0], 1, 1]}\n      });\n      result = applyActivation(\n          backend, result, activation, reshapedAlpha, leakyreluAlpha);\n      backend.disposeIntermediateTensorInfo(reshapedAlpha);\n    } else {\n      result = applyActivation(\n          backend, result, activation, preluActivationWeights, leakyreluAlpha);\n    }\n    backend.disposeIntermediateTensorInfo(resultOld);\n  }\n\n  return result;\n}\n\nexport const fusedConv2DConfig: KernelConfig = {\n  kernelName: FusedConv2D,\n  backendName: 'cpu',\n  kernelFunc: fusedConv2D as unknown as KernelFunc\n};\n"]}
|