/** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as util from '../util'; /** * Returns true if the axis specifies the inner most dimensions of the * array. */ export function axesAreInnerMostDims(axes, rank) { for (let i = 0; i < axes.length; ++i) { if (axes[axes.length - i - 1] !== rank - 1 - i) { return false; } } return true; } export function combineLocations(outputLoc, reduceLoc, axes) { const rank = outputLoc.length + reduceLoc.length; const loc = []; let outIdx = 0; let reduceIdx = 0; for (let dim = 0; dim < rank; dim++) { if (axes.indexOf(dim) === -1) { loc.push(outputLoc[outIdx++]); } else { loc.push(reduceLoc[reduceIdx++]); } } return loc; } export function computeOutAndReduceShapes(aShape, axes) { const outShape = []; const rank = aShape.length; for (let dim = 0; dim < rank; dim++) { if (axes.indexOf(dim) === -1) { outShape.push(aShape[dim]); } } const reduceShape = axes.map(dim => aShape[dim]); return [outShape, reduceShape]; } export function expandShapeToKeepDim(shape, axes) { const reduceSubShape = axes.map(x => 1); return combineLocations(shape, reduceSubShape, axes); } export function assertAxesAreInnerMostDims(msg, axes, rank) { util.assert(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` + `Got axes ${axes} and rank-${rank} input.`); } /** * Returns the axes permutation to be used with `tf.transpose`, if such * permutation is necessary. Otherwise it returns null. This method is used by * operations that operate only on inner-most axes. */ export function getAxesPermutation(axes, rank) { if (axesAreInnerMostDims(axes, rank)) { return null; } const result = []; for (let i = 0; i < rank; ++i) { if (axes.indexOf(i) === -1) { result.push(i); } } axes.forEach(axis => result.push(axis)); return result; } /** Returns the axes permutation that undoes the original permutation. */ export function getUndoAxesPermutation(axes) { return axes.map((axis, i) => [i, axis]) .sort((a, b) => a[1] - b[1]) .map(x => x[0]); } export function getInnerMostAxes(numAxes, rank) { const res = []; for (let i = rank - numAxes; i < rank; ++i) { res.push(i); } return res; } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYXhpc191dGlsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYXhpc191dGlsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sS0FBSyxJQUFJLE1BQU0sU0FBUyxDQUFDO0FBRWhDOzs7R0FHRztBQUNILE1BQU0sVUFBVSxvQkFBb0IsQ0FBQyxJQUFjLEVBQUUsSUFBWTtJQUMvRCxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxDQUFDLE1BQU0sRUFBRSxFQUFFLENBQUMsRUFBRTtRQUNwQyxJQUFJLElBQUksQ0FBQyxJQUFJLENBQUMsTUFBTSxHQUFHLENBQUMsR0FBRyxDQUFDLENBQUMsS0FBSyxJQUFJLEdBQUcsQ0FBQyxHQUFHLENBQUMsRUFBRTtZQUM5QyxPQUFPLEtBQUssQ0FBQztTQUNkO0tBQ0Y7SUFDRCxPQUFPLElBQUksQ0FBQztBQUNkLENBQUM7QUFFRCxNQUFNLFVBQVUsZ0JBQWdCLENBQzVCLFNBQW1CLEVBQUUsU0FBbUIsRUFBRSxJQUFjO0lBQzFELE1BQU0sSUFBSSxHQUFHLFNBQVMsQ0FBQyxNQUFNLEdBQUcsU0FBUyxDQUFDLE1BQU0sQ0FBQztJQUNqRCxNQUFNLEdBQUcsR0FBRyxFQUFFLENBQUM7SUFDZixJQUFJLE1BQU0sR0FBRyxDQUFDLENBQUM7SUFDZixJQUFJLFNBQVMsR0FBRyxDQUFDLENBQUM7SUFDaEIsS0FBSyxJQUFJLEdBQUcsR0FBRyxDQUFDLEVBQUUsR0FBRyxHQUFHLElBQUksRUFBRSxHQUFHLEVBQUUsRUFBRTtRQUNyQyxJQUFJLElBQUksQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLEVBQUU7WUFDNUIsR0FBRyxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsTUFBTSxFQUFFLENBQUMsQ0FBQyxDQUFDO1NBQy9CO2FBQU07WUFDTCxHQUFHLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxDQUFDLENBQUM7U0FDbEM7S0FDRjtJQUNELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sVUFBVSx5QkFBeUIsQ0FDckMsTUFBZ0IsRUFBRSxJQUFjO0lBQ2xDLE1BQU0sUUFBUSxHQUFHLEVBQUUsQ0FBQztJQUNwQixNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsTUFBTSxDQUFDO0lBQzNCLEtBQUssSUFBSSxHQUFHLEdBQUcsQ0FBQyxFQUFFLEdBQUcsR0FBRyxJQUFJLEVBQUUsR0FBRyxFQUFFLEVBQUU7UUFDbkMsSUFBSSxJQUFJLENBQUMsT0FBTyxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFO1lBQzVCLFFBQVEsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7U0FDNUI7S0FDRjtJQUNELE1BQU0sV0FBVyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxNQUFNLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztJQUNqRCxPQUFPLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxDQUFDO0FBQ2pDLENBQUM7QUFFRCxNQUFNLFVBQVUsb0JBQW9CLENBQ2hDLEtBQWUsRUFBRSxJQUFjO0lBQ2pDLE1BQU0sY0FBYyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN4QyxPQUFPLGdCQUFnQixDQUFDLEtBQUssRUFBRSxjQUFjLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdkQsQ0FBQztBQUVELE1BQU0sVUFBVSwwQkFBMEIsQ0FDdEMsR0FBVyxFQUFFLElBQWMsRUFBRSxJQUFZO0lBQzNDLElBQUksQ0FBQyxNQUFNLENBQ1Asb0JBQW9CLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxFQUNoQyxHQUFHLEVBQUUsQ0FBQyxHQUFHLEdBQUcsMENBQTBDO1FBQ2xELFlBQVksSUFBSSxhQUFhLElBQUksU0FBUyxDQUFDLENBQUM7QUFDdEQsQ0FBQztBQUVEOzs7O0dBSUc7QUFDSCxNQUFNLFVBQVUsa0JBQWtCLENBQUMsSUFBYyxFQUFFLElBQVk7SUFFN0QsSUFBSSxvQkFBb0IsQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLEVBQUU7UUFDcEMsT0FBTyxJQUFJLENBQUM7S0FDYjtJQUNELE1BQU0sTUFBTSxHQUFhLEVBQUUsQ0FBQztJQUM1QixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsSUFBSSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzdCLElBQUksSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRTtZQUMxQixNQUFNLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO1NBQ2hCO0tBQ0Y7SUFDRCxJQUFJLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxFQUFFLENBQUMsTUFBTSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO0lBQ3hDLE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCx5RUFBeUU7QUFDekUsTUFBTSxVQUFVLHNCQUFzQixDQUFDLElBQWM7SUFDbkQsT0FBTyxJQUFJLENBQUMsR0FBRyxDQUFDLENBQUMsSUFBSSxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsSUFBSSxDQUFDLENBQUM7U0FDbEMsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztTQUMzQixHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztBQUN0QixDQUFDO0FBRUQsTUFBTSxVQUFVLGdCQUFnQixDQUFDLE9BQWUsRUFBRSxJQUFZO0lBQzVELE1BQU0sR0FBRyxHQUFhLEVBQUUsQ0FBQztJQUN6QixLQUFLLElBQUksQ0FBQyxHQUFHLElBQUksR0FBRyxPQUFPLEVBQUUsQ0FBQyxHQUFHLElBQUksRUFBRSxFQUFFLENBQUMsRUFBRTtRQUMxQyxHQUFHLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ2I7SUFDRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxNyBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCAqIGFzIHV0aWwgZnJvbSAnLi4vdXRpbCc7XG5cbi8qKlxuICogUmV0dXJucyB0cnVlIGlmIHRoZSBheGlzIHNwZWNpZmllcyB0aGUgaW5uZXIgbW9zdCBkaW1lbnNpb25zIG9mIHRoZVxuICogYXJyYXkuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzOiBudW1iZXJbXSwgcmFuazogbnVtYmVyKTogYm9vbGVhbiB7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgYXhlcy5sZW5ndGg7ICsraSkge1xuICAgIGlmIChheGVzW2F4ZXMubGVuZ3RoIC0gaSAtIDFdICE9PSByYW5rIC0gMSAtIGkpIHtcbiAgICAgIHJldHVybiBmYWxzZTtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIHRydWU7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBjb21iaW5lTG9jYXRpb25zKFxuICAgIG91dHB1dExvYzogbnVtYmVyW10sIHJlZHVjZUxvYzogbnVtYmVyW10sIGF4ZXM6IG51bWJlcltdKTogbnVtYmVyW10ge1xuICBjb25zdCByYW5rID0gb3V0cHV0TG9jLmxlbmd0aCArIHJlZHVjZUxvYy5sZW5ndGg7XG4gIGNvbnN0IGxvYyA9IFtdO1xuICBsZXQgb3V0SWR4ID0gMDtcbiAgbGV0IHJlZHVjZUlkeCA9IDA7XG4gIMKgIGZvciAobGV0IGRpbSA9IDA7IGRpbSA8IHJhbms7IGRpbSsrKSB7XG4gICAgaWYgKGF4ZXMuaW5kZXhPZihkaW0pID09PSAtMSkge1xuICAgICAgbG9jLnB1c2gob3V0cHV0TG9jW291dElkeCsrXSk7XG4gICAgfSBlbHNlIHtcbiAgICAgIGxvYy5wdXNoKHJlZHVjZUxvY1tyZWR1Y2VJZHgrK10pO1xuICAgIH1cbiAgfVxuICByZXR1cm4gbG9jO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY29tcHV0ZU91dEFuZFJlZHVjZVNoYXBlcyhcbiAgICBhU2hhcGU6IG51bWJlcltdLCBheGVzOiBudW1iZXJbXSk6IFtudW1iZXJbXSwgbnVtYmVyW11dIHtcbiAgY29uc3Qgb3V0U2hhcGUgPSBbXTtcbiAgY29uc3QgcmFuayA9IGFTaGFwZS5sZW5ndGg7XG4gIGZvciAobGV0IGRpbSA9IDA7IGRpbSA8IHJhbms7IGRpbSsrKSB7XG4gICAgaWYgKGF4ZXMuaW5kZXhPZihkaW0pID09PSAtMSkge1xuICAgICAgb3V0U2hhcGUucHVzaChhU2hhcGVbZGltXSk7XG4gICAgfVxuICB9XG4gIGNvbnN0IHJlZHVjZVNoYXBlID0gYXhlcy5tYXAoZGltID0+IGFTaGFwZVtkaW1dKTtcbiAgcmV0dXJuIFtvdXRTaGFwZSwgcmVkdWNlU2hhcGVdO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gZXhwYW5kU2hhcGVUb0tlZXBEaW0oXG4gICAgc2hhcGU6IG51bWJlcltdLCBheGVzOiBudW1iZXJbXSk6IG51bWJlcltdIHtcbiAgY29uc3QgcmVkdWNlU3ViU2hhcGUgPSBheGVzLm1hcCh4ID0+IDEpO1xuICByZXR1cm4gY29tYmluZUxvY2F0aW9ucyhzaGFwZSwgcmVkdWNlU3ViU2hhcGUsIGF4ZXMpO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gYXNzZXJ0QXhlc0FyZUlubmVyTW9zdERpbXMoXG4gICAgbXNnOiBzdHJpbmcsIGF4ZXM6IG51bWJlcltdLCByYW5rOiBudW1iZXIpOiB2b2lkIHtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzLCByYW5rKSxcbiAgICAgICgpID0+IGAke21zZ30gc3VwcG9ydHMgb25seSBpbm5lci1tb3N0IGF4ZXMgZm9yIG5vdy4gYCArXG4gICAgICAgICAgYEdvdCBheGVzICR7YXhlc30gYW5kIHJhbmstJHtyYW5rfSBpbnB1dC5gKTtcbn1cblxuLyoqXG4gKiBSZXR1cm5zIHRoZSBheGVzIHBlcm11dGF0aW9uIHRvIGJlIHVzZWQgd2l0aCBgdGYudHJhbnNwb3NlYCwgaWYgc3VjaFxuICogcGVybXV0YXRpb24gaXMgbmVjZXNzYXJ5LiBPdGhlcndpc2UgaXQgcmV0dXJucyBudWxsLiBUaGlzIG1ldGhvZCBpcyB1c2VkIGJ5XG4gKiBvcGVyYXRpb25zIHRoYXQgb3BlcmF0ZSBvbmx5IG9uIGlubmVyLW1vc3QgYXhlcy5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldEF4ZXNQZXJtdXRhdGlvbihheGVzOiBudW1iZXJbXSwgcmFuazogbnVtYmVyKTogbnVtYmVyW118XG4gICAgbnVsbCB7XG4gIGlmIChheGVzQXJlSW5uZXJNb3N0RGltcyhheGVzLCByYW5rKSkge1xuICAgIHJldHVybiBudWxsO1xuICB9XG4gIGNvbnN0IHJlc3VsdDogbnVtYmVyW10gPSBbXTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCByYW5rOyArK2kpIHtcbiAgICBpZiAoYXhlcy5pbmRleE9mKGkpID09PSAtMSkge1xuICAgICAgcmVzdWx0LnB1c2goaSk7XG4gICAgfVxuICB9XG4gIGF4ZXMuZm9yRWFjaChheGlzID0+IHJlc3VsdC5wdXNoKGF4aXMpKTtcbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuLyoqIFJldHVybnMgdGhlIGF4ZXMgcGVybXV0YXRpb24gdGhhdCB1bmRvZXMgdGhlIG9yaWdpbmFsIHBlcm11dGF0aW9uLiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldFVuZG9BeGVzUGVybXV0YXRpb24oYXhlczogbnVtYmVyW10pOiBudW1iZXJbXSB7XG4gIHJldHVybiBheGVzLm1hcCgoYXhpcywgaSkgPT4gW2ksIGF4aXNdKVxuICAgICAgLnNvcnQoKGEsIGIpID0+IGFbMV0gLSBiWzFdKVxuICAgICAgLm1hcCh4ID0+IHhbMF0pO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gZ2V0SW5uZXJNb3N0QXhlcyhudW1BeGVzOiBudW1iZXIsIHJhbms6IG51bWJlcik6IG51bWJlcltdIHtcbiAgY29uc3QgcmVzOiBudW1iZXJbXSA9IFtdO1xuICBmb3IgKGxldCBpID0gcmFuayAtIG51bUF4ZXM7IGkgPCByYW5rOyArK2kpIHtcbiAgICByZXMucHVzaChpKTtcbiAgfVxuICByZXR1cm4gcmVzO1xufVxuIl19