"use strict"; /** * @license * Copyright 2017 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ Object.defineProperty(exports, "__esModule", { value: true }); var util = require("../util"); /** * Returns true if the axis specifies the inner most dimensions of the * array. */ function axesAreInnerMostDims(axes, rank) { for (var i = 0; i < axes.length; ++i) { if (axes[axes.length - i - 1] !== rank - 1 - i) { return false; } } return true; } exports.axesAreInnerMostDims = axesAreInnerMostDims; function combineLocations(outputLoc, reduceLoc, axes) { var rank = outputLoc.length + reduceLoc.length; var loc = []; var outIdx = 0; var reduceIdx = 0; for (var dim = 0; dim < rank; dim++) { if (axes.indexOf(dim) === -1) { loc.push(outputLoc[outIdx++]); } else { loc.push(reduceLoc[reduceIdx++]); } } return loc; } exports.combineLocations = combineLocations; function computeOutAndReduceShapes(aShape, axes) { var outShape = []; var rank = aShape.length; for (var dim = 0; dim < rank; dim++) { if (axes.indexOf(dim) === -1) { outShape.push(aShape[dim]); } } var reduceShape = axes.map(function (dim) { return aShape[dim]; }); return [outShape, reduceShape]; } exports.computeOutAndReduceShapes = computeOutAndReduceShapes; function expandShapeToKeepDim(shape, axes) { var reduceSubShape = axes.map(function (x) { return 1; }); return combineLocations(shape, reduceSubShape, axes); } exports.expandShapeToKeepDim = expandShapeToKeepDim; function assertAxesAreInnerMostDims(msg, axes, rank) { util.assert(axesAreInnerMostDims(axes, rank), function () { return msg + " supports only inner-most axes for now. " + ("Got axes " + axes + " and rank-" + rank + " input."); }); } exports.assertAxesAreInnerMostDims = assertAxesAreInnerMostDims; /** * Returns the axes permutation to be used with `tf.transpose`, if such * permutation is necessary. Otherwise it returns null. This method is used by * operations that operate only on inner-most axes. */ function getAxesPermutation(axes, rank) { if (axesAreInnerMostDims(axes, rank)) { return null; } var result = []; for (var i = 0; i < rank; ++i) { if (axes.indexOf(i) === -1) { result.push(i); } } axes.forEach(function (axis) { return result.push(axis); }); return result; } exports.getAxesPermutation = getAxesPermutation; /** Returns the axes permutation that undoes the original permutation. */ function getUndoAxesPermutation(axes) { return axes.map(function (axis, i) { return [i, axis]; }) .sort(function (a, b) { return a[1] - b[1]; }) .map(function (x) { return x[0]; }); } exports.getUndoAxesPermutation = getUndoAxesPermutation; function getInnerMostAxes(numAxes, rank) { var res = []; for (var i = rank - numAxes; i < rank; ++i) { res.push(i); } return res; } exports.getInnerMostAxes = getInnerMostAxes; //# sourceMappingURL=axis_util.js.map