gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { ENGINE } from '../../engine';
import { assert } from '../../util';
import { div } from '../div';
import { mul } from '../mul';
import { norm } from '../norm';
import { op } from '../operation';
import { split } from '../split';
import { squeeze } from '../squeeze';
import { stack } from '../stack';
import { sub } from '../sub';
import { sum } from '../sum';
/**
 * Gram-Schmidt orthogonalization.
 *
 * ```js
 * const x = tf.tensor2d([[1, 2], [3, 4]]);
 * let y = tf.linalg.gramSchmidt(x);
 * y.print();
 * console.log('Orthogonalized:');
 * y.dot(y.transpose()).print();  // should be nearly the identity matrix.
 * console.log('First row direction maintained:');
 * const data = await y.array();
 * console.log(data[0][1] / data[0][0]);  // should be nearly 2.
 * ```
 *
 * @param xs The vectors to be orthogonalized, in one of the two following
 *   formats:
 *   - An Array of `tf.Tensor1D`.
 *   - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows
 *     of `xs`.
 *   In each case, all the vectors must have the same length and the length
 *   must be greater than or equal to the number of vectors.
 * @returns The orthogonalized and normalized vectors or matrix.
 *   Orthogonalization means that the vectors or the rows of the matrix
 *   are orthogonal (zero inner products). Normalization means that each
 *   vector or each row of the matrix has an L2 norm that equals `1`.
 *
 * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
 */
