/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from '../../engine';
|
import { customGrad } from '../../gradients';
|
import { FusedConv2D } from '../../kernel_names';
|
import { makeTypesMatch } from '../../tensor_util';
|
import { convertToTensor } from '../../tensor_util_env';
|
import * as util from '../../util';
|
import { add } from '../add';
|
import * as broadcast_util from '../broadcast_util';
|
import { conv2d as unfusedConv2d } from '../conv2d';
|
import { conv2DBackpropFilter } from '../conv2d_backprop_filter';
|
import { conv2DBackpropInput } from '../conv2d_backprop_input';
|
import * as conv_util from '../conv_util';
|
import { applyActivation, getFusedBiasGradient, getFusedDyActivation, shouldFuse } from '../fused_util';
|
import { op } from '../operation';
|
import { reshape } from '../reshape';
|
/**
|
* Computes a 2D convolution over the input x, optionally fused with adding a
|
* bias and applying an activation.
|
*
|
* ```js
|
* const inputDepth = 2;
|
* const inShape = [2, 2, 2, inputDepth];
|
* const outputDepth = 2;
|
* const fSize = 1;
|
* const pad = 0;
|
* const strides = 1;
|
*
|
* const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
* 16], inShape);
|
* const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,
|
* outputDepth]);
|
*
|
* tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',
|
* dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();
|
* ```
|
*
|
* @param obj An object with the following properties:
|
* @param x The input tensor, of rank 4 or rank 3, of shape
|
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
|
* assumed.
|
* @param filter The filter, rank 4, of shape
|
* `[filterHeight, filterWidth, inDepth, outDepth]`.
|
* @param strides The strides of the convolution: `[strideHeight,
|
* strideWidth]`.
|
* @param pad The type of padding algorithm.
|
* - `same` and stride 1: output will be of same size as input,
|
* regardless of filter size.
|
* - `valid` output will be smaller than input if filter is larger
|
* than 1x1.
|
* - For more info, see this guide:
|
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
|
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
|
* @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to
|
* "NHWC". Specify the data format of the input and output data. With the
|
* default format "NHWC", the data is stored in the order of: [batch,
|
* height, width, channels]. Only "NHWC" is currently supported.
|
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
|
* in which we sample input values across the height and width dimensions
|
* in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
|
* number, then `dilationHeight == dilationWidth`. If it is greater than
|
* 1, then all values of `strides` must be 1.
|
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
|
* provided, it will default to truncate.
|
* @param bias Tensor to be added to the result.
|
* @param activation Name of activation kernel (defaults to `linear`) to be
|
* applied
|
* after biasAdd.
|
* @param preluActivationWeights Tensor of prelu weights to be applied as part
|
* of a `prelu` activation, typically the same shape as `x`.
|
* @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
|
* activation.
|
*/
|
function fusedConv2d_({ x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha }) {
|
activation = activation || 'linear';
|
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
|
// TODO: Transpose bias and preluActivationWeights properly for NCHW
|
// format before computation.
|
util.assert(dataFormat === 'NHWC', () => `Error in fused conv2d: got dataFormat of ${dataFormat} but ` +
|
`only NHWC is currently supported for the case of gradient depth ` +
|
`is 0 and the activation is not linear.`);
|
let result = unfusedConv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
|
if (bias != null) {
|
result = add(result, bias);
|
}
|
return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
|
}
|
const $x = convertToTensor(x, 'x', 'conv2d', 'float32');
|
const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
|
let x4D = $x;
|
let reshapedTo4D = false;
|
if ($x.rank === 3) {
|
reshapedTo4D = true;
|
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
|
}
|
util.assert(x4D.rank === 4, () => `Error in fused conv2d: input must be rank 4, but got rank ` +
|
`${x4D.rank}.`);
|
util.assert($filter.rank === 4, () => `Error in fused conv2d: filter must be rank 4, but got rank ` +
|
`${$filter.rank}.`);
|
conv_util.checkPadOnDimRoundingMode('fused conv2d', pad, dimRoundingMode);
|
const inputChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
|
util.assert($filter.shape[2] === inputChannels, () => `Error in conv2d: depth of input (${inputChannels}) must match ` +
|
`input depth for filter ${$filter.shape[2]}.`);
|
util.assert(conv_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
|
`Got strides ${strides} and dilations '${dilations}'`);
|
const convInfo = conv_util.computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);
|
let $bias;
|
if (bias != null) {
|
$bias = convertToTensor(bias, 'bias', 'fused conv2d');
|
[$bias] = makeTypesMatch($bias, $x);
|
// According to TensorFlow, the bias is supposed be a 1-D tensor or a
|
// scalar.
|
//
|
// 3-D or 4-D bias is not disabled for NHWC format, because they are
|
// currently being used in some cases. For examplem in our code base,
|
// https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/fused_conv2d_test.ts#L1972.
|
if (dataFormat === 'NHWC') {
|
broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
|
}
|
else {
|
util.assert($bias.shape.length <= 1, () => `Error in fused conv2d: only supports scalar or 1-D Tensor ` +
|
`bias for NCHW format but got the bias of ` +
|
`rank-${$bias.shape.length}.`);
|
util.assert($bias.shape.length === 0 || $bias.shape[0] === convInfo.outChannels ||
|
$bias.shape[0] === 1, () => `Error in fused conv2d: bias shape (${$bias.shape}) is not ` +
|
`compatible with the number of output channels ` +
|
`(${convInfo.outChannels})`);
|
}
|
}
|
let $preluActivationWeights;
|
if (preluActivationWeights != null) {
|
// PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D
|
// tensor.
|
const alphaShape = preluActivationWeights.shape;
|
util.assert(alphaShape.length <= 1 || alphaShape.length === 3, () => `Error in fused conv2d: only supports scalar, 1-D Tensor or ` +
|
`3-D Tensor PReLU activation weights but got a tensor of ` +
|
`rank-${alphaShape.length}.`);
|
if (alphaShape.length === 1) {
|
// Whether the data format is NCHW or NHWC, the 1-D PReLU activation
|
// weights tensor should be aligned with the output channels of conv2d
|
// result.
|
util.assert(alphaShape[0] === 1 || alphaShape[0] === convInfo.outChannels, () => `Error in fused conv2d: PReLU activation weights ` +
|
`(${alphaShape}) is not compatible with the number of output ` +
|
`channels (${convInfo.outChannels}).`);
|
}
|
else if (alphaShape.length === 3) {
|
// Whether the data format is NCHW or NHWC, the PReLU activation weights
|
// tensor should has the compatible shape with the result of conv2d.
|
try {
|
broadcast_util.assertAndGetBroadcastShape(alphaShape, convInfo.outShape);
|
}
|
catch (e) {
|
const errMsg = `Error in fused conv2d: PReLU activation weights (${alphaShape}) ` +
|
`is not compatible with the output shape of the conv2d ` +
|
`(${convInfo.outShape}).`;
|
throw Error(errMsg);
|
}
|
}
|
$preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused conv2d');
|
}
|
const grad = (dy, saved) => {
|
util.assert(dataFormat === 'NHWC', () => `Error in gradient of fused conv2D: got dataFormat of ${dataFormat} but only NHWC is currently supported.`);
|
const [$filter, x4D, y, $bias] = saved;
|
const dyActivation = getFusedDyActivation(dy, y, activation);
|
util.assert(conv_util.tupleValuesAreOne(dilations), () => 'Error in gradient of fused conv2D: ' +
|
`dilation rates greater than 1 ` +
|
`are not yet supported in gradients. Got dilations '${dilations}'`);
|
const xDer = conv2DBackpropInput(x4D.shape, dyActivation, $filter, strides, pad);
|
const filterDer = conv2DBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad);
|
const der = [xDer, filterDer];
|
if ($bias != null) {
|
const biasDer = getFusedBiasGradient($bias, dyActivation);
|
der.push(biasDer);
|
}
|
return der;
|
};
|
const inputs = {
|
x: x4D,
|
filter: $filter,
|
bias: $bias,
|
preluActivationWeights: $preluActivationWeights
|
};
|
const attrs = {
|
strides,
|
pad,
|
dataFormat,
|
dilations,
|
dimRoundingMode,
|
activation,
|
leakyreluAlpha
|
};
|
// Depending on the the params passed in we will have different number of
|
// inputs and thus a a different number of elements in the gradient.
|
if (bias == null) {
|
const customOp = customGrad((x4D, filter, save) => {
|
let res =
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
ENGINE.runKernel(FusedConv2D, inputs, attrs);
|
save([filter, x4D, res]);
|
if (reshapedTo4D) {
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
|
}
|
return { value: res, gradFunc: grad };
|
});
|
return customOp(x4D, $filter);
|
}
|
else {
|
const customOpWithBias = customGrad((x4D, filter, bias, save) => {
|
let res = ENGINE.runKernel(FusedConv2D, inputs, attrs);
|
save([filter, x4D, res, bias]);
|
if (reshapedTo4D) {
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
|
}
|
return { value: res, gradFunc: grad };
|
});
|
return customOpWithBias(x4D, $filter, $bias);
|
}
|
}
|
export const conv2d = /* @__PURE__ */ op({ fusedConv2d_ });
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"conv2d.js","sourceRoot":"","sources":["../../../../../../../tfjs-core/src/ops/fused/conv2d.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAC,MAAM,cAAc,CAAC;AACpC,OAAO,EAAC,UAAU,EAAC,MAAM,iBAAiB,CAAC;AAC3C,OAAO,EAAC,WAAW,EAAsC,MAAM,oBAAoB,CAAC;AAIpF,OAAO,EAAC,cAAc,EAAC,MAAM,mBAAmB,CAAC;AACjD,OAAO,EAAC,eAAe,EAAC,MAAM,uBAAuB,CAAC;AAEtD,OAAO,KAAK,IAAI,MAAM,YAAY,CAAC;AACnC,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,KAAK,cAAc,MAAM,mBAAmB,CAAC;AACpD,OAAO,EAAC,MAAM,IAAI,aAAa,EAAC,MAAM,WAAW,CAAC;AAClD,OAAO,EAAC,oBAAoB,EAAC,MAAM,2BAA2B,CAAC;AAC/D,OAAO,EAAC,mBAAmB,EAAC,MAAM,0BAA0B,CAAC;AAC7D,OAAO,KAAK,SAAS,MAAM,cAAc,CAAC;AAE1C,OAAO,EAAC,eAAe,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,UAAU,EAAC,MAAM,eAAe,CAAC;AACtG,OAAO,EAAC,EAAE,EAAC,MAAM,cAAc,CAAC;AAChC,OAAO,EAAC,OAAO,EAAC,MAAM,YAAY,CAAC;AAEnC;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAwDG;AACH,SAAS,YAAY,CAA8B,EACjD,CAAC,EACD,MAAM,EACN,OAAO,EACP,GAAG,EACH,UAAU,GAAG,MAAM,EACnB,SAAS,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,EAClB,eAAe,EACf,IAAI,EACJ,UAAU,GAAG,QAAQ,EACrB,sBAAsB,EACtB,cAAc,EAaf;IACC,UAAU,GAAG,UAAU,IAAI,QAAQ,CAAC;IAEpC,IAAI,UAAU,CAAC,MAAM,CAAC,KAAK,CAAC,aAAa,EAAE,UAAU,CAAC,KAAK,KAAK,EAAE;QAChE,oEAAoE;QACpE,6BAA6B;QAC7B,IAAI,CAAC,MAAM,CACP,UAAU,KAAK,MAAM,EACrB,GAAG,EAAE,CAAC,4CAA4C,UAAU,OAAO;YAC/D,kEAAkE;YAClE,wCAAwC,CAAC,CAAC;QAElD,IAAI,MAAM,GAAG,aAAa,CACtB,CAAC,EAAE,MAAM,EAAE,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,SAAS,EAAE,eAAe,CAAC,CAAC;QACrE,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,MAAM,GAAG,GAAG,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC;SAC5B;QAED,OAAO,eAAe,CACX,MAAM,EAAE,UAAU,EAAE,sBAAsB,EAAE,cAAc,CAAM,CAAC;KAC7E;IAED,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,GAAG,EAAE,QAAQ,EAAE,SAAS,CAAC,CAAC;IACxD,MAAM,OAAO,GAAG,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,QAAQ,EAAE,SAAS,CAAC,CAAC;IAEvE,IAAI,GAAG,GAAG,EAAc,CAAC;IACzB,IAAI,YAAY,GAAG,KAAK,CAAC;IAEzB,IAAI,EAAE,CAAC,IAAI,KAAK,CAAC,EAAE;QACjB,YAAY,GAAG,IAAI,CAAC;QACpB,GAAG,GAAG,OAAO,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;KAC/D;IACD,IAAI,CAAC,MAAM,CACP,GAAG,CAAC,IAAI,KAAK,CAAC,EACd,GAAG,EAAE,CAAC,4DAA4D;QAC9D,GAAG,GAAG,CAAC,IAAI,GAAG,CAAC,CAAC;IACxB,IAAI,CAAC,MAAM,CACP,OAAO,CAAC,IAAI,KAAK,CAAC,EAClB,GAAG,EAAE,CAAC,6DAA6D;QAC/D,GAAG,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;IAC5B,SAAS,CAAC,yBAAyB,CAAC,cAAc,EAAE,GAAG,EAAE,eAAe,CAAC,CAAC;IAC1E,MAAM,aAAa,GAAG,UAAU,KAAK,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAC1E,IAAI,CAAC,MAAM,CACP,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,aAAa,EAClC,GAAG,EAAE,CAAC,oCAAoC,aAAa,eAAe;QAClE,0BAA0B,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;IACvD,IAAI,CAAC,MAAM,CACP,SAAS,CAAC,8BAA8B,CAAC,OAAO,EAAE,SAAS,CAAC,EAC5D,GAAG,EAAE,CAAC,0DAA0D;QAC5D,eAAe,OAAO,mBAAmB,SAAS,GAAG,CAAC,CAAC;IAE/D,MAAM,QAAQ,GAAG,SAAS,CAAC,iBAAiB,CACxC,GAAG,CAAC,KAAK,EAAE,OAAO,CAAC,KAAK,EAAE,OAAO,EAAE,SAAS,EAAE,GAAG,EAAE,eAAe,CAAC,CAAC;IAExE,IAAI,KAAa,CAAC;IAClB,IAAI,IAAI,IAAI,IAAI,EAAE;QAChB,KAAK,GAAG,eAAe,CAAC,IAAI,EAAE,MAAM,EAAE,cAAc,CAAC,CAAC;QACtD,CAAC,KAAK,CAAC,GAAG,cAAc,CAAC,KAAK,EAAE,EAAE,CAAC,CAAC;QAEpC,qEAAqE;QACrE,UAAU;QACV,EAAE;QACF,oEAAoE;QACpE,qEAAqE;QACrE,uIAAuI;QACvI,IAAI,UAAU,KAAK,MAAM,EAAE;YACzB,cAAc,CAAC,0BAA0B,CAAC,QAAQ,CAAC,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC;SAC3E;aAAM;YACL,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EACvB,GAAG,EAAE,CAAC,4DAA4D;gBAC9D,2CAA2C;gBAC3C,QAAQ,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;YAEvC,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,QAAQ,CAAC,WAAW;gBAC/D,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EACxB,GAAG,EAAE,CAAC,sCAAsC,KAAK,CAAC,KAAK,WAAW;gBAC9D,gDAAgD;gBAChD,IAAI,QAAQ,CAAC,WAAW,GAAG,CAAC,CAAC;SACtC;KACF;IAED,IAAI,uBAA+B,CAAC;IACpC,IAAI,sBAAsB,IAAI,IAAI,EAAE;QAClC,sEAAsE;QACtE,UAAU;QACV,MAAM,UAAU,GAAG,sBAAsB,CAAC,KAAK,CAAC;QAChD,IAAI,CAAC,MAAM,CACP,UAAU,CAAC,MAAM,IAAI,CAAC,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC,EACjD,GAAG,EAAE,CAAC,6DAA6D;YAC/D,0DAA0D;YAC1D,QAAQ,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC;QAEtC,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC,EAAE;YAC3B,oEAAoE;YACpE,sEAAsE;YACtE,UAAU;YACV,IAAI,CAAC,MAAM,CACP,UAAU,CAAC,CAAC,CAAC,KAAK,CAAC,IAAI,UAAU,CAAC,CAAC,CAAC,KAAK,QAAQ,CAAC,WAAW,EAC7D,GAAG,EAAE,CAAC,kDAAkD;gBACpD,IAAI,UAAU,gDAAgD;gBAC9D,aAAa,QAAQ,CAAC,WAAW,IAAI,CAAC,CAAC;SAChD;aAAM,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC,EAAE;YAClC,wEAAwE;YACxE,oEAAoE;YACpE,IAAI;gBACF,cAAc,CAAC,0BAA0B,CACrC,UAAU,EAAE,QAAQ,CAAC,QAAQ,CAAC,CAAC;aACpC;YAAC,OAAO,CAAC,EAAE;gBACV,MAAM,MAAM,GACR,oDAAoD,UAAU,IAAI;oBAClE,wDAAwD;oBACxD,IAAI,QAAQ,CAAC,QAAQ,IAAI,CAAC;gBAC9B,MAAM,KAAK,CAAC,MAAM,CAAC,CAAC;aACrB;SACF;QAED,uBAAuB,GAAG,eAAe,CACrC,sBAAsB,EAAE,eAAe,EAAE,cAAc,CAAC,CAAC;KAC9D;IAED,MAAM,IAAI,GAAG,CAAC,EAAY,EAAE,KAAe,EAAE,EAAE;QAC7C,IAAI,CAAC,MAAM,CACP,UAAU,KAAK,MAAM,EACrB,GAAG,EAAE,CAAC,wDACF,UAAU,wCAAwC,CAAC,CAAC;QAE5D,MAAM,CAAC,OAAO,EAAE,GAAG,EAAE,CAAC,EAAE,KAAK,CAAC,GAC1B,KAA+C,CAAC;QAEpD,MAAM,YAAY,GAAG,oBAAoB,CAAC,EAAE,EAAE,CAAC,EAAE,UAAU,CAAa,CAAC;QAEzE,IAAI,CAAC,MAAM,CACP,SAAS,CAAC,iBAAiB,CAAC,SAAS,CAAC,EACtC,GAAG,EAAE,CAAC,qCAAqC;YACvC,gCAAgC;YAChC,sDAAsD,SAAS,GAAG,CAAC,CAAC;QAE5E,MAAM,IAAI,GACN,mBAAmB,CAAC,GAAG,CAAC,KAAK,EAAE,YAAY,EAAE,OAAO,EAAE,OAAO,EAAE,GAAG,CAAC,CAAC;QACxE,MAAM,SAAS,GACX,oBAAoB,CAAC,GAAG,EAAE,YAAY,EAAE,OAAO,CAAC,KAAK,EAAE,OAAO,EAAE,GAAG,CAAC,CAAC;QACzE,MAAM,GAAG,GAAa,CAAC,IAAI,EAAE,SAAS,CAAC,CAAC;QAExC,IAAI,KAAK,IAAI,IAAI,EAAE;YACjB,MAAM,OAAO,GAAG,oBAAoB,CAAC,KAAK,EAAE,YAAY,CAAC,CAAC;YAC1D,GAAG,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SACnB;QACD,OAAO,GAAG,CAAC;IACb,CAAC,CAAC;IAEF,MAAM,MAAM,GAAsB;QAChC,CAAC,EAAE,GAAG;QACN,MAAM,EAAE,OAAO;QACf,IAAI,EAAE,KAAK;QACX,sBAAsB,EAAE,uBAAuB;KAChD,CAAC;IAEF,MAAM,KAAK,GAAqB;QAC9B,OAAO;QACP,GAAG;QACH,UAAU;QACV,SAAS;QACT,eAAe;QACf,UAAU;QACV,cAAc;KACf,CAAC;IAEF,yEAAyE;IACzE,oEAAoE;IACpE,IAAI,IAAI,IAAI,IAAI,EAAE;QAChB,MAAM,QAAQ,GACV,UAAU,CAAC,CAAC,GAAa,EAAE,MAAgB,EAAE,IAAkB,EAAE,EAAE;YACjE,IAAI,GAAG;YACH,0DAA0D;YAC1D,MAAM,CAAC,SAAS,CACZ,WAAW,EAAE,MAAmC,EAChD,KAAgC,CAAC,CAAC;YAE1C,IAAI,CAAC,CAAC,MAAM,EAAE,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;YAEzB,IAAI,YAAY,EAAE;gBAChB,0DAA0D;gBAC1D,GAAG,GAAG,OAAO,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CACjD,CAAC;aACd;YAED,OAAO,EAAC,KAAK,EAAE,GAAG,EAAE,QAAQ,EAAE,IAAI,EAAC,CAAC;QACtC,CAAC,CAAC,CAAC;QACP,OAAO,QAAQ,CAAC,GAAG,EAAE,OAAO,CAAM,CAAC;KACpC;SAAM;QACL,MAAM,gBAAgB,GAAG,UAAU,CAC/B,CAAC,GAAa,EAAE,MAAgB,EAAE,IAAY,EAAE,IAAkB,EAAE,EAAE;YACpE,IAAI,GAAG,GAAsB,MAAM,CAAC,SAAS,CACzC,WAAW,EAAE,MAAmC,EAChD,KAAgC,CAAC,CAAC;YAEtC,IAAI,CAAC,CAAC,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC;YAE/B,IAAI,YAAY,EAAE;gBAChB,0DAA0D;gBAC1D,GAAG,GAAG,OAAO,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CACjD,CAAC;aACd;YAED,OAAO,EAAC,KAAK,EAAE,GAAG,EAAE,QAAQ,EAAE,IAAI,EAAC,CAAC;QACtC,CAAC,CAAC,CAAC;QAEP,OAAO,gBAAgB,CAAC,GAAG,EAAE,OAAO,EAAE,KAAK,CAAM,CAAC;KACnD;AACH,CAAC;AACD,MAAM,CAAC,MAAM,MAAM,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,YAAY,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../../engine';\nimport {customGrad} from '../../gradients';\nimport {FusedConv2D, FusedConv2DAttrs, FusedConv2DInputs} from '../../kernel_names';\nimport {NamedAttrMap} from '../../kernel_registry';\nimport {Tensor, Tensor3D, Tensor4D} from '../../tensor';\nimport {GradSaveFunc, NamedTensorMap} from '../../tensor_types';\nimport {makeTypesMatch} from '../../tensor_util';\nimport {convertToTensor} from '../../tensor_util_env';\nimport {TensorLike} from '../../types';\nimport * as util from '../../util';\nimport {add} from '../add';\nimport * as broadcast_util from '../broadcast_util';\nimport {conv2d as unfusedConv2d} from '../conv2d';\nimport {conv2DBackpropFilter} from '../conv2d_backprop_filter';\nimport {conv2DBackpropInput} from '../conv2d_backprop_input';\nimport * as conv_util from '../conv_util';\nimport {Activation} from '../fused_types';\nimport {applyActivation, getFusedBiasGradient, getFusedDyActivation, shouldFuse} from '../fused_util';\nimport {op} from '../operation';\nimport {reshape} from '../reshape';\n\n/**\n * Computes a 2D convolution over the input x, optionally fused with adding a\n * bias and applying an activation.\n *\n * ```js\n * const inputDepth = 2;\n * const inShape = [2, 2, 2, inputDepth];\n * const outputDepth = 2;\n * const fSize = 1;\n * const pad = 0;\n * const strides = 1;\n *\n * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n * 16], inShape);\n * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,\n * outputDepth]);\n *\n * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',\n * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();\n * ```\n *\n * @param obj An object with the following properties:\n * @param x The input tensor, of rank 4 or rank 3, of shape\n *     `[batch, height, width, inChannels]`. If rank 3, batch of 1 is\n * assumed.\n * @param filter The filter, rank 4, of shape\n *     `[filterHeight, filterWidth, inDepth, outDepth]`.\n * @param strides The strides of the convolution: `[strideHeight,\n * strideWidth]`.\n * @param pad The type of padding algorithm.\n *   - `same` and stride 1: output will be of same size as input,\n *       regardless of filter size.\n *   - `valid` output will be smaller than input if filter is larger\n *       than 1x1.\n *   - For more info, see this guide:\n *     [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](\n *          https://www.tensorflow.org/api_docs/python/tf/nn/convolution)\n * @param dataFormat An optional string from: \"NHWC\", \"NCHW\". Defaults to\n *     \"NHWC\". Specify the data format of the input and output data. With the\n *     default format \"NHWC\", the data is stored in the order of: [batch,\n *     height, width, channels]. Only \"NHWC\" is currently supported.\n * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`\n *     in which we sample input values across the height and width dimensions\n *     in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single\n *     number, then `dilationHeight == dilationWidth`. If it is greater than\n *     1, then all values of `strides` must be 1.\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is\n *     provided, it will default to truncate.\n * @param bias Tensor to be added to the result.\n * @param activation Name of activation kernel (defaults to `linear`) to be\n *     applied\n *      after biasAdd.\n * @param preluActivationWeights Tensor of prelu weights to be applied as part\n *     of a `prelu` activation, typically the same shape as `x`.\n * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`\n *     activation.\n */\nfunction fusedConv2d_<T extends Tensor3D|Tensor4D>({\n  x,\n  filter,\n  strides,\n  pad,\n  dataFormat = 'NHWC',\n  dilations = [1, 1],\n  dimRoundingMode,\n  bias,\n  activation = 'linear',\n  preluActivationWeights,\n  leakyreluAlpha\n}: {\n  x: T|TensorLike,\n  filter: Tensor4D|TensorLike,\n  strides: [number, number]|number,\n  pad: 'valid'|'same'|number|conv_util.ExplicitPadding,\n  dataFormat?: 'NHWC'|'NCHW',\n  dilations?: [number, number]|number,\n  dimRoundingMode?: 'floor'|'round'|'ceil',\n  bias?: Tensor|TensorLike,\n  activation?: Activation,\n  preluActivationWeights?: Tensor,\n  leakyreluAlpha?: number\n}): T {\n  activation = activation || 'linear';\n\n  if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {\n    // TODO: Transpose bias and preluActivationWeights properly for NCHW\n    // format before computation.\n    util.assert(\n        dataFormat === 'NHWC',\n        () => `Error in fused conv2d: got dataFormat of ${dataFormat} but ` +\n            `only NHWC is currently supported for the case of gradient depth ` +\n            `is 0 and the activation is not linear.`);\n\n    let result = unfusedConv2d(\n        x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);\n    if (bias != null) {\n      result = add(result, bias);\n    }\n\n    return applyActivation(\n               result, activation, preluActivationWeights, leakyreluAlpha) as T;\n  }\n\n  const $x = convertToTensor(x, 'x', 'conv2d', 'float32');\n  const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');\n\n  let x4D = $x as Tensor4D;\n  let reshapedTo4D = false;\n\n  if ($x.rank === 3) {\n    reshapedTo4D = true;\n    x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);\n  }\n  util.assert(\n      x4D.rank === 4,\n      () => `Error in fused conv2d: input must be rank 4, but got rank ` +\n          `${x4D.rank}.`);\n  util.assert(\n      $filter.rank === 4,\n      () => `Error in fused conv2d: filter must be rank 4, but got rank ` +\n          `${$filter.rank}.`);\n  conv_util.checkPadOnDimRoundingMode('fused conv2d', pad, dimRoundingMode);\n  const inputChannels = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];\n  util.assert(\n      $filter.shape[2] === inputChannels,\n      () => `Error in conv2d: depth of input (${inputChannels}) must match ` +\n          `input depth for filter ${$filter.shape[2]}.`);\n  util.assert(\n      conv_util.eitherStridesOrDilationsAreOne(strides, dilations),\n      () => 'Error in conv2D: Either strides or dilations must be 1. ' +\n          `Got strides ${strides} and dilations '${dilations}'`);\n\n  const convInfo = conv_util.computeConv2DInfo(\n      x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);\n\n  let $bias: Tensor;\n  if (bias != null) {\n    $bias = convertToTensor(bias, 'bias', 'fused conv2d');\n    [$bias] = makeTypesMatch($bias, $x);\n\n    // According to TensorFlow, the bias is supposed be a 1-D tensor or a\n    // scalar.\n    //\n    // 3-D or 4-D bias is not disabled for NHWC format, because they are\n    // currently being used in some cases. For examplem in our code base,\n    // https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/fused_conv2d_test.ts#L1972.\n    if (dataFormat === 'NHWC') {\n      broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);\n    } else {\n      util.assert(\n          $bias.shape.length <= 1,\n          () => `Error in fused conv2d: only supports scalar or 1-D Tensor ` +\n              `bias for NCHW format but got the bias of ` +\n              `rank-${$bias.shape.length}.`);\n\n      util.assert(\n          $bias.shape.length === 0 || $bias.shape[0] === convInfo.outChannels ||\n              $bias.shape[0] === 1,\n          () => `Error in fused conv2d: bias shape (${$bias.shape}) is not ` +\n              `compatible with the number of output channels ` +\n              `(${convInfo.outChannels})`);\n    }\n  }\n\n  let $preluActivationWeights: Tensor;\n  if (preluActivationWeights != null) {\n    // PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D\n    // tensor.\n    const alphaShape = preluActivationWeights.shape;\n    util.assert(\n        alphaShape.length <= 1 || alphaShape.length === 3,\n        () => `Error in fused conv2d: only supports scalar, 1-D Tensor or ` +\n            `3-D Tensor PReLU activation weights but got a tensor of ` +\n            `rank-${alphaShape.length}.`);\n\n    if (alphaShape.length === 1) {\n      // Whether the data format is NCHW or NHWC, the 1-D PReLU activation\n      // weights tensor should be aligned with the output channels of conv2d\n      // result.\n      util.assert(\n          alphaShape[0] === 1 || alphaShape[0] === convInfo.outChannels,\n          () => `Error in fused conv2d: PReLU activation weights ` +\n              `(${alphaShape}) is not compatible with the number of output ` +\n              `channels (${convInfo.outChannels}).`);\n    } else if (alphaShape.length === 3) {\n      // Whether the data format is NCHW or NHWC, the PReLU activation weights\n      // tensor should has the compatible shape with the result of conv2d.\n      try {\n        broadcast_util.assertAndGetBroadcastShape(\n            alphaShape, convInfo.outShape);\n      } catch (e) {\n        const errMsg =\n            `Error in fused conv2d: PReLU activation weights (${alphaShape}) ` +\n            `is not compatible with the output shape of the conv2d ` +\n            `(${convInfo.outShape}).`;\n        throw Error(errMsg);\n      }\n    }\n\n    $preluActivationWeights = convertToTensor(\n        preluActivationWeights, 'prelu weights', 'fused conv2d');\n  }\n\n  const grad = (dy: Tensor4D, saved: Tensor[]) => {\n    util.assert(\n        dataFormat === 'NHWC',\n        () => `Error in gradient of fused conv2D: got dataFormat of ${\n            dataFormat} but only NHWC is currently supported.`);\n\n    const [$filter, x4D, y, $bias] =\n        saved as [Tensor4D, Tensor4D, Tensor4D, Tensor];\n\n    const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D;\n\n    util.assert(\n        conv_util.tupleValuesAreOne(dilations),\n        () => 'Error in gradient of fused conv2D: ' +\n            `dilation rates greater than 1 ` +\n            `are not yet supported in gradients. Got dilations '${dilations}'`);\n\n    const xDer =\n        conv2DBackpropInput(x4D.shape, dyActivation, $filter, strides, pad);\n    const filterDer =\n        conv2DBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad);\n    const der: Tensor[] = [xDer, filterDer];\n\n    if ($bias != null) {\n      const biasDer = getFusedBiasGradient($bias, dyActivation);\n      der.push(biasDer);\n    }\n    return der;\n  };\n\n  const inputs: FusedConv2DInputs = {\n    x: x4D,\n    filter: $filter,\n    bias: $bias,\n    preluActivationWeights: $preluActivationWeights\n  };\n\n  const attrs: FusedConv2DAttrs = {\n    strides,\n    pad,\n    dataFormat,\n    dilations,\n    dimRoundingMode,\n    activation,\n    leakyreluAlpha\n  };\n\n  // Depending on the the params passed in we will have different number of\n  // inputs and thus a a different number of elements in the gradient.\n  if (bias == null) {\n    const customOp =\n        customGrad((x4D: Tensor4D, filter: Tensor4D, save: GradSaveFunc) => {\n          let res: Tensor4D|Tensor3D =\n              // tslint:disable-next-line: no-unnecessary-type-assertion\n              ENGINE.runKernel(\n                  FusedConv2D, inputs as unknown as NamedTensorMap,\n                  attrs as unknown as NamedAttrMap);\n\n          save([filter, x4D, res]);\n\n          if (reshapedTo4D) {\n            // tslint:disable-next-line: no-unnecessary-type-assertion\n            res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]) as\n                Tensor3D;\n          }\n\n          return {value: res, gradFunc: grad};\n        });\n    return customOp(x4D, $filter) as T;\n  } else {\n    const customOpWithBias = customGrad(\n        (x4D: Tensor4D, filter: Tensor4D, bias: Tensor, save: GradSaveFunc) => {\n          let res: Tensor4D|Tensor3D = ENGINE.runKernel(\n              FusedConv2D, inputs as unknown as NamedTensorMap,\n              attrs as unknown as NamedAttrMap);\n\n          save([filter, x4D, res, bias]);\n\n          if (reshapedTo4D) {\n            // tslint:disable-next-line: no-unnecessary-type-assertion\n            res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]) as\n                Tensor3D;\n          }\n\n          return {value: res, gradFunc: grad};\n        });\n\n    return customOpWithBias(x4D, $filter, $bias) as T;\n  }\n}\nexport const conv2d = /* @__PURE__ */ op({fusedConv2d_});\n"]}
|