gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
/**
 * @license
 * Copyright 2017 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import * as util from '../util';
/**
 * Returns true if the axis specifies the inner most dimensions of the
 * array.
 */
export function axesAreInnerMostDims(axes, rank) {
    for (let i = 0; i < axes.length; ++i) {
        if (axes[axes.length - i - 1] !== rank - 1 - i) {
            return false;
        }
    }
    return true;
}
export function combineLocations(outputLoc, reduceLoc, axes) {
    const rank = outputLoc.length + reduceLoc.length;
    const loc = [];
    let outIdx = 0;
    let reduceIdx = 0;
    for (let dim = 0; dim < rank; dim++) {
        if (axes.indexOf(dim) === -1) {
            loc.push(outputLoc[outIdx++]);
        }
        else {
            loc.push(reduceLoc[reduceIdx++]);
        }
    }
    return loc;
}
export function computeOutAndReduceShapes(aShape, axes) {
    const outShape = [];
    const rank = aShape.length;
    for (let dim = 0; dim < rank; dim++) {
        if (axes.indexOf(dim) === -1) {
            outShape.push(aShape[dim]);
        }
    }
    const reduceShape = axes.map(dim => aShape[dim]);
    return [outShape, reduceShape];
}
export function expandShapeToKeepDim(shape, axes) {
    const reduceSubShape = axes.map(x => 1);
    return combineLocations(shape, reduceSubShape, axes);
}
export function assertAxesAreInnerMostDims(msg, axes, rank) {
    util.assert(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
        `Got axes ${axes} and rank-${rank} input.`);
}
/**
 * Returns the axes permutation to be used with `tf.transpose`, if such
 * permutation is necessary. Otherwise it returns null. This method is used by
 * operations that operate only on inner-most axes.
 */
