/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { nearestDivisor } from '../util'; import { PARALLELIZE_THRESHOLD } from './reduce_util'; export function segOpComputeOptimalWindowSize(inSize, numSegments) { let done = false; let res; if (inSize <= PARALLELIZE_THRESHOLD) { res = inSize; done = true; } else { res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); } while (!done) { if (res > numSegments || res === inSize) { done = true; } else { res = nearestDivisor(inSize, res + 1); } } return res; } export function computeOutShape(aShape, axis, numSegments) { const outShape = []; const rank = aShape.length; for (let dim = 0; dim < rank; dim++) { if (dim !== axis) { outShape.push(aShape[dim]); } else { outShape.push(numSegments); } } return outShape; } export function collectGatherOpShapeInfo(x, indices, axis, batchDims) { const indicesRank = indices.shape.length; const xRank = x.shape.length; if (batchDims !== 0) { if (batchDims < -indicesRank || batchDims > indicesRank) { throw new Error(`Expect batchDims in the range of [-${indicesRank}, ${indicesRank}], but got ${batchDims}`); } } if (batchDims < 0) { batchDims += indicesRank; } if (batchDims > xRank) { throw new Error(`batchDims (${batchDims}) must be less than rank(x) ( ${xRank}).`); } if (axis < batchDims) { throw new Error(`batchDims (${batchDims}) must be less than or equal to axis (${axis}).`); } for (let i = 0; i < batchDims; ++i) { if (x.shape[i] !== indices.shape[i]) { throw new Error(`x.shape[${i}]: ${x.shape[i]} should be equal to indices.shape[${i}]: ${indices.shape[i]}.`); } } const dimSize = x.shape[axis]; const outputShape = []; let batchSize = 1; let outerSize = 1; let sliceSize = 1; for (let i = 0; i < batchDims; ++i) { outputShape.push(x.shape[i]); batchSize *= x.shape[i]; } for (let i = batchDims; i < axis; i++) { outputShape.push(x.shape[i]); outerSize *= x.shape[i]; } for (let i = batchDims; i < indicesRank; i++) { outputShape.push(indices.shape[i]); } for (let i = axis + 1; i < xRank; i++) { outputShape.push(x.shape[i]); sliceSize *= x.shape[i]; } return { batchSize, sliceSize, outerSize, dimSize, outputShape }; } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoic2VnbWVudF91dGlsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvc2VnbWVudF91dGlsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUdILE9BQU8sRUFBQyxjQUFjLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFFdkMsT0FBTyxFQUFDLHFCQUFxQixFQUFDLE1BQU0sZUFBZSxDQUFDO0FBU3BELE1BQU0sVUFBVSw2QkFBNkIsQ0FDekMsTUFBYyxFQUFFLFdBQW1CO0lBQ3JDLElBQUksSUFBSSxHQUFHLEtBQUssQ0FBQztJQUNqQixJQUFJLEdBQUcsQ0FBQztJQUVSLElBQUksTUFBTSxJQUFJLHFCQUFxQixFQUFFO1FBQ25DLEdBQUcsR0FBRyxNQUFNLENBQUM7UUFDYixJQUFJLEdBQUcsSUFBSSxDQUFDO0tBQ2I7U0FBTTtRQUNMLEdBQUcsR0FBRyxjQUFjLENBQUMsTUFBTSxFQUFFLElBQUksQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUM7S0FDN0Q7SUFFRCxPQUFPLENBQUMsSUFBSSxFQUFFO1FBQ1osSUFBSSxHQUFHLEdBQUcsV0FBVyxJQUFJLEdBQUcsS0FBSyxNQUFNLEVBQUU7WUFDdkMsSUFBSSxHQUFHLElBQUksQ0FBQztTQUNiO2FBQU07WUFDTCxHQUFHLEdBQUcsY0FBYyxDQUFDLE1BQU0sRUFBRSxHQUFHLEdBQUcsQ0FBQyxDQUFDLENBQUM7U0FDdkM7S0FDRjtJQUNELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sVUFBVSxlQUFlLENBQzNCLE1BQWdCLEVBQUUsSUFBWSxFQUFFLFdBQW1CO0lBQ3JELE1BQU0sUUFBUSxHQUFHLEVBQUUsQ0FBQztJQUNwQixNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsTUFBTSxDQUFDO0lBQzNCLEtBQUssSUFBSSxHQUFHLEdBQUcsQ0FBQyxFQUFFLEdBQUcsR0FBRyxJQUFJLEVBQUUsR0FBRyxFQUFFLEVBQUU7UUFDbkMsSUFBSSxHQUFHLEtBQUssSUFBSSxFQUFFO1lBQ2hCLFFBQVEsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7U0FDNUI7YUFBTTtZQUNMLFFBQVEsQ0FBQyxJQUFJLENBQUMsV0FBVyxDQUFDLENBQUM7U0FDNUI7S0FDRjtJQUNELE9BQU8sUUFBUSxDQUFDO0FBQ2xCLENBQUM7QUFVRCxNQUFNLFVBQVUsd0JBQXdCLENBQ3BDLENBQWEsRUFBRSxPQUFtQixFQUFFLElBQVksRUFDaEQsU0FBaUI7SUFDbkIsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUM7SUFDekMsTUFBTSxLQUFLLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUM7SUFFN0IsSUFBSSxTQUFTLEtBQUssQ0FBQyxFQUFFO1FBQ25CLElBQUksU0FBUyxHQUFHLENBQUMsV0FBVyxJQUFJLFNBQVMsR0FBRyxXQUFXLEVBQUU7WUFDdkQsTUFBTSxJQUFJLEtBQUssQ0FBQyxzQ0FBc0MsV0FBVyxLQUM3RCxXQUFXLGNBQWMsU0FBUyxFQUFFLENBQUMsQ0FBQztTQUMzQztLQUNGO0lBRUQsSUFBSSxTQUFTLEdBQUcsQ0FBQyxFQUFFO1FBQ2pCLFNBQVMsSUFBSSxXQUFXLENBQUM7S0FDMUI7SUFFRCxJQUFJLFNBQVMsR0FBRyxLQUFLLEVBQUU7UUFDckIsTUFBTSxJQUFJLEtBQUssQ0FBQyxjQUFjLFNBQVM7TUFDckMsS0FBSyxJQUFJLENBQUMsQ0FBQztLQUNkO0lBRUQsSUFBSSxJQUFJLEdBQUcsU0FBUyxFQUFFO1FBQ3BCLE1BQU0sSUFBSSxLQUFLLENBQUMsY0FDWixTQUFTLHlDQUF5QyxJQUFJLElBQUksQ0FBQyxDQUFDO0tBQ2pFO0lBRUQsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFNBQVMsRUFBRSxFQUFFLENBQUMsRUFBRTtRQUNsQyxJQUFJLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEtBQUssT0FBTyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRTtZQUNuQyxNQUFNLElBQUksS0FBSyxDQUNYLFdBQVcsQ0FBQyxNQUFNLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLHFDQUN4QixDQUFDLE1BQU0sT0FBTyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7U0FDckM7S0FDRjtJQUNELE1BQU0sT0FBTyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLENBQUM7SUFFOUIsTUFBTSxXQUFXLEdBQWEsRUFBRSxDQUFDO0lBQ2pDLElBQUksU0FBUyxHQUFHLENBQUMsQ0FBQztJQUNsQixJQUFJLFNBQVMsR0FBRyxDQUFDLENBQUM7SUFDbEIsSUFBSSxTQUFTLEdBQUcsQ0FBQyxDQUFDO0lBRWxCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxTQUFTLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDbEMsV0FBVyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDN0IsU0FBUyxJQUFJLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7S0FDekI7SUFFRCxLQUFLLElBQUksQ0FBQyxHQUFHLFNBQVMsRUFBRSxDQUFDLEdBQUcsSUFBSSxFQUFFLENBQUMsRUFBRSxFQUFFO1FBQ3JDLFdBQVcsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzdCLFNBQVMsSUFBSSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ3pCO0lBRUQsS0FBSyxJQUFJLENBQUMsR0FBRyxTQUFTLEVBQUUsQ0FBQyxHQUFHLFdBQVcsRUFBRSxDQUFDLEVBQUUsRUFBRTtRQUM1QyxXQUFXLENBQUMsSUFBSSxDQUFDLE9BQU8sQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUNwQztJQUVELEtBQUssSUFBSSxDQUFDLEdBQUcsSUFBSSxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsS0FBSyxFQUFFLENBQUMsRUFBRSxFQUFFO1FBQ3JDLFdBQVcsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzdCLFNBQVMsSUFBSSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ3pCO0lBRUQsT0FBTyxFQUFDLFNBQVMsRUFBRSxTQUFTLEVBQUUsU0FBUyxFQUFFLE9BQU8sRUFBRSxXQUFXLEVBQUMsQ0FBQztBQUNqRSxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQgeyBUZW5zb3JJbmZvIH0gZnJvbSAnLi4vdGVuc29yX2luZm8nO1xuaW1wb3J0IHtuZWFyZXN0RGl2aXNvcn0gZnJvbSAnLi4vdXRpbCc7XG5cbmltcG9ydCB7UEFSQUxMRUxJWkVfVEhSRVNIT0xEfSBmcm9tICcuL3JlZHVjZV91dGlsJztcblxuZXhwb3J0IGludGVyZmFjZSBTZWdPcEluZm8ge1xuICB3aW5kb3dTaXplOiBudW1iZXI7XG4gIGJhdGNoU2l6ZTogbnVtYmVyO1xuICBpblNpemU6IG51bWJlcjtcbiAgbnVtU2VnbWVudHM6IG51bWJlcjtcbn1cblxuZXhwb3J0IGZ1bmN0aW9uIHNlZ09wQ29tcHV0ZU9wdGltYWxXaW5kb3dTaXplKFxuICAgIGluU2l6ZTogbnVtYmVyLCBudW1TZWdtZW50czogbnVtYmVyKTogbnVtYmVyIHtcbiAgbGV0IGRvbmUgPSBmYWxzZTtcbiAgbGV0IHJlcztcblxuICBpZiAoaW5TaXplIDw9IFBBUkFMTEVMSVpFX1RIUkVTSE9MRCkge1xuICAgIHJlcyA9IGluU2l6ZTtcbiAgICBkb25lID0gdHJ1ZTtcbiAgfSBlbHNlIHtcbiAgICByZXMgPSBuZWFyZXN0RGl2aXNvcihpblNpemUsIE1hdGguZmxvb3IoTWF0aC5zcXJ0KGluU2l6ZSkpKTtcbiAgfVxuXG4gIHdoaWxlICghZG9uZSkge1xuICAgIGlmIChyZXMgPiBudW1TZWdtZW50cyB8fCByZXMgPT09IGluU2l6ZSkge1xuICAgICAgZG9uZSA9IHRydWU7XG4gICAgfSBlbHNlIHtcbiAgICAgIHJlcyA9IG5lYXJlc3REaXZpc29yKGluU2l6ZSwgcmVzICsgMSk7XG4gICAgfVxuICB9XG4gIHJldHVybiByZXM7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBjb21wdXRlT3V0U2hhcGUoXG4gICAgYVNoYXBlOiBudW1iZXJbXSwgYXhpczogbnVtYmVyLCBudW1TZWdtZW50czogbnVtYmVyKTogbnVtYmVyW10ge1xuICBjb25zdCBvdXRTaGFwZSA9IFtdO1xuICBjb25zdCByYW5rID0gYVNoYXBlLmxlbmd0aDtcbiAgZm9yIChsZXQgZGltID0gMDsgZGltIDwgcmFuazsgZGltKyspIHtcbiAgICBpZiAoZGltICE9PSBheGlzKSB7XG4gICAgICBvdXRTaGFwZS5wdXNoKGFTaGFwZVtkaW1dKTtcbiAgICB9IGVsc2Uge1xuICAgICAgb3V0U2hhcGUucHVzaChudW1TZWdtZW50cyk7XG4gICAgfVxuICB9XG4gIHJldHVybiBvdXRTaGFwZTtcbn1cblxuZXhwb3J0IGludGVyZmFjZSBHYXRoZXJPcFNoYXBlSW5mbyB7XG4gIGJhdGNoU2l6ZTogbnVtYmVyO1xuICBzbGljZVNpemU6IG51bWJlcjtcbiAgb3V0ZXJTaXplOiBudW1iZXI7XG4gIGRpbVNpemU6IG51bWJlcjtcbiAgb3V0cHV0U2hhcGU6IG51bWJlcltdO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY29sbGVjdEdhdGhlck9wU2hhcGVJbmZvKFxuICAgIHg6IFRlbnNvckluZm8sIGluZGljZXM6IFRlbnNvckluZm8sIGF4aXM6IG51bWJlcixcbiAgICBiYXRjaERpbXM6IG51bWJlcik6IEdhdGhlck9wU2hhcGVJbmZvIHtcbiAgY29uc3QgaW5kaWNlc1JhbmsgPSBpbmRpY2VzLnNoYXBlLmxlbmd0aDtcbiAgY29uc3QgeFJhbmsgPSB4LnNoYXBlLmxlbmd0aDtcblxuICBpZiAoYmF0Y2hEaW1zICE9PSAwKSB7XG4gICAgaWYgKGJhdGNoRGltcyA8IC1pbmRpY2VzUmFuayB8fCBiYXRjaERpbXMgPiBpbmRpY2VzUmFuaykge1xuICAgICAgdGhyb3cgbmV3IEVycm9yKGBFeHBlY3QgYmF0Y2hEaW1zIGluIHRoZSByYW5nZSBvZiBbLSR7aW5kaWNlc1Jhbmt9LCAke1xuICAgICAgICAgIGluZGljZXNSYW5rfV0sIGJ1dCBnb3QgJHtiYXRjaERpbXN9YCk7XG4gICAgfVxuICB9XG5cbiAgaWYgKGJhdGNoRGltcyA8IDApIHtcbiAgICBiYXRjaERpbXMgKz0gaW5kaWNlc1Jhbms7XG4gIH1cblxuICBpZiAoYmF0Y2hEaW1zID4geFJhbmspIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYGJhdGNoRGltcyAoJHtiYXRjaERpbXN9KSBtdXN0IGJlIGxlc3MgdGhhbiByYW5rKHgpIChcbiAgICAke3hSYW5rfSkuYCk7XG4gIH1cblxuICBpZiAoYXhpcyA8IGJhdGNoRGltcykge1xuICAgIHRocm93IG5ldyBFcnJvcihgYmF0Y2hEaW1zICgke1xuICAgICAgICBiYXRjaERpbXN9KSBtdXN0IGJlIGxlc3MgdGhhbiBvciBlcXVhbCB0byBheGlzICgke2F4aXN9KS5gKTtcbiAgfVxuXG4gIGZvciAobGV0IGkgPSAwOyBpIDwgYmF0Y2hEaW1zOyArK2kpIHtcbiAgICBpZiAoeC5zaGFwZVtpXSAhPT0gaW5kaWNlcy5zaGFwZVtpXSkge1xuICAgICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICAgIGB4LnNoYXBlWyR7aX1dOiAke3guc2hhcGVbaV19IHNob3VsZCBiZSBlcXVhbCB0byBpbmRpY2VzLnNoYXBlWyR7XG4gICAgICAgICAgICAgIGl9XTogJHtpbmRpY2VzLnNoYXBlW2ldfS5gKTtcbiAgICB9XG4gIH1cbiAgY29uc3QgZGltU2l6ZSA9IHguc2hhcGVbYXhpc107XG5cbiAgY29uc3Qgb3V0cHV0U2hhcGU6IG51bWJlcltdID0gW107XG4gIGxldCBiYXRjaFNpemUgPSAxO1xuICBsZXQgb3V0ZXJTaXplID0gMTtcbiAgbGV0IHNsaWNlU2l6ZSA9IDE7XG5cbiAgZm9yIChsZXQgaSA9IDA7IGkgPCBiYXRjaERpbXM7ICsraSkge1xuICAgIG91dHB1dFNoYXBlLnB1c2goeC5zaGFwZVtpXSk7XG4gICAgYmF0Y2hTaXplICo9IHguc2hhcGVbaV07XG4gIH1cblxuICBmb3IgKGxldCBpID0gYmF0Y2hEaW1zOyBpIDwgYXhpczsgaSsrKSB7XG4gICAgb3V0cHV0U2hhcGUucHVzaCh4LnNoYXBlW2ldKTtcbiAgICBvdXRlclNpemUgKj0geC5zaGFwZVtpXTtcbiAgfVxuXG4gIGZvciAobGV0IGkgPSBiYXRjaERpbXM7IGkgPCBpbmRpY2VzUmFuazsgaSsrKSB7XG4gICAgb3V0cHV0U2hhcGUucHVzaChpbmRpY2VzLnNoYXBlW2ldKTtcbiAgfVxuXG4gIGZvciAobGV0IGkgPSBheGlzICsgMTsgaSA8IHhSYW5rOyBpKyspIHtcbiAgICBvdXRwdXRTaGFwZS5wdXNoKHguc2hhcGVbaV0pO1xuICAgIHNsaWNlU2l6ZSAqPSB4LnNoYXBlW2ldO1xuICB9XG5cbiAgcmV0dXJuIHtiYXRjaFNpemUsIHNsaWNlU2l6ZSwgb3V0ZXJTaXplLCBkaW1TaXplLCBvdXRwdXRTaGFwZX07XG59XG4iXX0=