"use strict";
|
/**
|
* @license
|
* Copyright 2017 Google Inc. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
Object.defineProperty(exports, "__esModule", { value: true });
|
var util = require("../util");
|
/**
|
* Returns true if the axis specifies the inner most dimensions of the
|
* array.
|
*/
|
function axesAreInnerMostDims(axes, rank) {
|
for (var i = 0; i < axes.length; ++i) {
|
if (axes[axes.length - i - 1] !== rank - 1 - i) {
|
return false;
|
}
|
}
|
return true;
|
}
|
exports.axesAreInnerMostDims = axesAreInnerMostDims;
|
function combineLocations(outputLoc, reduceLoc, axes) {
|
var rank = outputLoc.length + reduceLoc.length;
|
var loc = [];
|
var outIdx = 0;
|
var reduceIdx = 0;
|
for (var dim = 0; dim < rank; dim++) {
|
if (axes.indexOf(dim) === -1) {
|
loc.push(outputLoc[outIdx++]);
|
}
|
else {
|
loc.push(reduceLoc[reduceIdx++]);
|
}
|
}
|
return loc;
|
}
|
exports.combineLocations = combineLocations;
|
function computeOutAndReduceShapes(aShape, axes) {
|
var outShape = [];
|
var rank = aShape.length;
|
for (var dim = 0; dim < rank; dim++) {
|
if (axes.indexOf(dim) === -1) {
|
outShape.push(aShape[dim]);
|
}
|
}
|
var reduceShape = axes.map(function (dim) { return aShape[dim]; });
|
return [outShape, reduceShape];
|
}
|
exports.computeOutAndReduceShapes = computeOutAndReduceShapes;
|
function expandShapeToKeepDim(shape, axes) {
|
var reduceSubShape = axes.map(function (x) { return 1; });
|
return combineLocations(shape, reduceSubShape, axes);
|
}
|
exports.expandShapeToKeepDim = expandShapeToKeepDim;
|
function assertAxesAreInnerMostDims(msg, axes, rank) {
|
util.assert(axesAreInnerMostDims(axes, rank), function () { return msg + " supports only inner-most axes for now. " +
|
("Got axes " + axes + " and rank-" + rank + " input."); });
|
}
|
exports.assertAxesAreInnerMostDims = assertAxesAreInnerMostDims;
|
/**
|
* Returns the axes permutation to be used with `tf.transpose`, if such
|
* permutation is necessary. Otherwise it returns null. This method is used by
|
* operations that operate only on inner-most axes.
|
*/
|
function getAxesPermutation(axes, rank) {
|
if (axesAreInnerMostDims(axes, rank)) {
|
return null;
|
}
|
var result = [];
|
for (var i = 0; i < rank; ++i) {
|
if (axes.indexOf(i) === -1) {
|
result.push(i);
|
}
|
}
|
axes.forEach(function (axis) { return result.push(axis); });
|
return result;
|
}
|
exports.getAxesPermutation = getAxesPermutation;
|
/** Returns the axes permutation that undoes the original permutation. */
|
function getUndoAxesPermutation(axes) {
|
return axes.map(function (axis, i) { return [i, axis]; })
|
.sort(function (a, b) { return a[1] - b[1]; })
|
.map(function (x) { return x[0]; });
|
}
|
exports.getUndoAxesPermutation = getUndoAxesPermutation;
|
function getInnerMostAxes(numAxes, rank) {
|
var res = [];
|
for (var i = rank - numAxes; i < rank; ++i) {
|
res.push(i);
|
}
|
return res;
|
}
|
exports.getInnerMostAxes = getInnerMostAxes;
|
//# sourceMappingURL=axis_util.js.map
|