/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { buffer } from './buffer';
|
import { expandDims } from './expand_dims';
|
import { op } from './operation';
|
import { reshape } from './reshape';
|
import { tile } from './tile';
|
/**
|
* Create an identity matrix.
|
*
|
* @param numRows Number of rows.
|
* @param numColumns Number of columns. Defaults to `numRows`.
|
* @param batchShape If provided, will add the batch shape to the beginning
|
* of the shape of the returned `tf.Tensor` by repeating the identity
|
* matrix.
|
* @param dtype Data type.
|
* @returns Identity matrix of the specified size and data type, possibly
|
* with batch repetition if `batchShape` is specified.
|
*
|
* @doc {heading: 'Tensors', subheading: 'Creation'}
|
*/
|
function eye_(numRows, numColumns, batchShape, dtype = 'float32') {
|
if (numColumns == null) {
|
numColumns = numRows;
|
}
|
const buff = buffer([numRows, numColumns], dtype);
|
const n = numRows <= numColumns ? numRows : numColumns;
|
for (let i = 0; i < n; ++i) {
|
buff.set(1, i, i);
|
}
|
const out = reshape(buff.toTensor(), [numRows, numColumns]);
|
if (batchShape == null) {
|
return out;
|
}
|
else {
|
if (batchShape.length === 1) {
|
return tile(expandDims(out, 0), [batchShape[0], 1, 1]);
|
}
|
else if (batchShape.length === 2) {
|
// tslint:disable-next-line:no-unnecessary-type-assertion
|
return tile(expandDims(expandDims(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
|
}
|
else if (batchShape.length === 3) {
|
// tslint:disable-next-line:no-unnecessary-type-assertion
|
return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [
|
batchShape[0], batchShape[1], batchShape[2], 1, 1
|
]);
|
}
|
else {
|
throw new Error(`eye() currently supports only 1D and 2D ` +
|
// tslint:disable-next-line:no-any
|
`batchShapes, but received ${batchShape.length}D.`);
|
}
|
}
|
}
|
export const eye = /* @__PURE__ */ op({ eye_ });
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZXllLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvZXllLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUtILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxVQUFVLENBQUM7QUFDaEMsT0FBTyxFQUFDLFVBQVUsRUFBQyxNQUFNLGVBQWUsQ0FBQztBQUN6QyxPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLFFBQVEsQ0FBQztBQUU1Qjs7Ozs7Ozs7Ozs7OztHQWFHO0FBQ0gsU0FBUyxJQUFJLENBQ1QsT0FBZSxFQUFFLFVBQW1CLEVBQ3BDLFVBSXdFLEVBQ3hFLFFBQWtCLFNBQVM7SUFDN0IsSUFBSSxVQUFVLElBQUksSUFBSSxFQUFFO1FBQ3RCLFVBQVUsR0FBRyxPQUFPLENBQUM7S0FDdEI7SUFDRCxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsQ0FBQyxPQUFPLEVBQUUsVUFBVSxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFDbEQsTUFBTSxDQUFDLEdBQUcsT0FBTyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxVQUFVLENBQUM7SUFDdkQsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLENBQUMsRUFBRSxFQUFFLENBQUMsRUFBRTtRQUMxQixJQUFJLENBQUMsR0FBRyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUM7S0FDbkI7SUFDRCxNQUFNLEdBQUcsR0FBYSxPQUFPLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxFQUFFLENBQUMsT0FBTyxFQUFFLFVBQVUsQ0FBQyxDQUFDLENBQUM7SUFDdEUsSUFBSSxVQUFVLElBQUksSUFBSSxFQUFFO1FBQ3RCLE9BQU8sR0FBRyxDQUFDO0tBQ1o7U0FBTTtRQUNMLElBQUksVUFBVSxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7WUFDM0IsT0FBTyxJQUFJLENBQUMsVUFBVSxDQUFDLEdBQUcsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLFVBQVUsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQWEsQ0FBQztTQUNwRTthQUFNLElBQUksVUFBVSxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7WUFDbEMseURBQXlEO1lBQ3pELE9BQU8sSUFBSSxDQUNBLFVBQVUsQ0FBQyxVQUFVLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUNqQyxDQUFDLFVBQVUsQ0FBQyxDQUFDLENBQUMsRUFBRSxVQUFVLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFhLENBQUM7U0FDOUQ7YUFBTSxJQUFJLFVBQVUsQ0FBQyxNQUFNLEtBQUssQ0FBQyxFQUFFO1lBQ2xDLHlEQUF5RDtZQUN6RCxPQUFPLElBQUksQ0FBQyxVQUFVLENBQUMsVUFBVSxDQUFDLFVBQVUsQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUU7Z0JBQ3JELFVBQVUsQ0FBQyxDQUFDLENBQUMsRUFBRSxVQUFVLENBQUMsQ0FBQyxDQUFDLEVBQUUsVUFBVSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDO2FBQ2xELENBQWEsQ0FBQztTQUN2QjthQUFNO1lBQ0wsTUFBTSxJQUFJLEtBQUssQ0FDWCwwQ0FBMEM7Z0JBQzFDLGtDQUFrQztnQkFDbEMsNkJBQThCLFVBQWtCLENBQUMsTUFBTSxJQUFJLENBQUMsQ0FBQztTQUNsRTtLQUNGO0FBQ0gsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLEdBQUcsR0FBRyxlQUFlLENBQUMsRUFBRSxDQUFDLEVBQUMsSUFBSSxFQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtUZW5zb3IyRH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7RGF0YVR5cGV9IGZyb20gJy4uL3R5cGVzJztcblxuaW1wb3J0IHtidWZmZXJ9IGZyb20gJy4vYnVmZmVyJztcbmltcG9ydCB7ZXhwYW5kRGltc30gZnJvbSAnLi9leHBhbmRfZGltcyc7XG5pbXBvcnQge29wfSBmcm9tICcuL29wZXJhdGlvbic7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vcmVzaGFwZSc7XG5pbXBvcnQge3RpbGV9IGZyb20gJy4vdGlsZSc7XG5cbi8qKlxuICogQ3JlYXRlIGFuIGlkZW50aXR5IG1hdHJpeC5cbiAqXG4gKiBAcGFyYW0gbnVtUm93cyBOdW1iZXIgb2Ygcm93cy5cbiAqIEBwYXJhbSBudW1Db2x1bW5zIE51bWJlciBvZiBjb2x1bW5zLiBEZWZhdWx0cyB0byBgbnVtUm93c2AuXG4gKiBAcGFyYW0gYmF0Y2hTaGFwZSBJZiBwcm92aWRlZCwgd2lsbCBhZGQgdGhlIGJhdGNoIHNoYXBlIHRvIHRoZSBiZWdpbm5pbmdcbiAqICAgb2YgdGhlIHNoYXBlIG9mIHRoZSByZXR1cm5lZCBgdGYuVGVuc29yYCBieSByZXBlYXRpbmcgdGhlIGlkZW50aXR5XG4gKiAgIG1hdHJpeC5cbiAqIEBwYXJhbSBkdHlwZSBEYXRhIHR5cGUuXG4gKiBAcmV0dXJucyBJZGVudGl0eSBtYXRyaXggb2YgdGhlIHNwZWNpZmllZCBzaXplIGFuZCBkYXRhIHR5cGUsIHBvc3NpYmx5XG4gKiAgIHdpdGggYmF0Y2ggcmVwZXRpdGlvbiBpZiBgYmF0Y2hTaGFwZWAgaXMgc3BlY2lmaWVkLlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdUZW5zb3JzJywgc3ViaGVhZGluZzogJ0NyZWF0aW9uJ31cbiAqL1xuZnVuY3Rpb24gZXllXyhcbiAgICBudW1Sb3dzOiBudW1iZXIsIG51bUNvbHVtbnM/OiBudW1iZXIsXG4gICAgYmF0Y2hTaGFwZT86XG4gICAgICAgIFtcbiAgICAgICAgICBudW1iZXJcbiAgICAgICAgXXxbbnVtYmVyLFxuICAgICAgICAgICBudW1iZXJdfFtudW1iZXIsIG51bWJlciwgbnVtYmVyXXxbbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyXSxcbiAgICBkdHlwZTogRGF0YVR5cGUgPSAnZmxvYXQzMicpOiBUZW5zb3IyRCB7XG4gIGlmIChudW1Db2x1bW5zID09IG51bGwpIHtcbiAgICBudW1Db2x1bW5zID0gbnVtUm93cztcbiAgfVxuICBjb25zdCBidWZmID0gYnVmZmVyKFtudW1Sb3dzLCBudW1Db2x1bW5zXSwgZHR5cGUpO1xuICBjb25zdCBuID0gbnVtUm93cyA8PSBudW1Db2x1bW5zID8gbnVtUm93cyA6IG51bUNvbHVtbnM7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgbjsgKytpKSB7XG4gICAgYnVmZi5zZXQoMSwgaSwgaSk7XG4gIH1cbiAgY29uc3Qgb3V0OiBUZW5zb3IyRCA9IHJlc2hhcGUoYnVmZi50b1RlbnNvcigpLCBbbnVtUm93cywgbnVtQ29sdW1uc10pO1xuICBpZiAoYmF0Y2hTaGFwZSA9PSBudWxsKSB7XG4gICAgcmV0dXJuIG91dDtcbiAgfSBlbHNlIHtcbiAgICBpZiAoYmF0Y2hTaGFwZS5sZW5ndGggPT09IDEpIHtcbiAgICAgIHJldHVybiB0aWxlKGV4cGFuZERpbXMob3V0LCAwKSwgW2JhdGNoU2hhcGVbMF0sIDEsIDFdKSBhcyBUZW5zb3IyRDtcbiAgICB9IGVsc2UgaWYgKGJhdGNoU2hhcGUubGVuZ3RoID09PSAyKSB7XG4gICAgICAvLyB0c2xpbnQ6ZGlzYWJsZS1uZXh0LWxpbmU6bm8tdW5uZWNlc3NhcnktdHlwZS1hc3NlcnRpb25cbiAgICAgIHJldHVybiB0aWxlKFxuICAgICAgICAgICAgICAgICBleHBhbmREaW1zKGV4cGFuZERpbXMob3V0LCAwKSwgMCksXG4gICAgICAgICAgICAgICAgIFtiYXRjaFNoYXBlWzBdLCBiYXRjaFNoYXBlWzFdLCAxLCAxXSkgYXMgVGVuc29yMkQ7XG4gICAgfSBlbHNlIGlmIChiYXRjaFNoYXBlLmxlbmd0aCA9PT0gMykge1xuICAgICAgLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOm5vLXVubmVjZXNzYXJ5LXR5cGUtYXNzZXJ0aW9uXG4gICAgICByZXR1cm4gdGlsZShleHBhbmREaW1zKGV4cGFuZERpbXMoZXhwYW5kRGltcyhvdXQsIDApLCAwKSwgMCksIFtcbiAgICAgICAgICAgICAgIGJhdGNoU2hhcGVbMF0sIGJhdGNoU2hhcGVbMV0sIGJhdGNoU2hhcGVbMl0sIDEsIDFcbiAgICAgICAgICAgICBdKSBhcyBUZW5zb3IyRDtcbiAgICB9IGVsc2Uge1xuICAgICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICAgIGBleWUoKSBjdXJyZW50bHkgc3VwcG9ydHMgb25seSAxRCBhbmQgMkQgYCArXG4gICAgICAgICAgLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOm5vLWFueVxuICAgICAgICAgIGBiYXRjaFNoYXBlcywgYnV0IHJlY2VpdmVkICR7KGJhdGNoU2hhcGUgYXMgYW55KS5sZW5ndGh9RC5gKTtcbiAgICB9XG4gIH1cbn1cblxuZXhwb3J0IGNvbnN0IGV5ZSA9IC8qIEBfX1BVUkVfXyAqLyBvcCh7ZXllX30pO1xuIl19
|