export function getAxesPermutation(axes, rank) {
    if (axesAreInnerMostDims(axes, rank)) {
        return null;
    }
    const result = [];
    for (let i = 0; i < rank; ++i) {
        if (axes.indexOf(i) === -1) {
            result.push(i);
        }
    }
    axes.forEach(axis => result.push(axis));
    return result;
}
/** Returns the axes permutation that undoes the original permutation. */
export function getUndoAxesPermutation(axes) {
    return axes.map((axis, i) => [i, axis])
        .sort((a, b) => a[1] - b[1])
        .map(x => x[0]);
}
export function getInnerMostAxes(numAxes, rank) {
    const res = [];
    for (let i = rank - numAxes; i < rank; ++i) {
        res.push(i);
    }
    return res;
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYXhpc191dGlsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYXhpc191dGlsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sS0FBSyxJQUFJLE1BQU0sU0FBUyxDQUFDO0FBRWhDOzs7R0FHRztBQUNILE1BQU0sVUFBVSxvQkFBb0IsQ0FBQyxJQUFjLEVBQUUsSUFBWTtJQUMvRCxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxDQUFDLE1BQU0sRUFBRSxFQUFFLENBQUMsRUFBRTtRQUNwQyxJQUFJLElBQUksQ0FBQyxJQUFJLENBQUMsTUFBTSxHQUFHLENBQUMsR0FBRyxDQUFDLENBQUMsS0FBSyxJQUFJLEdBQUcsQ0FBQyxHQUFHLENBQUMsRUFBRTtZQUM5QyxPQUFPLEtBQUssQ0FBQztTQUNkO0tBQ0Y7SUFDRCxPQUFPLElBQUksQ0FBQztBQUNkLENBQUM7QUFFRCxNQUFNLFVBQVUsZ0JBQWdCLENBQzVCLFNBQW1CLEVBQUUsU0FBbUIsRUFBRSxJQUFjO0lBQzFELE1BQU0sSUFBSSxHQUFHLFNBQVMsQ0FBQyxNQUFNLEdBQUcsU0FBUyxDQUFDLE1BQU0sQ0FBQztJQUNqRCxNQUFNLEdBQUcsR0FBRyxFQUFFLENBQUM7SUFDZixJQUFJLE1BQU0sR0FBRyxDQUFDLENBQUM7SUFDZixJQUFJLFNBQVMsR0FBRyxDQUFDLENBQUM7SUFDaEIsS0FBSyxJQUFJLEdBQUcsR0FBRyxDQUFDLEVBQUUsR0FBRyxHQUFHLElBQUksRUFBRSxHQUFHLEVBQUUsRUFBRTtRQUNyQyxJQUFJLElBQUksQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLEVBQUU7WUFDNUIsR0FBRyxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsTUFBTSxFQUFFLENBQUMsQ0FBQyxDQUFDO1NBQy9CO2FBQU07WUFDTCxHQUFHLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxDQUFDLENBQUM7U0FDbEM7S0FDRjtJQUNELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sVUFBVSx5QkFBeUIsQ0FDckMsTUFBZ0IsRUFBRSxJQUFjO0lBQ2xDLE1BQU0sUUFBUSxHQUFHLEVBQUUsQ0FBQztJQUNwQixNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsTUFBTSxDQUFDO0lBQzNCLEtBQUssSUFBSSxHQUFHLEdBQUcsQ0FBQyxFQUFFLEdBQUcsR0FBRyxJQUFJLEVBQUUsR0FBRyxFQUFFLEVBQUU7UUFDbkMsSUFBSSxJQUFJLENBQUMsT0FBTyxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFO1lBQzVCLFFBQVEsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7U0FDNUI7S0FDRjtJQUNELE1BQU0sV0FBVyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxNQUFNLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztJQUNqRCxPQUFPLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxDQUFDO0FBQ2pDLENBQUM7QUFFRCxNQUFNLFVBQVUsb0JBQW9CLENBQ2hDLEtBQWUsRUFBRSxJQUFjO0lBQ2pDLE1BQU0sY0FBYyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN4QyxPQUFPLGdCQUFnQixDQUFDLEtBQUssRUFBRSxjQUFjLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdkQsQ0FBQztBQUVELE1BQU0sVUFBVSwwQkFBMEIsQ0FDdEMsR0FBVyxFQUFFLElBQWMsRUFBRSxJQUFZO0lBQzNDLElBQUksQ0FBQyxNQUFNLENBQ1Asb0JBQW9CLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxFQUNoQyxHQUFHLEVBQUUsQ0FBQyxHQUFHLEdBQUcsMENBQTBDO1FBQ2xELFlBQVksSUFBSSxhQUFhLElBQUksU0FBUyxDQUFDLENBQUM7QUFDdEQsQ0FBQztBQUVEOzs7O0dBSUc7QUFDSCxNQUFNLFVBQVUsa0JBQWtCLENBQUMsSUFBYyxFQUFFLElBQVk7SUFFN0QsSUFBSSxvQkFBb0IsQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLEVBQUU7UUFDcEMsT0FBTyxJQUFJLENBQUM7S0FDYjtJQUNELE1BQU0sTUFBTSxHQUFhLEVBQUUsQ0FBQztJQUM1QixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzdCLElBQUksSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRTtZQUMxQixNQUFNLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO1NBQ2hCO0tBQ0Y7SUFDRCxJQUFJLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxFQUFFLENBQUMsTUFBTSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO0lBQ3hDLE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCx5RUFBeUU7QUFDekUsTUFBTSxVQUFVLHNCQUFzQixDQUFDLElBQWM7SUFDbkQsT0FBTyxJQUFJLENBQUMsR0FBRyxDQUFDLENBQUMsSUFBSSxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsSUFBSSxDQUFDLENBQUM7U0FDbEMsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztTQUMzQixHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztBQUN0QixDQUFDO0FBRUQsTUFBTSxVQUFVLGdCQUFnQixDQUFDLE9BQWUsRUFBRSxJQUFZO0lBQzVELE1BQU0sR0FBRyxHQUFhLEVBQUUsQ0FBQztJQUN6QixLQUFLLElBQUksQ0FBQyxHQUFHLElBQUksR0FBRyxPQUFPLEVBQUUsQ0FBQyxHQUFHLElBQUksRUFBRSxFQUFFLENBQUMsRUFBRTtRQUMxQyxHQUFHLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ2I7SUFDRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxNyBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCAqIGFzIHV0aWwgZnJvbSAnLi4vdXRpbCc7XG5cbi8qKlxuICogUmV0dXJucyB0cnVlIGlmIHRoZSBheGlzIHNwZWNpZmllcyB0aGUgaW5uZXIgbW9zdCBkaW1lbnNpb25zIG9mIHRoZVxuICogYXJyYXkuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzOiBudW1iZXJbXSwgcmFuazogbnVtYmVyKTogYm9vbGVhbiB7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgYXhlcy5sZW5ndGg7ICsraSkge1xuICAgIGlmIChheGVzW2F4ZXMubGVuZ3RoIC0gaSAtIDFdICE9PSByYW5rIC0gMSAtIGkpIHtcbiAgICAgIHJldHVybiBmYWxzZTtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIHRydWU7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBjb21iaW5lTG9jYXRpb25zKFxuICAgIG91dHB1dExvYzogbnVtYmVyW10sIHJlZHVjZUxvYzogbnVtYmVyW10sIGF4ZXM6IG51bWJlcltdKTogbnVtYmVyW10ge1xuICBjb25zdCByYW5rID0gb3V0cHV0TG9jLmxlbmd0aCArIHJlZHVjZUxvYy5sZW5ndGg7XG4gIGNvbnN0IGxvYyA9IFtdO1xuICBsZXQgb3V0SWR4ID0gMDtcbiAgbGV0IHJlZHVjZUlkeCA9IDA7XG4gIMKgIGZvciAobGV0IGRpbSA9IDA7IGRpbSA8IHJhbms7IGRpbSsrKSB7XG4gICAgaWYgKGF4ZXMuaW5kZXhPZihkaW0pID09PSAtMSkge1xuICAgICAgbG9jLnB1c2gob3V0cHV0TG9jW291dElkeCsrXSk7XG4gICAgfSBlbHNlIHtcbiAgICAgIGxvYy5wdXNoKHJlZHVjZUxvY1tyZWR1Y2VJZHgrK10pO1xuICAgIH1cbiAgfVxuICByZXR1cm4gbG9jO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY29tcHV0ZU91dEFuZFJlZHVjZVNoYXBlcyhcbiAgICBhU2hhcGU6IG51bWJlcltdLCBheGVzOiBudW1iZXJbXSk6IFtudW1iZXJbXSwgbnVtYmVyW11dIHtcbiAgY29uc3Qgb3V0U2hhcGUgPSBbXTtcbiAgY29uc3QgcmFuayA9IGFTaGFwZS5sZW5ndGg7XG4gIGZvciAobGV0IGRpbSA9IDA7IGRpbSA8IHJhbms7IGRpbSsrKSB7XG4gICAgaWYgKGF4ZXMuaW5kZXhPZihkaW0pID09PSAtMSkge1xuICAgICAgb3V0U2hhcGUucHVzaChhU2hhcGVbZGltXSk7XG4gICAgfVxuICB9XG4gIGNvbnN0IHJlZHVjZVNoYXBlID0gYXhlcy5tYXAoZGltID0+IGFTaGFwZVtkaW1dKTtcbiAgcmV0dXJuIFtvdXRTaGFwZSwgcmVkdWNlU2hhcGVdO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gZXhwYW5kU2hhcGVUb0tlZXBEaW0oXG4gICAgc2hhcGU6IG51bWJlcltdLCBheGVzOiBudW1iZXJbXSk6IG51bWJlcltdIHtcbiAgY29uc3QgcmVkdWNlU3ViU2hhcGUgPSBheGVzLm1hcCh4ID0+IDEpO1xuICByZXR1cm4gY29tYmluZUxvY2F0aW9ucyhzaGFwZSwgcmVkdWNlU3ViU2hhcGUsIGF4ZXMpO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gYXNzZXJ0QXhlc0FyZUlubmVyTW9zdERpbXMoXG4gICAgbXNnOiBzdHJpbmcsIGF4ZXM6IG51bWJlcltdLCByYW5rOiBudW1iZXIpOiB2b2lkIHtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzLCByYW5rKSxcbiAgICAgICgpID0+IGAke21zZ30gc3VwcG9ydHMgb25seSBpbm5lci1tb3N0IGF4ZXMgZm9yIG5vdy4gYCArXG4gICAgICAgICAgYEdvdCBheGVzICR7YXhlc30gYW5kIHJhbmstJHtyYW5rfSBpbnB1dC5gKTtcbn1cblxuLyoqXG4gKiBSZXR1cm5zIHRoZSBheGVzIHBlcm11dGF0aW9uIHRvIGJlIHVzZWQgd2l0aCBgdGYudHJhbnNwb3NlYCwgaWYgc3VjaFxuICogcGVybXV0YXRpb24gaXMgbmVjZXNzYXJ5LiBPdGhlcndpc2UgaXQgcmV0dXJucyBudWxsLiBUaGlzIG1ldGhvZCBpcyB1c2VkIGJ5XG4gKiBvcGVyYXRpb25zIHRoYXQgb3BlcmF0ZSBvbmx5IG9uIGlubmVyLW1vc3QgYXhlcy5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldEF4ZXNQZXJtdXRhdGlvbihheGVzOiBudW1iZXJbXSwgcmFuazogbnVtYmVyKTogbnVtYmVyW118XG4gICAgbnVsbCB7XG4gIGlmIChheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzLCByYW5rKSkge1xuICAgIHJldHVybiBudWxsO1xuICB9XG4gIGNvbnN0IHJlc3VsdDogbnVtYmVyW10gPSBbXTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCByYW5rOyArK2kpIHtcbiAgICBpZiAoYXhlcy5pbmRleE9mKGkpID09PSAtMSkge1xuICAgICAgcmVzdWx0LnB1c2goaSk7XG4gICAgfVxuICB9XG4gIGF4ZXMuZm9yRWFjaChheGlzID0+IHJlc3VsdC5wdXNoKGF4aXMpKTtcbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuLyoqIFJldHVybnMgdGhlIGF4ZXMgcGVybXV0YXRpb24gdGhhdCB1bmRvZXMgdGhlIG9yaWdpbmFsIHBlcm11dGF0aW9uLiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldFVuZG9BeGVzUGVybXV0YXRpb24oYXhlczogbnVtYmVyW10pOiBudW1iZXJbXSB7XG4gIHJldHVybiBheGVzLm1hcCgoYXhpcywgaSkgPT4gW2ksIGF4aXNdKVxuICAgICAgLnNvcnQoKGEsIGIpID0+IGFbMV0gLSBiWzFdKVxuICAgICAgLm1hcCh4ID0+IHhbMF0pO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gZ2V0SW5uZXJNb3N0QXhlcyhudW1BeGVzOiBudW1iZXIsIHJhbms6IG51bWJlcik6IG51bWJlcltdIHtcbiAgY29uc3QgcmVzOiBudW1iZXJbXSA9IFtdO1xuICBmb3IgKGxldCBpID0gcmFuayAtIG51bUF4ZXM7IGkgPCByYW5rOyArK2kpIHtcbiAgICByZXMucHVzaChpKTtcbiAgfVxuICByZXR1cm4gcmVzO1xufVxuIl19