import { convertToTensor } from '../tensor_util_env'; import * as util from '../util'; import { conv2d } from './conv2d'; import * as conv_util from './conv_util'; import { op } from './operation'; import { reshape } from './reshape'; /** * Computes a 1D convolution over the input x. * * @param x The input tensor, of rank 3 or rank 2, of shape * `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed. * @param filter The filter, rank 3, of shape * `[filterWidth, inDepth, outDepth]`. * @param stride The number of entries by which the filter is moved right at * each step. * @param pad The type of padding algorithm. * - `same` and stride 1: output will be of same size as input, * regardless of filter size. * - `valid`: output will be smaller than input if filter is larger * than 1x1. * - For more info, see this guide: * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution]( * https://www.tensorflow.org/api_docs/python/tf/nn/convolution) * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC", * the data is stored in the order of [batch, in_width, in_channels]. Only * "NWC" is currently supported. * @param dilation The dilation rate in which we sample input values in * atrous convolution. Defaults to `1`. If it is greater than 1, then * stride must be `1`. * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. * * @doc {heading: 'Operations', subheading: 'Convolution'} */ function conv1d_(x, filter, stride, pad, dataFormat = 'NWC', dilation = 1, dimRoundingMode) { const $x = convertToTensor(x, 'x', 'conv1d'); const $filter = convertToTensor(filter, 'filter', 'conv1d'); let x3D = $x; let reshapedTo3D = false; if ($x.rank === 2) { reshapedTo3D = true; x3D = reshape($x, [1, $x.shape[0], $x.shape[1]]); } util.assert(x3D.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${x3D.rank}.`); util.assert($filter.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ` + `${$filter.rank}.`); conv_util.checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode); util.assert(x3D.shape[2] === $filter.shape[1], () => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` + `input depth for filter ${$filter.shape[1]}.`); util.assert(conv_util.eitherStridesOrDilationsAreOne(stride, dilation), () => 'Error in conv1D: Either stride or dilation must be 1. ' + `Got stride ${stride} and dilation '${dilation}'`); util.assert(conv_util.stridesOrDilationsArePositive(dilation), () => 'Error in conv1D: Dilated rates should be larger than 0.'); util.assert(conv_util.stridesOrDilationsArePositive(stride), () => 'Error in conv1D: Stride should be larger than 0.'); util.assert(dataFormat === 'NWC', () => `Error in conv1d: got dataFormat of ${dataFormat} but only NWC is currently supported.`); const filter4D = reshape($filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]); const input4D = reshape(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]); const strides = [1, stride]; const dilations = [1, dilation]; const conv2dDataFormat = 'NHWC'; const res = conv2d(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode); if (reshapedTo3D) { return reshape(res, [res.shape[2], res.shape[3]]); } return reshape(res, [res.shape[0], res.shape[2], res.shape[3]]); } export const conv1d = /* @__PURE__ */ op({ conv1d_ }); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"conv1d.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/conv1d.ts"],"names":[],"mappings":"AAiBA,OAAO,EAAC,eAAe,EAAC,MAAM,oBAAoB,CAAC;AAEnD,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAEhC,OAAO,EAAC,MAAM,EAAC,MAAM,UAAU,CAAC;AAChC,OAAO,KAAK,SAAS,MAAM,aAAa,CAAC;AACzC,OAAO,EAAC,EAAE,EAAC,MAAM,aAAa,CAAC;AAC/B,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC;;;;;;;;;;;;;;;;;;;;;;;;;;;GA2BG;AACH,SAAS,OAAO,CACZ,CAAe,EAAE,MAA2B,EAAE,MAAc,EAC5D,GAAoD,EACpD,aAA0B,KAAK,EAAE,QAAQ,GAAG,CAAC,EAC7C,eAAwC;IAC1C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,GAAG,EAAE,QAAQ,CAAC,CAAC;IAC7C,MAAM,OAAO,GAAG,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,QAAQ,CAAC,CAAC;IAE5D,IAAI,GAAG,GAAG,EAAc,CAAC;IACzB,IAAI,YAAY,GAAG,KAAK,CAAC;IACzB,IAAI,EAAE,CAAC,IAAI,KAAK,CAAC,EAAE;QACjB,YAAY,GAAG,IAAI,CAAC;QACpB,GAAG,GAAG,OAAO,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;KAClD;IAED,IAAI,CAAC,MAAM,CACP,GAAG,CAAC,IAAI,KAAK,CAAC,EACd,GAAG,EAAE,CAAC,uDAAuD,GAAG,CAAC,IAAI,GAAG,CAAC,CAAC;IAC9E,IAAI,CAAC,MAAM,CACP,OAAO,CAAC,IAAI,KAAK,CAAC,EAClB,GAAG,EAAE,CAAC,uDAAuD;QACzD,GAAG,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;IAC5B,SAAS,CAAC,yBAAyB,CAAC,QAAQ,EAAE,GAAG,EAAE,eAAe,CAAC,CAAC;IACpE,IAAI,CAAC,MAAM,CACP,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EACjC,GAAG,EAAE,CAAC,oCAAoC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,eAAe;QACjE,0BAA0B,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;IACvD,IAAI,CAAC,MAAM,CACP,SAAS,CAAC,8BAA8B,CAAC,MAAM,EAAE,QAAQ,CAAC,EAC1D,GAAG,EAAE,CAAC,wDAAwD;QAC1D,cAAc,MAAM,kBAAkB,QAAQ,GAAG,CAAC,CAAC;IAC3D,IAAI,CAAC,MAAM,CACP,SAAS,CAAC,6BAA6B,CAAC,QAAQ,CAAC,EACjD,GAAG,EAAE,CAAC,yDAAyD,CAAC,CAAC;IACrE,IAAI,CAAC,MAAM,CACP,SAAS,CAAC,6BAA6B,CAAC,MAAM,CAAC,EAC/C,GAAG,EAAE,CAAC,kDAAkD,CAAC,CAAC;IAC9D,IAAI,CAAC,MAAM,CACP,UAAU,KAAK,KAAK,EACpB,GAAG,EAAE,CAAC,sCACF,UAAU,uCAAuC,CAAC,CAAC;IAE3D,MAAM,QAAQ,GAAG,OAAO,CACpB,OAAO,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACxE,MAAM,OAAO,GAAG,OAAO,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC5E,MAAM,OAAO,GAAqB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;IAC9C,MAAM,SAAS,GAAqB,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;IAElD,MAAM,gBAAgB,GAAG,MAAM,CAAC;IAEhC,MAAM,GAAG,GAAG,MAAM,CACb,OAAoB,EAAG,QAAqB,EAAE,OAAO,EAAE,GAAG,EAC3D,gBAAgB,EAAE,SAAS,EAAE,eAAe,CAAC,CAAC;IAElD,IAAI,YAAY,EAAE;QAChB,OAAO,OAAO,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAM,CAAC;KACxD;IAED,OAAO,OAAO,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAM,CAAC;AACvE,CAAC;AAED,MAAM,CAAC,MAAM,MAAM,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,OAAO,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Tensor2D, Tensor3D, Tensor4D} from '../tensor';\nimport {convertToTensor} from '../tensor_util_env';\nimport {TensorLike} from '../types';\nimport * as util from '../util';\n\nimport {conv2d} from './conv2d';\nimport * as conv_util from './conv_util';\nimport {op} from './operation';\nimport {reshape} from './reshape';\n\n/**\n * Computes a 1D convolution over the input x.\n *\n * @param x The input tensor, of rank 3 or rank 2, of shape\n *     `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed.\n * @param filter The filter, rank 3, of shape\n *     `[filterWidth, inDepth, outDepth]`.\n * @param stride The number of entries by which the filter is moved right at\n *     each step.\n * @param pad The type of padding algorithm.\n *    - `same` and stride 1: output will be of same size as input,\n *       regardless of filter size.\n *    - `valid`: output will be smaller than input if filter is larger\n *       than 1x1.\n *   - For more info, see this guide:\n *     [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](\n *          https://www.tensorflow.org/api_docs/python/tf/nn/convolution)\n * @param dataFormat An optional string from \"NWC\", \"NCW\". Defaults to \"NWC\",\n *     the data is stored in the order of [batch, in_width, in_channels]. Only\n *     \"NWC\" is currently supported.\n * @param dilation The dilation rate in which we sample input values in\n *     atrous convolution. Defaults to `1`. If it is greater than 1, then\n *     stride must be `1`.\n * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is\n *     provided, it will default to truncate.\n *\n * @doc {heading: 'Operations', subheading: 'Convolution'}\n */\nfunction conv1d_<T extends Tensor2D|Tensor3D>(\n    x: T|TensorLike, filter: Tensor3D|TensorLike, stride: number,\n    pad: 'valid'|'same'|number|conv_util.ExplicitPadding,\n    dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,\n    dimRoundingMode?: 'floor'|'round'|'ceil'): T {\n  const $x = convertToTensor(x, 'x', 'conv1d');\n  const $filter = convertToTensor(filter, 'filter', 'conv1d');\n\n  let x3D = $x as Tensor3D;\n  let reshapedTo3D = false;\n  if ($x.rank === 2) {\n    reshapedTo3D = true;\n    x3D = reshape($x, [1, $x.shape[0], $x.shape[1]]);\n  }\n\n  util.assert(\n      x3D.rank === 3,\n      () => `Error in conv1d: input must be rank 3, but got rank ${x3D.rank}.`);\n  util.assert(\n      $filter.rank === 3,\n      () => `Error in conv1d: filter must be rank 3, but got rank ` +\n          `${$filter.rank}.`);\n  conv_util.checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);\n  util.assert(\n      x3D.shape[2] === $filter.shape[1],\n      () => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +\n          `input depth for filter ${$filter.shape[1]}.`);\n  util.assert(\n      conv_util.eitherStridesOrDilationsAreOne(stride, dilation),\n      () => 'Error in conv1D: Either stride or dilation must be 1. ' +\n          `Got stride ${stride} and dilation '${dilation}'`);\n  util.assert(\n      conv_util.stridesOrDilationsArePositive(dilation),\n      () => 'Error in conv1D: Dilated rates should be larger than 0.');\n  util.assert(\n      conv_util.stridesOrDilationsArePositive(stride),\n      () => 'Error in conv1D: Stride should be larger than 0.');\n  util.assert(\n      dataFormat === 'NWC',\n      () => `Error in conv1d: got dataFormat of ${\n          dataFormat} but only NWC is currently supported.`);\n\n  const filter4D = reshape(\n      $filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]);\n  const input4D = reshape(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]);\n  const strides: [number, number] = [1, stride];\n  const dilations: [number, number] = [1, dilation];\n\n  const conv2dDataFormat = 'NHWC';\n\n  const res = conv2d(\n      (input4D as Tensor4D), (filter4D as Tensor4D), strides, pad,\n      conv2dDataFormat, dilations, dimRoundingMode);\n\n  if (reshapedTo3D) {\n    return reshape(res, [res.shape[2], res.shape[3]]) as T;\n  }\n\n  return reshape(res, [res.shape[0], res.shape[2], res.shape[3]]) as T;\n}\n\nexport const conv1d = /* @__PURE__ */ op({conv1d_});\n"]}