/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { convertToTensor } from '../../tensor_util_env';
|
import { assert } from '../../util';
|
import { greaterEqual } from '../greater_equal';
|
import { less } from '../less';
|
import { lessEqual } from '../less_equal';
|
import { logicalAnd } from '../logical_and';
|
import { minimum } from '../minimum';
|
import { neg } from '../neg';
|
import { op } from '../operation';
|
import { range } from '../range';
|
import { reshape } from '../reshape';
|
import { stack } from '../stack';
|
import { sub } from '../sub';
|
import { unstack } from '../unstack';
|
import { where } from '../where';
|
import { zeros } from '../zeros';
|
/**
|
* Copy a tensor setting everything outside a central band in each innermost
|
* matrix to zero.
|
*
|
* The band part is computed as follows: Assume input has `k` dimensions
|
* `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
|
* `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
|
* The indicator function
|
* `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)`
|
* `&& (num_upper < 0 || (n-m) <= num_upper)`
|
*
|
* ```js
|
* const x = tf.tensor2d([[ 0, 1, 2, 3],
|
* [-1, 0, 1, 2],
|
* [-2, -1, 0, 1],
|
* [-3, -2, -1, 0]]);
|
* let y = tf.linalg.bandPart(x, 1, -1);
|
* y.print(); // [[ 0, 1, 2, 3],
|
* // [-1, 0, 1, 2],
|
* // [ 0, -1, 0, 1],
|
* // [ 0, 0 , -1, 0]]
|
* let z = tf.linalg.bandPart(x, 2, 1);
|
* z.print(); // [[ 0, 1, 0, 0],
|
* // [-1, 0, 1, 0],
|
* // [-2, -1, 0, 1],
|
* // [ 0, -2, -1, 0]]
|
* ```
|
*
|
* @param x Rank `k` tensor
|
* @param numLower Number of subdiagonals to keep.
|
* If negative, keep entire lower triangle.
|
* @param numUpper Number of subdiagonals to keep.
|
* If negative, keep entire upper triangle.
|
* @returns Rank `k` tensor of the same shape as input.
|
* The extracted banded tensor.
|
*
|
* @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
|
*/
|
function bandPart_(a, numLower, numUpper) {
|
const $a = convertToTensor(a, 'a', 'bandPart');
|
assert($a.rank >= 2, () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`);
|
const shape = $a.shape;
|
const [M, N] = $a.shape.slice(-2);
|
let $numLower;
|
let $numUpper;
|
if (typeof numLower === 'number') {
|
assert(numLower % 1 === 0, () => `bandPart(): numLower must be an integer, got ${numLower}.`);
|
assert(numLower <= M, () => `bandPart(): numLower (${numLower})` +
|
` must not be greater than the number of rows (${M}).`);
|
$numLower =
|
convertToTensor(numLower < 0 ? M : numLower, 'numLower', 'bandPart');
|
}
|
else {
|
assert(numLower.dtype === 'int32', () => `bandPart(): numLower's dtype must be an int32.`);
|
// If numLower is a Scalar, checking `numLower <= M` could hurt performance,
|
// but minimum(numLower, M) could avoid unexpected results.
|
$numLower = where(less(numLower, 0), M, minimum(numLower, M));
|
}
|
if (typeof numUpper === 'number') {
|
assert(numUpper % 1 === 0, () => `bandPart(): numUpper must be an integer, got ${numUpper}.`);
|
assert(numUpper <= N, () => `bandPart(): numUpper (${numUpper})` +
|
` must not be greater than the number of columns (${N}).`);
|
$numUpper =
|
convertToTensor(numUpper < 0 ? N : numUpper, 'numUpper', 'bandPart');
|
}
|
else {
|
assert(numUpper.dtype === 'int32', () => `bandPart(): numUpper's dtype must be an int32.`);
|
$numUpper = where(less(numUpper, 0), N, minimum(numUpper, N));
|
}
|
const i = reshape(range(0, M, 1, 'int32'), [-1, 1]);
|
const j = range(0, N, 1, 'int32');
|
const ij = sub(i, j);
|
const inBand = logicalAnd(lessEqual(ij, $numLower), greaterEqual(ij, neg($numUpper)));
|
const zero = zeros([M, N], $a.dtype);
|
return reshape(stack(unstack(reshape($a, [-1, M, N]))
|
.map(mat => where(inBand, mat, zero))), shape);
|
}
|
export const bandPart = /* @__PURE__ */ op({ bandPart_ });
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"band_part.js","sourceRoot":"","sources":["../../../../../../../tfjs-core/src/ops/linalg/band_part.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,OAAO,EAAC,eAAe,EAAC,MAAM,uBAAuB,CAAC;AAEtD,OAAO,EAAC,MAAM,EAAC,MAAM,YAAY,CAAC;AAElC,OAAO,EAAC,YAAY,EAAC,MAAM,kBAAkB,CAAC;AAC9C,OAAO,EAAC,IAAI,EAAC,MAAM,SAAS,CAAC;AAC7B,OAAO,EAAC,SAAS,EAAC,MAAM,eAAe,CAAC;AACxC,OAAO,EAAC,UAAU,EAAC,MAAM,gBAAgB,CAAC;AAC1C,OAAO,EAAC,OAAO,EAAC,MAAM,YAAY,CAAC;AACnC,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,EAAE,EAAC,MAAM,cAAc,CAAC;AAChC,OAAO,EAAC,KAAK,EAAC,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,OAAO,EAAC,MAAM,YAAY,CAAC;AACnC,OAAO,EAAC,KAAK,EAAC,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,OAAO,EAAC,MAAM,YAAY,CAAC;AACnC,OAAO,EAAC,KAAK,EAAC,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,KAAK,EAAC,MAAM,UAAU,CAAC;AAE/B;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAqCG;AACH,SAAS,SAAS,CACd,CAAe,EAAE,QAAuB,EAAE,QAAuB;IACnE,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,GAAG,EAAE,UAAU,CAAC,CAAC;IAC/C,MAAM,CACF,EAAE,CAAC,IAAI,IAAI,CAAC,EACZ,GAAG,EAAE,CAAC,4CAA4C,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC;IAElE,MAAM,KAAK,GAAG,EAAE,CAAC,KAAK,CAAC;IACvB,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;IAElC,IAAI,SAAiB,CAAC;IACtB,IAAI,SAAiB,CAAC;IACtB,IAAI,OAAO,QAAQ,KAAK,QAAQ,EAAE;QAChC,MAAM,CACF,QAAQ,GAAG,CAAC,KAAK,CAAC,EAClB,GAAG,EAAE,CAAC,gDAAgD,QAAQ,GAAG,CAAC,CAAC;QACvE,MAAM,CACF,QAAQ,IAAI,CAAC,EACb,GAAG,EAAE,CAAC,yBAAyB,QAAQ,GAAG;YACtC,iDAAiD,CAAC,IAAI,CAAC,CAAC;QAChE,SAAS;YACL,eAAe,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,QAAQ,EAAE,UAAU,EAAE,UAAU,CAC7D,CAAC;KACZ;SAAM;QACL,MAAM,CACF,QAAQ,CAAC,KAAK,KAAK,OAAO,EAC1B,GAAG,EAAE,CAAC,gDAAgD,CAAC,CAAC;QAC5D,4EAA4E;QAC5E,2DAA2D;QAC3D,SAAS,GAAG,KAAK,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,OAAO,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAW,CAAC;KACzE;IAED,IAAI,OAAO,QAAQ,KAAK,QAAQ,EAAE;QAChC,MAAM,CACF,QAAQ,GAAG,CAAC,KAAK,CAAC,EAClB,GAAG,EAAE,CAAC,gDAAgD,QAAQ,GAAG,CAAC,CAAC;QACvE,MAAM,CACF,QAAQ,IAAI,CAAC,EACb,GAAG,EAAE,CAAC,yBAAyB,QAAQ,GAAG;YACtC,oDAAoD,CAAC,IAAI,CAAC,CAAC;QACnE,SAAS;YACL,eAAe,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,QAAQ,EAAE,UAAU,EAAE,UAAU,CAC7D,CAAC;KACZ;SAAM;QACL,MAAM,CACF,QAAQ,CAAC,KAAK,KAAK,OAAO,EAC1B,GAAG,EAAE,CAAC,gDAAgD,CAAC,CAAC;QAC5D,SAAS,GAAG,KAAK,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,OAAO,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAW,CAAC;KACzE;IAED,MAAM,CAAC,GAAG,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,OAAO,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpD,MAAM,CAAC,GAAG,KAAK,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,OAAO,CAAC,CAAC;IAClC,MAAM,EAAE,GAAG,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;IAErB,MAAM,MAAM,GACR,UAAU,CAAC,SAAS,CAAC,EAAE,EAAE,SAAS,CAAC,EAAE,YAAY,CAAC,EAAE,EAAE,GAAG,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC;IAE3E,MAAM,IAAI,GAAG,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC;IAErC,OAAO,OAAO,CACH,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;SAC3B,GAAG,CAAC,GAAG,CAAC,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC,EAChD,KAAK,CAAM,CAAC;AACzB,CAAC;AAED,MAAM,CAAC,MAAM,QAAQ,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,SAAS,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Scalar, Tensor} from '../../tensor';\nimport {convertToTensor} from '../../tensor_util_env';\nimport {TensorLike} from '../../types';\nimport {assert} from '../../util';\n\nimport {greaterEqual} from '../greater_equal';\nimport {less} from '../less';\nimport {lessEqual} from '../less_equal';\nimport {logicalAnd} from '../logical_and';\nimport {minimum} from '../minimum';\nimport {neg} from '../neg';\nimport {op} from '../operation';\nimport {range} from '../range';\nimport {reshape} from '../reshape';\nimport {stack} from '../stack';\nimport {sub} from '../sub';\nimport {unstack} from '../unstack';\nimport {where} from '../where';\nimport {zeros} from '../zeros';\n\n/**\n * Copy a tensor setting everything outside a central band in each innermost\n * matrix to zero.\n *\n * The band part is computed as follows: Assume input has `k` dimensions\n * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where\n * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.\n * The indicator function\n * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)`\n * `&& (num_upper < 0 || (n-m) <= num_upper)`\n *\n * ```js\n * const x = tf.tensor2d([[ 0,  1,  2, 3],\n *                        [-1,  0,  1, 2],\n *                        [-2, -1,  0, 1],\n *                        [-3, -2, -1, 0]]);\n * let y = tf.linalg.bandPart(x, 1, -1);\n * y.print(); // [[ 0,  1,  2, 3],\n *            //  [-1,  0,  1, 2],\n *            //  [ 0, -1,  0, 1],\n *            //  [ 0, 0 , -1, 0]]\n * let z = tf.linalg.bandPart(x, 2, 1);\n * z.print(); // [[ 0,  1,  0, 0],\n *            //  [-1,  0,  1, 0],\n *            //  [-2, -1,  0, 1],\n *            //  [ 0, -2, -1, 0]]\n * ```\n *\n * @param x Rank `k` tensor\n * @param numLower Number of subdiagonals to keep.\n *   If negative, keep entire lower triangle.\n * @param numUpper Number of subdiagonals to keep.\n *   If negative, keep entire upper triangle.\n * @returns Rank `k` tensor of the same shape as input.\n *   The extracted banded tensor.\n *\n * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}\n */\nfunction bandPart_<T extends Tensor>(\n    a: T|TensorLike, numLower: number|Scalar, numUpper: number|Scalar): T {\n  const $a = convertToTensor(a, 'a', 'bandPart');\n  assert(\n      $a.rank >= 2,\n      () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`);\n\n  const shape = $a.shape;\n  const [M, N] = $a.shape.slice(-2);\n\n  let $numLower: Scalar;\n  let $numUpper: Scalar;\n  if (typeof numLower === 'number') {\n    assert(\n        numLower % 1 === 0,\n        () => `bandPart(): numLower must be an integer, got ${numLower}.`);\n    assert(\n        numLower <= M,\n        () => `bandPart(): numLower (${numLower})` +\n            ` must not be greater than the number of rows (${M}).`);\n    $numLower =\n        convertToTensor(numLower < 0 ? M : numLower, 'numLower', 'bandPart') as\n        Scalar;\n  } else {\n    assert(\n        numLower.dtype === 'int32',\n        () => `bandPart(): numLower's dtype must be an int32.`);\n    // If numLower is a Scalar, checking `numLower <= M` could hurt performance,\n    // but minimum(numLower, M) could avoid unexpected results.\n    $numLower = where(less(numLower, 0), M, minimum(numLower, M)) as Scalar;\n  }\n\n  if (typeof numUpper === 'number') {\n    assert(\n        numUpper % 1 === 0,\n        () => `bandPart(): numUpper must be an integer, got ${numUpper}.`);\n    assert(\n        numUpper <= N,\n        () => `bandPart(): numUpper (${numUpper})` +\n            ` must not be greater than the number of columns (${N}).`);\n    $numUpper =\n        convertToTensor(numUpper < 0 ? N : numUpper, 'numUpper', 'bandPart') as\n        Scalar;\n  } else {\n    assert(\n        numUpper.dtype === 'int32',\n        () => `bandPart(): numUpper's dtype must be an int32.`);\n    $numUpper = where(less(numUpper, 0), N, minimum(numUpper, N)) as Scalar;\n  }\n\n  const i = reshape(range(0, M, 1, 'int32'), [-1, 1]);\n  const j = range(0, N, 1, 'int32');\n  const ij = sub(i, j);\n\n  const inBand =\n      logicalAnd(lessEqual(ij, $numLower), greaterEqual(ij, neg($numUpper)));\n\n  const zero = zeros([M, N], $a.dtype);\n\n  return reshape(\n             stack(unstack(reshape($a, [-1, M, N]))\n                       .map(mat => where(inBand, mat, zero))),\n             shape) as T;\n}\n\nexport const bandPart = /* @__PURE__ */ op({bandPart_});\n"]}
|