gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { Tensor } from './tensor';
import { upcastType } from './types';
import { assert } from './util';
export function makeTypesMatch(a, b) {
    if (a.dtype === b.dtype) {
        return [a, b];
    }
    const dtype = upcastType(a.dtype, b.dtype);
    return [a.cast(dtype), b.cast(dtype)];
}
export function assertTypesMatch(a, b) {
    assert(a.dtype === b.dtype, () => `The dtypes of the first(${a.dtype}) and` +
        ` second(${b.dtype}) input must match`);
}
export function isTensorInList(tensor, tensorList) {
    return tensorList.some(x => x.id === tensor.id);
}
/**
 * Extracts any `Tensor`s found within the provided object.
 *
 * @param container an object that may be a `Tensor` or may directly contain
 *   `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
 *   is safe to pass any object here, except that `Promise`s are not
 *   supported.
 * @returns An array of `Tensors` found within the passed object. If the
 *   argument is simply a `Tensor', a list containing that `Tensor` is
 *   returned. If the object is not a `Tensor` or does not
 *   contain `Tensors`, an empty list is returned.
 */
export function getTensorsInContainer(result) {
    const list = [];
    const seen = new Set();
    walkTensorContainer(result, list, seen);
    return list;
}
function walkTensorContainer(container, list, seen) {
    if (container == null) {
        return;
    }
    if (container instanceof Tensor) {
        list.push(container);
        return;
    }
    if (!isIterable(container)) {
        return;
    }
    // Iteration over keys works also for arrays.
    const iterable = container;
    for (const k in iterable) {
        const val = iterable[k];
        if (!seen.has(val)) {
            seen.add(val);
            walkTensorContainer(val, list, seen);
        }
    }
}
// tslint:disable-next-line:no-any
function isIterable(obj) {
    return Array.isArray(obj) || typeof obj === 'object';
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidGVuc29yX3V0aWwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL3RlbnNvcl91dGlsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxVQUFVLENBQUM7QUFFaEMsT0FBTyxFQUFDLFVBQVUsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUNuQyxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sUUFBUSxDQUFDO0FBRTlCLE1BQU0sVUFBVSxjQUFjLENBQW1CLENBQUksRUFBRSxDQUFJO0lBQ3pELElBQUksQ0FBQyxDQUFDLEtBQUssS0FBSyxDQUFDLENBQUMsS0FBSyxFQUFFO1FBQ3ZCLE9BQU8sQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUM7S0FDZjtJQUNELE1BQU0sS0FBSyxHQUFHLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUMzQyxPQUFPLENBQUMsQ0FBQyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsRUFBRSxDQUFDLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUM7QUFDeEMsQ0FBQztBQUVELE1BQU0sVUFBVSxnQkFBZ0IsQ0FBQyxDQUFTLEVBQUUsQ0FBUztJQUNuRCxNQUFNLENBQ0YsQ0FBQyxDQUFDLEtBQUssS0FBSyxDQUFDLENBQUMsS0FBSyxFQUNuQixHQUFHLEVBQUUsQ0FBQywyQkFBMkIsQ0FBQyxDQUFDLEtBQUssT0FBTztRQUMzQyxXQUFXLENBQUMsQ0FBQyxLQUFLLG9CQUFvQixDQUFDLENBQUM7QUFDbEQsQ0FBQztBQUVELE1BQU0sVUFBVSxjQUFjLENBQUMsTUFBYyxFQUFFLFVBQW9CO0lBQ2pFLE9BQU8sVUFBVSxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxFQUFFLEtBQUssTUFBTSxDQUFDLEVBQUUsQ0FBQyxDQUFDO0FBQ2xELENBQUM7QUFFRDs7Ozs7Ozs7Ozs7R0FXRztBQUNILE1BQU0sVUFBVSxxQkFBcUIsQ0FBQyxNQUF1QjtJQUMzRCxNQUFNLElBQUksR0FBYSxFQUFFLENBQUM7SUFDMUIsTUFBTSxJQUFJLEdBQUcsSUFBSSxHQUFHLEVBQVcsQ0FBQztJQUNoQyxtQkFBbUIsQ0FBQyxNQUFNLEVBQUUsSUFBSSxFQUFFLElBQUksQ0FBQyxDQUFDO0lBQ3hDLE9BQU8sSUFBSSxDQUFDO0FBQ2QsQ0FBQztBQUVELFNBQVMsbUJBQW1CLENBQ3hCLFNBQTBCLEVBQUUsSUFBYyxFQUFFLElBQWtCO0lBQ2hFLElBQUksU0FBUyxJQUFJLElBQUksRUFBRTtRQUNyQixPQUFPO0tBQ1I7SUFDRCxJQUFJLFNBQVMsWUFBWSxNQUFNLEVBQUU7UUFDL0IsSUFBSSxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsQ0FBQztRQUNyQixPQUFPO0tBQ1I7SUFDRCxJQUFJLENBQUMsVUFBVSxDQUFDLFNBQVMsQ0FBQyxFQUFFO1FBQzFCLE9BQU87S0FDUjtJQUNELDZDQUE2QztJQUM3QyxNQUFNLFFBQVEsR0FBRyxTQUFpQyxDQUFDO0lBQ25ELEtBQUssTUFBTSxDQUFDLElBQUksUUFBUSxFQUFFO1FBQ3hCLE1BQU0sR0FBRyxHQUFHLFFBQVEsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN4QixJQUFJLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxHQUFHLENBQUMsRUFBRTtZQUNsQixJQUFJLENBQUMsR0FBRyxDQUFDLEdBQUcsQ0FBQyxDQUFDO1lBQ2QsbUJBQW1CLENBQUMsR0FBRyxFQUFFLElBQUksRUFBRSxJQUFJLENBQUMsQ0FBQztTQUN0QztLQUNGO0FBQ0gsQ0FBQztBQUVELGtDQUFrQztBQUNsQyxTQUFTLFVBQVUsQ0FBQyxHQUFRO0lBQzFCLE9BQU8sS0FBSyxDQUFDLE9BQU8sQ0FBQyxHQUFHLENBQUMsSUFBSSxPQUFPLEdBQUcsS0FBSyxRQUFRLENBQUM7QUFDdkQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtUZW5zb3J9IGZyb20gJy4vdGVuc29yJztcbmltcG9ydCB7VGVuc29yQ29udGFpbmVyLCBUZW5zb3JDb250YWluZXJBcnJheX0gZnJvbSAnLi90ZW5zb3JfdHlwZXMnO1xuaW1wb3J0IHt1cGNhc3RUeXBlfSBmcm9tICcuL3R5cGVzJztcbmltcG9ydCB7YXNzZXJ0fSBmcm9tICcuL3V0aWwnO1xuXG5leHBvcnQgZnVuY3Rpb24gbWFrZVR5cGVzTWF0Y2g8VCBleHRlbmRzIFRlbnNvcj4oYTogVCwgYjogVCk6IFtULCBUXSB7XG4gIGlmIChhLmR0eXBlID09PSBiLmR0eXBlKSB7XG4gICAgcmV0dXJuIFthLCBiXTtcbiAgfVxuICBjb25zdCBkdHlwZSA9IHVwY2FzdFR5cGUoYS5kdHlwZSwgYi5kdHlwZSk7XG4gIHJldHVybiBbYS5jYXN0KGR0eXBlKSwgYi5jYXN0KGR0eXBlKV07XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBhc3NlcnRUeXBlc01hdGNoKGE6IFRlbnNvciwgYjogVGVuc29yKTogdm9pZCB7XG4gIGFzc2VydChcbiAgICAgIGEuZHR5cGUgPT09IGIuZHR5cGUsXG4gICAgICAoKSA9PiBgVGhlIGR0eXBlcyBvZiB0aGUgZmlyc3QoJHthLmR0eXBlfSkgYW5kYCArXG4gICAgICAgICAgYCBzZWNvbmQoJHtiLmR0eXBlfSkgaW5wdXQgbXVzdCBtYXRjaGApO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gaXNUZW5zb3JJbkxpc3QodGVuc29yOiBUZW5zb3IsIHRlbnNvckxpc3Q6IFRlbnNvcltdKTogYm9vbGVhbiB7XG4gIHJldHVybiB0ZW5zb3JMaXN0LnNvbWUoeCA9PiB4LmlkID09PSB0ZW5zb3IuaWQpO1xufVxuXG4vKipcbiAqIEV4dHJhY3RzIGFueSBgVGVuc29yYHMgZm91bmQgd2l0aGluIHRoZSBwcm92aWRlZCBvYmplY3QuXG4gKlxuICogQHBhcmFtIGNvbnRhaW5lciBhbiBvYmplY3QgdGhhdCBtYXkgYmUgYSBgVGVuc29yYCBvciBtYXkgZGlyZWN0bHkgY29udGFpblxuICogICBgVGVuc29yYHMsIHN1Y2ggYXMgYSBgVGVuc29yW11gIG9yIGB7a2V5OiBUZW5zb3IsIC4uLn1gLiBJbiBnZW5lcmFsIGl0XG4gKiAgIGlzIHNhZmUgdG8gcGFzcyBhbnkgb2JqZWN0IGhlcmUsIGV4Y2VwdCB0aGF0IGBQcm9taXNlYHMgYXJlIG5vdFxuICogICBzdXBwb3J0ZWQuXG4gKiBAcmV0dXJucyBBbiBhcnJheSBvZiBgVGVuc29yc2AgZm91bmQgd2l0aGluIHRoZSBwYXNzZWQgb2JqZWN0LiBJZiB0aGVcbiAqICAgYXJndW1lbnQgaXMgc2ltcGx5IGEgYFRlbnNvcicsIGEgbGlzdCBjb250YWluaW5nIHRoYXQgYFRlbnNvcmAgaXNcbiAqICAgcmV0dXJuZWQuIElmIHRoZSBvYmplY3QgaXMgbm90IGEgYFRlbnNvcmAgb3IgZG9lcyBub3RcbiAqICAgY29udGFpbiBgVGVuc29yc2AsIGFuIGVtcHR5IGxpc3QgaXMgcmV0dXJuZWQuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBnZXRUZW5zb3JzSW5Db250YWluZXIocmVzdWx0OiBUZW5zb3JDb250YWluZXIpOiBUZW5zb3JbXSB7XG4gIGNvbnN0IGxpc3Q6IFRlbnNvcltdID0gW107XG4gIGNvbnN0IHNlZW4gPSBuZXcgU2V0PHt9fHZvaWQ+KCk7XG4gIHdhbGtUZW5zb3JDb250YWluZXIocmVzdWx0LCBsaXN0LCBzZWVuKTtcbiAgcmV0dXJuIGxpc3Q7XG59XG5cbmZ1bmN0aW9uIHdhbGtUZW5zb3JDb250YWluZXIoXG4gICAgY29udGFpbmVyOiBUZW5zb3JDb250YWluZXIsIGxpc3Q6IFRlbnNvcltdLCBzZWVuOiBTZXQ8e318dm9pZD4pOiB2b2lkIHtcbiAgaWYgKGNvbnRhaW5lciA9PSBudWxsKSB7XG4gICAgcmV0dXJuO1xuICB9XG4gIGlmIChjb250YWluZXIgaW5zdGFuY2VvZiBUZW5zb3IpIHtcbiAgICBsaXN0LnB1c2goY29udGFpbmVyKTtcbiAgICByZXR1cm47XG4gIH1cbiAgaWYgKCFpc0l0ZXJhYmxlKGNvbnRhaW5lcikpIHtcbiAgICByZXR1cm47XG4gIH1cbiAgLy8gSXRlcmF0aW9uIG92ZXIga2V5cyB3b3JrcyBhbHNvIGZvciBhcnJheXMuXG4gIGNvbnN0IGl0ZXJhYmxlID0gY29udGFpbmVyIGFzIFRlbnNvckNvbnRhaW5lckFycmF5O1xuICBmb3IgKGNvbnN0IGsgaW4gaXRlcmFibGUpIHtcbiAgICBjb25zdCB2YWwgPSBpdGVyYWJsZVtrXTtcbiAgICBpZiAoIXNlZW4uaGFzKHZhbCkpIHtcbiAgICAgIHNlZW4uYWRkKHZhbCk7XG4gICAgICB3YWxrVGVuc29yQ29udGFpbmVyKHZhbCwgbGlzdCwgc2Vlbik7XG4gICAgfVxuICB9XG59XG5cbi8vIHRzbGludDpkaXNhYmxlLW5leHQtbGluZTpuby1hbnlcbmZ1bmN0aW9uIGlzSXRlcmFibGUob2JqOiBhbnkpOiBib29sZWFuIHtcbiAgcmV0dXJuIEFycmF5LmlzQXJyYXkob2JqKSB8fCB0eXBlb2Ygb2JqID09PSAnb2JqZWN0Jztcbn1cbiJdfQ==