"use strict"; /** * @license * Copyright 2018 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ Object.defineProperty(exports, "__esModule", { value: true }); var engine_1 = require("../engine"); var tensor_util_env_1 = require("../tensor_util_env"); var util_1 = require("../util"); var util_2 = require("../util"); var concat_util_1 = require("./concat_util"); var operation_1 = require("./operation"); var tensor_ops_1 = require("./tensor_ops"); /** * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details. * * For example, if: * A: shape(3) = |r1, g1, b1| * B: shape(2) = |r2, g2| * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2| * * @param tensors A list of`tf.Tensor`s to concatenate. * @return The concatenated array. */ function concat1d_(tensors) { return exports.concat(tensors, 0 /* axis */); } /** * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details. * * For example, if: * A: shape(2, 3) = | r1, g1, b1 | * | r2, g2, b2 | * * B: shape(2, 3) = | r3, g3, b3 | * | r4, g4, b4 | * * C = tf.concat2d([A, B], axis) * * if axis = 0: * C: shape(4, 3) = | r1, g1, b1 | * | r2, g2, b2 | * | r3, g3, b3 | * | r4, g4, b4 | * * if axis = 1: * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 | * | r2, g2, b2, r4, g4, b4 | * * * @param tensors A list of `tf.Tensor`s to concatenate. * @param axis The axis to concatenate along. * @return The concatenated array. */ function concat2d_(tensors, axis) { return exports.concat(tensors, axis); } /** * Concatenates a list of `tf.Tensor3D`s along an axis. * See `concat` for details. * * For example, if: * A: shape(2, 1, 3) = | r1, g1, b1 | * | r2, g2, b2 | * * B: shape(2, 1, 3) = | r3, g3, b3 | * | r4, g4, b4 | * * C = tf.concat3d([A, B], axis) * * if axis = 0: * C: shape(4, 1, 3) = | r1, g1, b1 | * | r2, g2, b2 | * | r3, g3, b3 | * | r4, g4, b4 | * * if axis = 1: * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 | * | r2, g2, b2, r4, g4, b4 | * * if axis = 2: * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 | * | r2, g2, b2, r4, g4, b4 | * * @param tensors A list of`tf.Tensor`s to concatenate. * @param axis The axis to concate along. * @return The concatenated array. */ function concat3d_(tensors, axis) { return exports.concat(tensors, axis); } /** * Concatenates a list of `tf.Tensor4D`s along an axis. * See `concat` for details. * * @param tensors A list of `tf.Tensor`s to concatenate. * @param axis The axis to concate along. * @return The concatenated array. */ function concat4d_(tensors, axis) { return exports.concat(tensors, axis); } /** * Concatenates a list of `tf.Tensor`s along a given axis. * * The tensors ranks and types must match, and their sizes must match in all * dimensions except `axis`. * * Also available are stricter rank-specific methods that assert that * `tensors` are of the given rank: * - `tf.concat1d` * - `tf.concat2d` * - `tf.concat3d` * - `tf.concat4d` * * Except `tf.concat1d` (which does not have axis param), all methods have * same signature as this method. * * ```js * const a = tf.tensor1d([1, 2]); * const b = tf.tensor1d([3, 4]); * a.concat(b).print(); // or a.concat(b) * ``` * * ```js * const a = tf.tensor1d([1, 2]); * const b = tf.tensor1d([3, 4]); * const c = tf.tensor1d([5, 6]); * tf.concat([a, b, c]).print(); * ``` * * ```js * const a = tf.tensor2d([[1, 2], [10, 20]]); * const b = tf.tensor2d([[3, 4], [30, 40]]); * const axis = 1; * tf.concat([a, b], axis).print(); * ``` * @param tensors A list of tensors to concatenate. * @param axis The axis to concate along. Defaults to 0 (the first dim). */ /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function concat_(tensors, axis) { if (axis === void 0) { axis = 0; } util_1.assert(tensors.length >= 1, function () { return 'Pass at least one tensor to concat'; }); var $tensors = tensor_util_env_1.convertToTensorArray(tensors, 'tensors', 'concat'); if ($tensors[0].dtype === 'complex64') { $tensors.forEach(function (tensor) { if (tensor.dtype !== 'complex64') { throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype " + tensor.dtype + ". "); } }); } axis = util_2.parseAxisParam(axis, $tensors[0].shape)[0]; var outShape = concat_util_1.computeOutShape($tensors.map(function (t) { return t.shape; }), axis); if (util_1.sizeFromShape(outShape) === 0) { return tensor_ops_1.tensor([], outShape); } // Keep only non-empty tensors (ignore tensors with 0 in their shape). $tensors = $tensors.filter(function (t) { return t.size > 0; }); if ($tensors.length === 1) { return $tensors[0]; } var shapes = $tensors.map(function (t) { return t.shape; }); concat_util_1.assertParamsConsistent(shapes, axis); var der = function (dy) { var sizeSplits = shapes.map(function (s) { return s[axis]; }); var derTensors = exports.split(dy, sizeSplits, axis); return derTensors.map(function (t) { return function () { return t; }; }); }; var inputs = $tensors; var attr = { axis: axis }; return engine_1.ENGINE.runKernelFunc(function (backend) { return backend.concat($tensors, axis); }, inputs, der, 'Concat', attr); } /** * Splits a `tf.Tensor` into sub tensors. * * If `numOrSizeSplits` is a number, splits `x` along dimension `axis` * into `numOrSizeSplits` smaller tensors. * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`. * * If `numOrSizeSplits` is a number array, splits `x` into * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the * same size as `x` except along dimension `axis` where the size is * `numOrSizeSplits[i]`. * * ```js * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); * const [a, b] = tf.split(x, 2, 1); * a.print(); * b.print(); * * const [c, d, e] = tf.split(x, [1, 2, 1], 1); * c.print(); * d.print(); * e.print(); * ``` * * @param x The input tensor to split. * @param numOrSizeSplits Either an integer indicating the number of * splits along the axis or an array of integers containing the sizes of * each output tensor along the axis. If a number then it must evenly divide * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. * @param axis The dimension along which to split. Defaults to 0 (the first * dim). */ /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function split_(x, numOrSizeSplits, axis) { if (axis === void 0) { axis = 0; } var $x = tensor_util_env_1.convertToTensor(x, 'x', 'split'); axis = util_2.parseAxisParam(axis, $x.shape)[0]; var splitSizes; if (typeof (numOrSizeSplits) === 'number') { util_1.assert($x.shape[axis] % numOrSizeSplits === 0, function () { return 'Number of splits must evenly divide the axis.'; }); splitSizes = new Array(numOrSizeSplits).fill($x.shape[axis] / numOrSizeSplits); } else { util_1.assert($x.shape[axis] === numOrSizeSplits.reduce(function (a, b) { return a + b; }), function () { return 'The sum of sizes must match the size of the axis dimension.'; }); splitSizes = numOrSizeSplits; } var der = function (dy) { return ({ $x: function () { return exports.concat(dy, axis); } }); }; return engine_1.ENGINE.runKernelFunc(function (backend) { return backend.split($x, splitSizes, axis); }, { $x: $x }, der); } exports.concat = operation_1.op({ concat_: concat_ }); exports.concat1d = operation_1.op({ concat1d_: concat1d_ }); exports.concat2d = operation_1.op({ concat2d_: concat2d_ }); exports.concat3d = operation_1.op({ concat3d_: concat3d_ }); exports.concat4d = operation_1.op({ concat4d_: concat4d_ }); exports.split = operation_1.op({ split_: split_ }); //# sourceMappingURL=concat_split.js.map