gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { buffer } from './buffer';
import { expandDims } from './expand_dims';
import { op } from './operation';
import { reshape } from './reshape';
import { tile } from './tile';
/**
 * Create an identity matrix.
 *
 * @param numRows Number of rows.
 * @param numColumns Number of columns. Defaults to `numRows`.
 * @param batchShape If provided, will add the batch shape to the beginning
 *   of the shape of the returned `tf.Tensor` by repeating the identity
 *   matrix.
 * @param dtype Data type.
 * @returns Identity matrix of the specified size and data type, possibly
 *   with batch repetition if `batchShape` is specified.
 *
 * @doc {heading: 'Tensors', subheading: 'Creation'}
 */
function eye_(numRows, numColumns, batchShape, dtype = 'float32') {
    if (numColumns == null) {
        numColumns = numRows;
    }
    const buff = buffer([numRows, numColumns], dtype);
    const n = numRows <= numColumns ? numRows : numColumns;
    for (let i = 0; i < n; ++i) {
        buff.set(1, i, i);
    }
    const out = reshape(buff.toTensor(), [numRows, numColumns]);
    if (batchShape == null) {
        return out;
    }
    else {
        if (batchShape.length === 1) {
            return tile(expandDims(out, 0), [batchShape[0], 1, 1]);
        }
        else if (batchShape.length === 2) {
            // tslint:disable-next-line:no-unnecessary-type-assertion
            return tile(expandDims(expandDims(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
        }
        else if (batchShape.length === 3) {
            // tslint:disable-next-line:no-unnecessary-type-assertion
            return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [
                batchShape[0], batchShape[1], batchShape[2], 1, 1
            ]);
        }
        else {
            throw new Error(`eye() currently supports only 1D and 2D ` +
                // tslint:disable-next-line:no-any
                `batchShapes, but received ${batchShape.length}D.`);
        }
    }
}
export const eye = /* @__PURE__ */ op({ eye_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZXllLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvZXllLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUtILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxVQUFVLENBQUM7QUFDaEMsT0FBTyxFQUFDLFVBQVUsRUFBQyxNQUFNLGVBQWUsQ0FBQztBQUN6QyxPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLFFBQVEsQ0FBQztBQUU1Qjs7Ozs7Ozs7Ozs7OztHQWFHO0FBQ0gsU0FBUyxJQUFJLENBQ1QsT0FBZSxFQUFFLFVBQW1CLEVBQ3BDLFVBSXdFLEVBQ3hFLFFBQWtCLFNBQVM7SUFDN0IsSUFBSSxVQUFVLElBQUksSUFBSSxFQUFFO1FBQ3RCLFVBQVUsR0FBRyxPQUFPLENBQUM7S0FDdEI7SUFDRCxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsQ0FBQyxPQUFPLEVBQUUsVUFBVSxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFDbEQsTUFBTSxDQUFDLEdBQUcsT0FBTyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxVQUFVLENBQUM7SUFDdkQsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTtRQUMxQixJQUFJLENBQUMsR0FBRyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUM7S0FDbkI7SUFDRCxNQUFNLEdBQUcsR0FBYSxPQUFPLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxFQUFFLENBQUMsT0FBTyxFQUFFLFVBQVUsQ0FBQyxDQUFDLENBQUM7SUFDdEUsSUFBSSxVQUFVLElBQUksSUFBSSxFQUFFO1FBQ3RCLE9BQU8sR0FBRyxDQUFDO0tBQ1o7U0FBTTtRQUNMLElBQUksVUFBVSxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7WUFDM0IsT0FBTyxJQUFJLENBQUMsVUFBVSxDQUFDLEdBQUcsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLFVBQVUsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQWEsQ0FBQztTQUNwRTthQUFNLElBQUksVUFBVSxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7WUFDbEMseURBQXlEO1lBQ3pELE9BQU8sSUFBSSxDQUNBLFVBQVUsQ0FBQyxVQUFVLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUNqQyxDQUFDLFVBQVUsQ0FBQyxDQUFDLENBQUMsRUFBRSxVQUFVLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFhLENBQUM7U0FDOUQ7YUFBTSxJQUFJLFVBQVUsQ0FBQyxNQUFNLEtBQUssQ0FBQyxFQUFFO1lBQ2xDLHlEQUF5RDtZQUN6RCxPQUFPLElBQUksQ0FBQyxVQUFVLENBQUMsVUFBVSxDQUFDLFVBQVUsQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUU7Z0JBQ3JELFVBQVUsQ0FBQyxDQUFDLENBQUMsRUFBRSxVQUFVLENBQUMsQ0FBQyxDQUFDLEVBQUUsVUFBVSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDO2FBQ2xELENBQWEsQ0FBQztTQUN2QjthQUFNO1lBQ0wsTUFBTSxJQUFJLEtBQUssQ0FDWCwwQ0FBMEM7Z0JBQzFDLGtDQUFrQztnQkFDbEMsNkJBQThCLFVBQWtCLENBQUMsTUFBTSxJQUFJLENBQUMsQ0FBQztTQUNsRTtLQUNGO0FBQ0gsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLEdBQUcsR0FBRyxlQUFlLENBQUMsRUFBRSxDQUFDLEVBQUMsSUFBSSxFQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtUZW5zb3IyRH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7RGF0YVR5cGV9IGZyb20gJy4uL3R5cGVzJztcblxuaW1wb3J0IHtidWZmZXJ9IGZyb20gJy4vYnVmZmVyJztcbmltcG9ydCB7ZXhwYW5kRGltc30gZnJvbSAnLi9leHBhbmRfZGltcyc7XG5pbXBvcnQge29wfSBmcm9tICcuL29wZXJhdGlvbic7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vcmVzaGFwZSc7XG5pbXBvcnQge3RpbGV9IGZyb20gJy4vdGlsZSc7XG5cbi8qKlxuICogQ3JlYXRlIGFuIGlkZW50aXR5IG1hdHJpeC5cbiAqXG4gKiBAcGFyYW0gbnVtUm93cyBOdW1iZXIgb2Ygcm93cy5cbiAqIEBwYXJhbSBudW1Db2x1bW5zIE51bWJlciBvZiBjb2x1bW5zLiBEZWZhdWx0cyB0byBgbnVtUm93c2AuXG4gKiBAcGFyYW0gYmF0Y2hTaGFwZSBJZiBwcm92aWRlZCwgd2lsbCBhZGQgdGhlIGJhdGNoIHNoYXBlIHRvIHRoZSBiZWdpbm5pbmdcbiAqICAgb2YgdGhlIHNoYXBlIG9mIHRoZSByZXR1cm5lZCBgdGYuVGVuc29yYCBieSByZXBlYXRpbmcgdGhlIGlkZW50aXR5XG4gKiAgIG1hdHJpeC5cbiAqIEBwYXJhbSBkdHlwZSBEYXRhIHR5cGUuXG4gKiBAcmV0dXJucyBJZGVudGl0eSBtYXRyaXggb2YgdGhlIHNwZWNpZmllZCBzaXplIGFuZCBkYXRhIHR5cGUsIHBvc3NpYmx5XG4gKiAgIHdpdGggYmF0Y2ggcmVwZXRpdGlvbiBpZiBgYmF0Y2hTaGFwZWAgaXMgc3BlY2lmaWVkLlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdUZW5zb3JzJywgc3ViaGVhZGluZzogJ0NyZWF0aW9uJ31cbiAqL1xuZnVuY3Rpb24gZXllXyhcbiAgICBudW1Sb3dzOiBudW1iZXIsIG51bUNvbHVtbnM/OiBudW1iZXIsXG4gICAgYmF0Y2hTaGFwZT86XG4gICAgICAgIFtcbiAgICAgICAgICBudW1iZXJcbiAgICAgICAgXXxbbnVtYmVyLFxuICAgICAgICAgICBudW1iZXJdfFtudW1iZXIsIG51bWJlciwgbnVtYmVyXXxbbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyXSxcbiAgICBkdHlwZTogRGF0YVR5cGUgPSAnZmxvYXQzMicpOiBUZW5zb3IyRCB7XG4gIGlmIChudW1Db2x1bW5zID09IG51bGwpIHtcbiAgICBudW1Db2x1bW5zID0gbnVtUm93cztcbiAgfVxuICBjb25zdCBidWZmID0gYnVmZmVyKFtudW1Sb3dzLCBudW1Db2x1bW5zXSwgZHR5cGUpO1xuICBjb25zdCBuID0gbnVtUm93cyA8PSBudW1Db2x1bW5zID8gbnVtUm93cyA6IG51bUNvbHVtbnM7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgbjsgKytpKSB7XG4gICAgYnVmZi5zZXQoMSwgaSwgaSk7XG4gIH1cbiAgY29uc3Qgb3V0OiBUZW5zb3IyRCA9IHJlc2hhcGUoYnVmZi50b1RlbnNvcigpLCBbbnVtUm93cywgbnVtQ29sdW1uc10pO1xuICBpZiAoYmF0Y2hTaGFwZSA9PSBudWxsKSB7XG4gICAgcmV0dXJuIG91dDtcbiAgfSBlbHNlIHtcbiAgICBpZiAoYmF0Y2hTaGFwZS5sZW5ndGggPT09IDEpIHtcbiAgICAgIHJldHVybiB0aWxlKGV4cGFuZERpbXMob3V0LCAwKSwgW2JhdGNoU2hhcGVbMF0sIDEsIDFdKSBhcyBUZW5zb3IyRDtcbiAgICB9IGVsc2UgaWYgKGJhdGNoU2hhcGUubGVuZ3RoID09PSAyKSB7XG4gICAgICAvLyB0c2xpbnQ6ZGlzYWJsZS1uZXh0LWxpbmU6bm8tdW5uZWNlc3NhcnktdHlwZS1hc3NlcnRpb25cbiAgICAgIHJldHVybiB0aWxlKFxuICAgICAgICAgICAgICAgICBleHBhbmREaW1zKGV4cGFuZERpbXMob3V0LCAwKSwgMCksXG4gICAgICAgICAgICAgICAgIFtiYXRjaFNoYXBlWzBdLCBiYXRjaFNoYXBlWzFdLCAxLCAxXSkgYXMgVGVuc29yMkQ7XG4gICAgfSBlbHNlIGlmIChiYXRjaFNoYXBlLmxlbmd0aCA9PT0gMykge1xuICAgICAgLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOm5vLXVubmVjZXNzYXJ5LXR5cGUtYXNzZXJ0aW9uXG4gICAgICByZXR1cm4gdGlsZShleHBhbmREaW1zKGV4cGFuZERpbXMoZXhwYW5kRGltcyhvdXQsIDApLCAwKSwgMCksIFtcbiAgICAgICAgICAgICAgIGJhdGNoU2hhcGVbMF0sIGJhdGNoU2hhcGVbMV0sIGJhdGNoU2hhcGVbMl0sIDEsIDFcbiAgICAgICAgICAgICBdKSBhcyBUZW5zb3IyRDtcbiAgICB9IGVsc2Uge1xuICAgICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICAgIGBleWUoKSBjdXJyZW50bHkgc3VwcG9ydHMgb25seSAxRCBhbmQgMkQgYCArXG4gICAgICAgICAgLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOm5vLWFueVxuICAgICAgICAgIGBiYXRjaFNoYXBlcywgYnV0IHJlY2VpdmVkICR7KGJhdGNoU2hhcGUgYXMgYW55KS5sZW5ndGh9RC5gKTtcbiAgICB9XG4gIH1cbn1cblxuZXhwb3J0IGNvbnN0IGV5ZSA9IC8qIEBfX1BVUkVfXyAqLyBvcCh7ZXllX30pO1xuIl19