"use strict"; /** * @license * Copyright 2017 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ Object.defineProperty(exports, "__esModule", { value: true }); var util = require("../util"); function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { if (dataFormat === void 0) { dataFormat = 'channelsLast'; } var _a = parseTupleParam(filterSize), filterHeight = _a[0], filterWidth = _a[1]; var filterShape; if (dataFormat === 'channelsLast') { filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]]; } else if (dataFormat === 'channelsFirst') { filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]]; } else { throw new Error("Unknown dataFormat " + dataFormat); } return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat); } exports.computePool2DInfo = computePool2DInfo; /** * Computes the information for a forward pass of a pooling3D operation. */ function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { if (dataFormat === void 0) { dataFormat = 'NDHWC'; } var _a = parse3TupleParam(filterSize), filterDepth = _a[0], filterHeight = _a[1], filterWidth = _a[2]; var filterShape; var $dataFormat; if (dataFormat === 'NDHWC') { $dataFormat = 'channelsLast'; filterShape = [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]]; } else if (dataFormat === 'NCDHW') { $dataFormat = 'channelsFirst'; filterShape = [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]]; } else { throw new Error("Unknown dataFormat " + dataFormat); } return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode); } exports.computePool3DInfo = computePool3DInfo; /** * Computes the information for a forward pass of a convolution/pooling * operation. */ function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) { if (depthwise === void 0) { depthwise = false; } if (dataFormat === void 0) { dataFormat = 'channelsLast'; } var _a = [-1, -1, -1, -1], batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3]; if (dataFormat === 'channelsLast') { batchSize = inShape[0], inHeight = inShape[1], inWidth = inShape[2], inChannels = inShape[3]; } else if (dataFormat === 'channelsFirst') { batchSize = inShape[0], inChannels = inShape[1], inHeight = inShape[2], inWidth = inShape[3]; } else { throw new Error("Unknown dataFormat " + dataFormat); } var filterHeight = filterShape[0], filterWidth = filterShape[1], filterChannels = filterShape[3]; var _b = parseTupleParam(strides), strideHeight = _b[0], strideWidth = _b[1]; var _c = parseTupleParam(dilations), dilationHeight = _c[0], dilationWidth = _c[1]; var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); var _d = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outHeight = _d.outHeight, outWidth = _d.outWidth; var outChannels = depthwise ? filterChannels * inChannels : filterChannels; var outShape; if (dataFormat === 'channelsFirst') { outShape = [batchSize, outChannels, outHeight, outWidth]; } else if (dataFormat === 'channelsLast') { outShape = [batchSize, outHeight, outWidth, outChannels]; } return { batchSize: batchSize, dataFormat: dataFormat, inHeight: inHeight, inWidth: inWidth, inChannels: inChannels, outHeight: outHeight, outWidth: outWidth, outChannels: outChannels, padInfo: padInfo, strideHeight: strideHeight, strideWidth: strideWidth, filterHeight: filterHeight, filterWidth: filterWidth, effectiveFilterHeight: effectiveFilterHeight, effectiveFilterWidth: effectiveFilterWidth, dilationHeight: dilationHeight, dilationWidth: dilationWidth, inShape: inShape, outShape: outShape, filterShape: filterShape }; } exports.computeConv2DInfo = computeConv2DInfo; /** * Computes the information for a forward pass of a 3D convolution/pooling * operation. */ function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat, roundingMode) { if (depthwise === void 0) { depthwise = false; } if (dataFormat === void 0) { dataFormat = 'channelsLast'; } var _a = [-1, -1, -1, -1, -1], batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4]; if (dataFormat === 'channelsLast') { batchSize = inShape[0], inDepth = inShape[1], inHeight = inShape[2], inWidth = inShape[3], inChannels = inShape[4]; } else if (dataFormat === 'channelsFirst') { batchSize = inShape[0], inChannels = inShape[1], inDepth = inShape[2], inHeight = inShape[3], inWidth = inShape[4]; } else { throw new Error("Unknown dataFormat " + dataFormat); } var filterDepth = filterShape[0], filterHeight = filterShape[1], filterWidth = filterShape[2], filterChannels = filterShape[4]; var _b = parse3TupleParam(strides), strideDepth = _b[0], strideHeight = _b[1], strideWidth = _b[2]; var _c = parse3TupleParam(dilations), dilationDepth = _c[0], dilationHeight = _c[1], dilationWidth = _c[2]; var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); var _d = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outDepth = _d.outDepth, outHeight = _d.outHeight, outWidth = _d.outWidth; var outChannels = depthwise ? filterChannels * inChannels : filterChannels; var outShape; if (dataFormat === 'channelsFirst') { outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; } else if (dataFormat === 'channelsLast') { outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; } return { batchSize: batchSize, dataFormat: dataFormat, inDepth: inDepth, inHeight: inHeight, inWidth: inWidth, inChannels: inChannels, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth, outChannels: outChannels, padInfo: padInfo, strideDepth: strideDepth, strideHeight: strideHeight, strideWidth: strideWidth, filterDepth: filterDepth, filterHeight: filterHeight, filterWidth: filterWidth, effectiveFilterDepth: effectiveFilterDepth, effectiveFilterHeight: effectiveFilterHeight, effectiveFilterWidth: effectiveFilterWidth, dilationDepth: dilationDepth, dilationHeight: dilationHeight, dilationWidth: dilationWidth, inShape: inShape, outShape: outShape, filterShape: filterShape }; } exports.computeConv3DInfo = computeConv3DInfo; function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) { if (zeroPad == null) { zeroPad = computeDefaultPad(inShape, fieldSize, stride); } var inputRows = inShape[0]; var inputCols = inShape[1]; var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert(util.isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " + "Change the stride and/or zero pad parameters"; }); var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert(util.isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " + "Change the stride and/or zero pad parameters"; }); return [outputRows, outputCols]; } function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) { if (zeroPad == null) { zeroPad = computeDefaultPad(inShape, fieldSize, stride); } var inputDepth = inShape[0]; var inputRows = inShape[1]; var inputCols = inShape[2]; var outputDepths = conditionalRound((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert(util.isInt(outputDepths), function () { return "The output # of depths (" + outputDepths + ") must be an integer. " + "Change the stride and/or zero pad parameters"; }); var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert(util.isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " + "Change the stride and/or zero pad parameters"; }); var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); util.assert(util.isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " + "Change the stride and/or zero pad parameters"; }); return [outputDepths, outputRows, outputCols, outChannels]; } function computeDefaultPad(inputShape, fieldSize, stride, dilation) { if (dilation === void 0) { dilation = 1; } var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); } exports.computeDefaultPad = computeDefaultPad; function parseTupleParam(param) { if (typeof param === 'number') { return [param, param, param]; } if (param.length === 2) { return [param[0], param[1], 1]; } return param; } function parse3TupleParam(param) { return typeof param === 'number' ? [param, param, param] : param; } /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d * Atrous convolution is equivalent to standard convolution with upsampled * filters with effective_filter_height = * filter_height + (filter_height - 1) * (dilation - 1) * and effective_filter_width = * filter_width + (filter_width - 1) * (dilation - 1), * produced by inserting dilation - 1 zeros along consecutive elements across * the filters' spatial dimensions. * When there is a dilation, this converts a filter dimension to the * effective filter dimension, so it can be used in a standard convolution. */ function getEffectiveFilterSize(filterSize, dilation) { if (dilation <= 1) { return filterSize; } return filterSize + (filterSize - 1) * (dilation - 1); } function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode) { var padInfo; var outHeight; var outWidth; if (typeof pad === 'number') { var padType = (pad === 0) ? 'VALID' : 'NUMBER'; padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType }; var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode); outHeight = outShape[0]; outWidth = outShape[1]; } else if (pad === 'same') { outHeight = Math.ceil(inHeight / strideHeight); outWidth = Math.ceil(inWidth / strideWidth); var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight); var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth); var top_1 = Math.floor(padAlongHeight / 2); var bottom = padAlongHeight - top_1; var left = Math.floor(padAlongWidth / 2); var right = padAlongWidth - left; padInfo = { top: top_1, bottom: bottom, left: left, right: right, type: 'SAME' }; } else if (pad === 'valid') { padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' }; outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); } else { throw Error("Unknown padding parameter: " + pad); } return { padInfo: padInfo, outHeight: outHeight, outWidth: outWidth }; } function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) { var padInfo; var outDepth; var outHeight; var outWidth; if (typeof pad === 'number') { var padType = (pad === 0) ? 'VALID' : 'NUMBER'; padInfo = { top: pad, bottom: pad, left: pad, right: pad, front: pad, back: pad, type: padType }; var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode); outDepth = outShape[0]; outHeight = outShape[1]; outWidth = outShape[2]; } else if (pad === 'same') { outDepth = Math.ceil(inDepth / strideDepth); outHeight = Math.ceil(inHeight / strideHeight); outWidth = Math.ceil(inWidth / strideWidth); var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; var front = Math.floor(padAlongDepth / 2); var back = padAlongDepth - front; var top_2 = Math.floor(padAlongHeight / 2); var bottom = padAlongHeight - top_2; var left = Math.floor(padAlongWidth / 2); var right = padAlongWidth - left; padInfo = { top: top_2, bottom: bottom, left: left, right: right, front: front, back: back, type: 'SAME' }; } else if (pad === 'valid') { padInfo = { top: 0, bottom: 0, left: 0, right: 0, front: 0, back: 0, type: 'VALID' }; outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth); outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); } else { throw Error("Unknown padding parameter: " + pad); } return { padInfo: padInfo, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth }; } /** * Rounds a value depending on the rounding mode * @param value * @param roundingMode */ function conditionalRound(value, roundingMode) { if (!roundingMode) { return value; } switch (roundingMode) { case 'round': // used for Caffe Conv return Math.round(value); case 'ceil': // used for Caffe Pool return Math.ceil(value); case 'floor': return Math.floor(value); default: throw new Error("Unknown roundingMode " + roundingMode); } } function tupleValuesAreOne(param) { var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1], dimC = _a[2]; return dimA === 1 && dimB === 1 && dimC === 1; } exports.tupleValuesAreOne = tupleValuesAreOne; function eitherStridesOrDilationsAreOne(strides, dilations) { return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); } exports.eitherStridesOrDilationsAreOne = eitherStridesOrDilationsAreOne; /** * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to * 'channelsLast'|'channelsFirst' * @param dataFormat in 'NHWC'|'NCHW' mode * @return dataFormat in 'channelsLast'|'channelsFirst' mode * @throws unknown dataFormat */ function convertConv2DDataFormat(dataFormat) { if (dataFormat === 'NHWC') { return 'channelsLast'; } else if (dataFormat === 'NCHW') { return 'channelsFirst'; } else { throw new Error("Unknown dataFormat " + dataFormat); } } exports.convertConv2DDataFormat = convertConv2DDataFormat; //# sourceMappingURL=conv_util.js.map