import { assert } from '../util';
|
/**
|
* Prepare the split size array. When the input is a number, the axis is evenly
|
* divided among the split size. When the input contains the negative value, the
|
* rest of the axis is allocated toward that.
|
*/
|
export function prepareSplitSize(x, numOrSizeSplits, axis = 0) {
|
let splitSizes = [];
|
if (typeof (numOrSizeSplits) === 'number') {
|
assert(x.shape[axis] % numOrSizeSplits === 0, () => 'Number of splits must evenly divide the axis.');
|
splitSizes =
|
new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
|
}
|
else {
|
const numOfNegs = numOrSizeSplits.reduce((count, value) => {
|
if (value === -1) {
|
count += 1;
|
}
|
return count;
|
}, 0);
|
assert(numOfNegs <= 1, () => 'There should be only one negative value in split array.');
|
const negIndex = numOrSizeSplits.indexOf(-1);
|
// Allow the number of split array to be -1, which indicates the rest
|
// of dimension is allocated to that split.
|
if (negIndex !== -1) {
|
const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
|
numOrSizeSplits[negIndex] = x.shape[axis] - total;
|
}
|
assert(x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.');
|
splitSizes = numOrSizeSplits;
|
}
|
return splitSizes;
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoic3BsaXRfdXRpbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL3NwbGl0X3V0aWwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBa0JBLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFFL0I7Ozs7R0FJRztBQUNILE1BQU0sVUFBVSxnQkFBZ0IsQ0FDNUIsQ0FBb0IsRUFBRSxlQUFnQyxFQUN0RCxJQUFJLEdBQUcsQ0FBQztJQUNWLElBQUksVUFBVSxHQUFHLEVBQUUsQ0FBQztJQUNwQixJQUFJLE9BQU8sQ0FBQyxlQUFlLENBQUMsS0FBSyxRQUFRLEVBQUU7UUFDekMsTUFBTSxDQUNGLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLEdBQUcsZUFBZSxLQUFLLENBQUMsRUFDckMsR0FBRyxFQUFFLENBQUMsK0NBQStDLENBQUMsQ0FBQztRQUMzRCxVQUFVO1lBQ04sSUFBSSxLQUFLLENBQUMsZUFBZSxDQUFDLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLEdBQUcsZUFBZSxDQUFDLENBQUM7S0FDdEU7U0FBTTtRQUNMLE1BQU0sU0FBUyxHQUFHLGVBQWUsQ0FBQyxNQUFNLENBQUMsQ0FBQyxLQUFLLEVBQUUsS0FBSyxFQUFFLEVBQUU7WUFDeEQsSUFBSSxLQUFLLEtBQUssQ0FBQyxDQUFDLEVBQUU7Z0JBQ2hCLEtBQUssSUFBSSxDQUFDLENBQUM7YUFDWjtZQUNELE9BQU8sS0FBSyxDQUFDO1FBQ2YsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO1FBQ04sTUFBTSxDQUNGLFNBQVMsSUFBSSxDQUFDLEVBQ2QsR0FBRyxFQUFFLENBQUMseURBQXlELENBQUMsQ0FBQztRQUNyRSxNQUFNLFFBQVEsR0FBRyxlQUFlLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDN0MscUVBQXFFO1FBQ3JFLDJDQUEyQztRQUMzQyxJQUFJLFFBQVEsS0FBSyxDQUFDLENBQUMsRUFBRTtZQUNuQixNQUFNLEtBQUssR0FBRyxlQUFlLENBQUMsTUFBTSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7WUFDbEUsZUFBZSxDQUFDLFFBQVEsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLEdBQUcsS0FBSyxDQUFDO1NBQ25EO1FBQ0QsTUFBTSxDQUNGLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLEtBQUssZUFBZSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsRUFDekQsR0FBRyxFQUFFLENBQUMsNkRBQTZELENBQUMsQ0FBQztRQUN6RSxVQUFVLEdBQUcsZUFBZSxDQUFDO0tBQzlCO0lBRUQsT0FBTyxVQUFVLENBQUM7QUFDcEIsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cbmltcG9ydCB7IFRlbnNvckluZm8gfSBmcm9tICcuLi90ZW5zb3JfaW5mbyc7XG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7YXNzZXJ0fSBmcm9tICcuLi91dGlsJztcblxuLyoqXG4gKiBQcmVwYXJlIHRoZSBzcGxpdCBzaXplIGFycmF5LiBXaGVuIHRoZSBpbnB1dCBpcyBhIG51bWJlciwgdGhlIGF4aXMgaXMgZXZlbmx5XG4gKiBkaXZpZGVkIGFtb25nIHRoZSBzcGxpdCBzaXplLiBXaGVuIHRoZSBpbnB1dCBjb250YWlucyB0aGUgbmVnYXRpdmUgdmFsdWUsIHRoZVxuICogcmVzdCBvZiB0aGUgYXhpcyBpcyBhbGxvY2F0ZWQgdG93YXJkIHRoYXQuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBwcmVwYXJlU3BsaXRTaXplKFxuICAgIHg6IFRlbnNvcnxUZW5zb3JJbmZvLCBudW1PclNpemVTcGxpdHM6IG51bWJlcltdfG51bWJlcixcbiAgICBheGlzID0gMCk6IG51bWJlcltdIHtcbiAgbGV0IHNwbGl0U2l6ZXMgPSBbXTtcbiAgaWYgKHR5cGVvZiAobnVtT3JTaXplU3BsaXRzKSA9PT0gJ251bWJlcicpIHtcbiAgICBhc3NlcnQoXG4gICAgICAgIHguc2hhcGVbYXhpc10gJSBudW1PclNpemVTcGxpdHMgPT09IDAsXG4gICAgICAgICgpID0+ICdOdW1iZXIgb2Ygc3BsaXRzIG11c3QgZXZlbmx5IGRpdmlkZSB0aGUgYXhpcy4nKTtcbiAgICBzcGxpdFNpemVzID1cbiAgICAgICAgbmV3IEFycmF5KG51bU9yU2l6ZVNwbGl0cykuZmlsbCh4LnNoYXBlW2F4aXNdIC8gbnVtT3JTaXplU3BsaXRzKTtcbiAgfSBlbHNlIHtcbiAgICBjb25zdCBudW1PZk5lZ3MgPSBudW1PclNpemVTcGxpdHMucmVkdWNlKChjb3VudCwgdmFsdWUpID0+IHtcbiAgICAgIGlmICh2YWx1ZSA9PT0gLTEpIHtcbiAgICAgICAgY291bnQgKz0gMTtcbiAgICAgIH1cbiAgICAgIHJldHVybiBjb3VudDtcbiAgICB9LCAwKTtcbiAgICBhc3NlcnQoXG4gICAgICAgIG51bU9mTmVncyA8PSAxLFxuICAgICAgICAoKSA9PiAnVGhlcmUgc2hvdWxkIGJlIG9ubHkgb25lIG5lZ2F0aXZlIHZhbHVlIGluIHNwbGl0IGFycmF5LicpO1xuICAgIGNvbnN0IG5lZ0luZGV4ID0gbnVtT3JTaXplU3BsaXRzLmluZGV4T2YoLTEpO1xuICAgIC8vIEFsbG93IHRoZSBudW1iZXIgb2Ygc3BsaXQgYXJyYXkgdG8gYmUgLTEsIHdoaWNoIGluZGljYXRlcyB0aGUgcmVzdFxuICAgIC8vIG9mIGRpbWVuc2lvbiBpcyBhbGxvY2F0ZWQgdG8gdGhhdCBzcGxpdC5cbiAgICBpZiAobmVnSW5kZXggIT09IC0xKSB7XG4gICAgICBjb25zdCB0b3RhbCA9IG51bU9yU2l6ZVNwbGl0cy5yZWR1Y2UoKGEsIGIpID0+IGIgPiAwID8gYSArIGIgOiBhKTtcbiAgICAgIG51bU9yU2l6ZVNwbGl0c1tuZWdJbmRleF0gPSB4LnNoYXBlW2F4aXNdIC0gdG90YWw7XG4gICAgfVxuICAgIGFzc2VydChcbiAgICAgICAgeC5zaGFwZVtheGlzXSA9PT0gbnVtT3JTaXplU3BsaXRzLnJlZHVjZSgoYSwgYikgPT4gYSArIGIpLFxuICAgICAgICAoKSA9PiAnVGhlIHN1bSBvZiBzaXplcyBtdXN0IG1hdGNoIHRoZSBzaXplIG9mIHRoZSBheGlzIGRpbWVuc2lvbi4nKTtcbiAgICBzcGxpdFNpemVzID0gbnVtT3JTaXplU3BsaXRzO1xuICB9XG5cbiAgcmV0dXJuIHNwbGl0U2l6ZXM7XG59XG4iXX0=
|