function gramSchmidt_(xs) {
    let inputIsTensor2D;
    if (Array.isArray(xs)) {
        inputIsTensor2D = false;
        assert(xs != null && xs.length > 0, () => 'Gram-Schmidt process: input must not be null, undefined, or ' +
            'empty');
        const dim = xs[0].shape[0];
        for (let i = 1; i < xs.length; ++i) {
            assert(xs[i].shape[0] === dim, () => 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +
                `(${xs[i].shape[0]} vs. ${dim})`);
        }
    }
    else {
        inputIsTensor2D = true;
        xs = split(xs, xs.shape[0], 0).map(x => squeeze(x, [0]));
    }
    assert(xs.length <= xs[0].shape[0], () => `Gram-Schmidt: Number of vectors (${xs.length}) exceeds ` +
        `number of dimensions (${xs[0].shape[0]}).`);
    const ys = [];
    const xs1d = xs;
    for (let i = 0; i < xs.length; ++i) {
        ys.push(ENGINE.tidy(() => {
            let x = xs1d[i];
            if (i > 0) {
                for (let j = 0; j < i; ++j) {
                    const proj = mul(sum(mul(ys[j], x)), ys[j]);
                    x = sub(x, proj);
                }
            }
            return div(x, norm(x, 'euclidean'));
        }));
    }
    if (inputIsTensor2D) {
        return stack(ys, 0);
    }
    else {
        return ys;
    }
}
export const gramSchmidt = /* @__PURE__ */ op({ gramSchmidt_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZ3JhbV9zY2htaWR0LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvbGluYWxnL2dyYW1fc2NobWlkdC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sY0FBYyxDQUFDO0FBRXBDLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFFbEMsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLFFBQVEsQ0FBQztBQUMzQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sUUFBUSxDQUFDO0FBQzNCLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFDN0IsT0FBTyxFQUFDLEVBQUUsRUFBQyxNQUFNLGNBQWMsQ0FBQztBQUNoQyxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sVUFBVSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFDbkMsT0FBTyxFQUFDLEtBQUssRUFBQyxNQUFNLFVBQVUsQ0FBQztBQUMvQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sUUFBUSxDQUFDO0FBQzNCLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFFM0I7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7OztHQTJCRztBQUNILFNBQVMsWUFBWSxDQUFDLEVBQXVCO0lBQzNDLElBQUksZUFBd0IsQ0FBQztJQUM3QixJQUFJLEtBQUssQ0FBQyxPQUFPLENBQUMsRUFBRSxDQUFDLEVBQUU7UUFDckIsZUFBZSxHQUFHLEtBQUssQ0FBQztRQUN4QixNQUFNLENBQ0YsRUFBRSxJQUFJLElBQUksSUFBSSxFQUFFLENBQUMsTUFBTSxHQUFHLENBQUMsRUFDM0IsR0FBRyxFQUFFLENBQUMsOERBQThEO1lBQ2hFLE9BQU8sQ0FBQyxDQUFDO1FBQ2pCLE1BQU0sR0FBRyxHQUFHLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDM0IsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLEVBQUUsQ0FBQyxNQUFNLEVBQUUsRUFBRSxDQUFDLEVBQUU7WUFDbEMsTUFBTSxDQUNGLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEtBQUssR0FBRyxFQUN0QixHQUFHLEVBQUUsQ0FDRCwrREFBK0Q7Z0JBQy9ELElBQUssRUFBaUIsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLFFBQVEsR0FBRyxHQUFHLENBQUMsQ0FBQztTQUMzRDtLQUNGO1NBQU07UUFDTCxlQUFlLEdBQUcsSUFBSSxDQUFDO1FBQ3ZCLEVBQUUsR0FBRyxLQUFLLENBQUMsRUFBRSxFQUFFLEVBQUUsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsT0FBTyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUMxRDtJQUVELE1BQU0sQ0FDRixFQUFFLENBQUMsTUFBTSxJQUFJLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQzNCLEdBQUcsRUFBRSxDQUFDLG9DQUNLLEVBQWlCLENBQUMsTUFBTSxZQUFZO1FBQzNDLHlCQUEwQixFQUFpQixDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsSUFBSSxDQUFDLENBQUM7SUFFckUsTUFBTSxFQUFFLEdBQWUsRUFBRSxDQUFDO0lBQzFCLE1BQU0sSUFBSSxHQUFHLEVBQUUsQ0FBQztJQUNoQixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsRUFBRSxDQUFDLE1BQU0sRUFBRSxFQUFFLENBQUMsRUFBRTtRQUNsQyxFQUFFLENBQUMsSUFBSSxDQUFDLE1BQU0sQ0FBQyxJQUFJLENBQUMsR0FBRyxFQUFFO1lBQ3ZCLElBQUksQ0FBQyxHQUFHLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQztZQUNoQixJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUU7Z0JBQ1QsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTtvQkFDMUIsTUFBTSxJQUFJLEdBQUcsR0FBRyxDQUFDLEdBQUcsQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7b0JBQzVDLENBQUMsR0FBRyxHQUFHLENBQUMsQ0FBQyxFQUFFLElBQUksQ0FBQyxDQUFDO2lCQUNsQjthQUNGO1lBQ0QsT0FBTyxHQUFHLENBQUMsQ0FBQyxFQUFFLElBQUksQ0FBQyxDQUFDLEVBQUUsV0FBVyxDQUFDLENBQUMsQ0FBQztRQUN0QyxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ0w7SUFFRCxJQUFJLGVBQWUsRUFBRTtRQUNuQixPQUFPLEtBQUssQ0FBQyxFQUFFLEVBQUUsQ0FBQyxDQUFhLENBQUM7S0FDakM7U0FBTTtRQUNMLE9BQU8sRUFBRSxDQUFDO0tBQ1g7QUFDSCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sV0FBVyxHQUFHLGVBQWUsQ0FBQyxFQUFFLENBQUMsRUFBQyxZQUFZLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0VOR0lORX0gZnJvbSAnLi4vLi4vZW5naW5lJztcbmltcG9ydCB7VGVuc29yMUQsIFRlbnNvcjJEfSBmcm9tICcuLi8uLi90ZW5zb3InO1xuaW1wb3J0IHthc3NlcnR9IGZyb20gJy4uLy4uL3V0aWwnO1xuXG5pbXBvcnQge2Rpdn0gZnJvbSAnLi4vZGl2JztcbmltcG9ydCB7bXVsfSBmcm9tICcuLi9tdWwnO1xuaW1wb3J0IHtub3JtfSBmcm9tICcuLi9ub3JtJztcbmltcG9ydCB7b3B9IGZyb20gJy4uL29wZXJhdGlvbic7XG5pbXBvcnQge3NwbGl0fSBmcm9tICcuLi9zcGxpdCc7XG5pbXBvcnQge3NxdWVlemV9IGZyb20gJy4uL3NxdWVlemUnO1xuaW1wb3J0IHtzdGFja30gZnJvbSAnLi4vc3RhY2snO1xuaW1wb3J0IHtzdWJ9IGZyb20gJy4uL3N1Yic7XG5pbXBvcnQge3N1bX0gZnJvbSAnLi4vc3VtJztcblxuLyoqXG4gKiBHcmFtLVNjaG1pZHQgb3J0aG9nb25hbGl6YXRpb24uXG4gKlxuICogYGBganNcbiAqIGNvbnN0IHggPSB0Zi50ZW5zb3IyZChbWzEsIDJdLCBbMywgNF1dKTtcbiAqIGxldCB5ID0gdGYubGluYWxnLmdyYW1TY2htaWR0KHgpO1xuICogeS5wcmludCgpO1xuICogY29uc29sZS5sb2coJ09ydGhvZ29uYWxpemVkOicpO1xuICogeS5kb3QoeS50cmFuc3Bvc2UoKSkucHJpbnQoKTsgIC8vIHNob3VsZCBiZSBuZWFybHkgdGhlIGlkZW50aXR5IG1hdHJpeC5cbiAqIGNvbnNvbGUubG9nKCdGaXJzdCByb3cgZGlyZWN0aW9uIG1haW50YWluZWQ6Jyk7XG4gKiBjb25zdCBkYXRhID0gYXdhaXQgeS5hcnJheSgpO1xuICogY29uc29sZS5sb2coZGF0YVswXVsxXSAvIGRhdGFbMF1bMF0pOyAgLy8gc2hvdWxkIGJlIG5lYXJseSAyLlxuICogYGBgXG4gKlxuICogQHBhcmFtIHhzIFRoZSB2ZWN0b3JzIHRvIGJlIG9ydGhvZ29uYWxpemVkLCBpbiBvbmUgb2YgdGhlIHR3byBmb2xsb3dpbmdcbiAqICAgZm9ybWF0czpcbiAqICAgLSBBbiBBcnJheSBvZiBgdGYuVGVuc29yMURgLlxuICogICAtIEEgYHRmLlRlbnNvcjJEYCwgaS5lLiwgYSBtYXRyaXgsIGluIHdoaWNoIGNhc2UgdGhlIHZlY3RvcnMgYXJlIHRoZSByb3dzXG4gKiAgICAgb2YgYHhzYC5cbiAqICAgSW4gZWFjaCBjYXNlLCBhbGwgdGhlIHZlY3RvcnMgbXVzdCBoYXZlIHRoZSBzYW1lIGxlbmd0aCBhbmQgdGhlIGxlbmd0aFxuICogICBtdXN0IGJlIGdyZWF0ZXIgdGhhbiBvciBlcXVhbCB0byB0aGUgbnVtYmVyIG9mIHZlY3RvcnMuXG4gKiBAcmV0dXJucyBUaGUgb3J0aG9nb25hbGl6ZWQgYW5kIG5vcm1hbGl6ZWQgdmVjdG9ycyBvciBtYXRyaXguXG4gKiAgIE9ydGhvZ29uYWxpemF0aW9uIG1lYW5zIHRoYXQgdGhlIHZlY3RvcnMgb3IgdGhlIHJvd3Mgb2YgdGhlIG1hdHJpeFxuICogICBhcmUgb3J0aG9nb25hbCAoemVybyBpbm5lciBwcm9kdWN0cykuIE5vcm1hbGl6YXRpb24gbWVhbnMgdGhhdCBlYWNoXG4gKiAgIHZlY3RvciBvciBlYWNoIHJvdyBvZiB0aGUgbWF0cml4IGhhcyBhbiBMMiBub3JtIHRoYXQgZXF1YWxzIGAxYC5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOidPcGVyYXRpb25zJywgc3ViaGVhZGluZzonTGluZWFyIEFsZ2VicmEnLCBuYW1lc3BhY2U6J2xpbmFsZyd9XG4gKi9cbmZ1bmN0aW9uIGdyYW1TY2htaWR0Xyh4czogVGVuc29yMURbXXxUZW5zb3IyRCk6IFRlbnNvcjFEW118VGVuc29yMkQge1xuICBsZXQgaW5wdXRJc1RlbnNvcjJEOiBib29sZWFuO1xuICBpZiAoQXJyYXkuaXNBcnJheSh4cykpIHtcbiAgICBpbnB1dElzVGVuc29yMkQgPSBmYWxzZTtcbiAgICBhc3NlcnQoXG4gICAgICAgIHhzICE9IG51bGwgJiYgeHMubGVuZ3RoID4gMCxcbiAgICAgICAgKCkgPT4gJ0dyYW0tU2NobWlkdCBwcm9jZXNzOiBpbnB1dCBtdXN0IG5vdCBiZSBudWxsLCB1bmRlZmluZWQsIG9yICcgK1xuICAgICAgICAgICAgJ2VtcHR5Jyk7XG4gICAgY29uc3QgZGltID0geHNbMF0uc2hhcGVbMF07XG4gICAgZm9yIChsZXQgaSA9IDE7IGkgPCB4cy5sZW5ndGg7ICsraSkge1xuICAgICAgYXNzZXJ0KFxuICAgICAgICAgIHhzW2ldLnNoYXBlWzBdID09PSBkaW0sXG4gICAgICAgICAgKCkgPT5cbiAgICAgICAgICAgICAgJ0dyYW0tU2NobWlkdDogTm9uLXVuaXF1ZSBsZW5ndGhzIGZvdW5kIGluIHRoZSBpbnB1dCB2ZWN0b3JzOiAnICtcbiAgICAgICAgICAgICAgYCgkeyh4cyBhcyBUZW5zb3IxRFtdKVtpXS5zaGFwZVswXX0gdnMuICR7ZGltfSlgKTtcbiAgICB9XG4gIH0gZWxzZSB7XG4gICAgaW5wdXRJc1RlbnNvcjJEID0gdHJ1ZTtcbiAgICB4cyA9IHNwbGl0KHhzLCB4cy5zaGFwZVswXSwgMCkubWFwKHggPT4gc3F1ZWV6ZSh4LCBbMF0pKTtcbiAgfVxuXG4gIGFzc2VydChcbiAgICAgIHhzLmxlbmd0aCA8PSB4c1swXS5zaGFwZVswXSxcbiAgICAgICgpID0+IGBHcmFtLVNjaG1pZHQ6IE51bWJlciBvZiB2ZWN0b3JzICgke1xuICAgICAgICAgICAgICAgICh4cyBhcyBUZW5zb3IxRFtdKS5sZW5ndGh9KSBleGNlZWRzIGAgK1xuICAgICAgICAgIGBudW1iZXIgb2YgZGltZW5zaW9ucyAoJHsoeHMgYXMgVGVuc29yMURbXSlbMF0uc2hhcGVbMF19KS5gKTtcblxuICBjb25zdCB5czogVGVuc29yMURbXSA9IFtdO1xuICBjb25zdCB4czFkID0geHM7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgeHMubGVuZ3RoOyArK2kpIHtcbiAgICB5cy5wdXNoKEVOR0lORS50aWR5KCgpID0+IHtcbiAgICAgIGxldCB4ID0geHMxZFtpXTtcbiAgICAgIGlmIChpID4gMCkge1xuICAgICAgICBmb3IgKGxldCBqID0gMDsgaiA8IGk7ICsraikge1xuICAgICAgICAgIGNvbnN0IHByb2ogPSBtdWwoc3VtKG11bCh5c1tqXSwgeCkpLCB5c1tqXSk7XG4gICAgICAgICAgeCA9IHN1Yih4LCBwcm9qKTtcbiAgICAgICAgfVxuICAgICAgfVxuICAgICAgcmV0dXJuIGRpdih4LCBub3JtKHgsICdldWNsaWRlYW4nKSk7XG4gICAgfSkpO1xuICB9XG5cbiAgaWYgKGlucHV0SXNUZW5zb3IyRCkge1xuICAgIHJldHVybiBzdGFjayh5cywgMCkgYXMgVGVuc29yMkQ7XG4gIH0gZWxzZSB7XG4gICAgcmV0dXJuIHlzO1xuICB9XG59XG5cbmV4cG9ydCBjb25zdCBncmFtU2NobWlkdCA9IC8qIEBfX1BVUkVfXyAqLyBvcCh7Z3JhbVNjaG1pZHRffSk7XG4iXX0=