gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import { assert } from '../util';
/**
 * Prepare the split size array. When the input is a number, the axis is evenly
 * divided among the split size. When the input contains the negative value, the
 * rest of the axis is allocated toward that.
 */
export function prepareSplitSize(x, numOrSizeSplits, axis = 0) {
    let splitSizes = [];
    if (typeof (numOrSizeSplits) === 'number') {
        assert(x.shape[axis] % numOrSizeSplits === 0, () => 'Number of splits must evenly divide the axis.');
        splitSizes =
            new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
    }
    else {
        const numOfNegs = numOrSizeSplits.reduce((count, value) => {
            if (value === -1) {
                count += 1;
            }
            return count;
        }, 0);
        assert(numOfNegs <= 1, () => 'There should be only one negative value in split array.');
        const negIndex = numOrSizeSplits.indexOf(-1);
        // Allow the number of split array to be -1, which indicates the rest
        // of dimension is allocated to that split.
        if (negIndex !== -1) {
            const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
            numOrSizeSplits[negIndex] = x.shape[axis] - total;
        }
        assert(x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.');
        splitSizes = numOrSizeSplits;
    }
    return splitSizes;
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoic3BsaXRfdXRpbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL3NwbGl0X3V0aWwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBa0JBLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFFL0I7Ozs7R0FJRztBQUNILE1BQU0sVUFBVSxnQkFBZ0IsQ0FDNUIsQ0FBb0IsRUFBRSxlQUFnQyxFQUN0RCxJQUFJLEdBQUcsQ0FBQztJQUNWLElBQUksVUFBVSxHQUFHLEVBQUUsQ0FBQztJQUNwQixJQUFJLE9BQU8sQ0FBQyxlQUFlLENBQUMsS0FBSyxRQUFRLEVBQUU7UUFDekMsTUFBTSxDQUNGLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLEdBQUcsZUFBZSxLQUFLLENBQUMsRUFDckMsR0FBRyxFQUFFLENBQUMsK0NBQStDLENBQUMsQ0FBQztRQUMzRCxVQUFVO1lBQ04sSUFBSSxLQUFLLENBQUMsZUFBZSxDQUFDLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLEdBQUcsZUFBZSxDQUFDLENBQUM7S0FDdEU7U0FBTTtRQUNMLE1BQU0sU0FBUyxHQUFHLGVBQWUsQ0FBQyxNQUFNLENBQUMsQ0FBQyxLQUFLLEVBQUUsS0FBSyxFQUFFLEVBQUU7WUFDeEQsSUFBSSxLQUFLLEtBQUssQ0FBQyxDQUFDLEVBQUU7Z0JBQ2hCLEtBQUssSUFBSSxDQUFDLENBQUM7YUFDWjtZQUNELE9BQU8sS0FBSyxDQUFDO1FBQ2YsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO1FBQ04sTUFBTSxDQUNGLFNBQVMsSUFBSSxDQUFDLEVBQ2QsR0FBRyxFQUFFLENBQUMseURBQXlELENBQUMsQ0FBQztRQUNyRSxNQUFNLFFBQVEsR0FBRyxlQUFlLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDN0MscUVBQXFFO1FBQ3JFLDJDQUEyQztRQUMzQyxJQUFJLFFBQVEsS0FBSyxDQUFDLENBQUMsRUFBRTtZQUNuQixNQUFNLEtBQUssR0FBRyxlQUFlLENBQUMsTUFBTSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7WUFDbEUsZUFBZSxDQUFDLFFBQVEsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLEdBQUcsS0FBSyxDQUFDO1NBQ25EO1FBQ0QsTUFBTSxDQUNGLENBQUMsQ0FBQyxLQUFLLENBQUMsSUFBSSxDQUFDLEtBQUssZUFBZSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsRUFDekQsR0FBRyxFQUFFLENBQUMsNkRBQTZELENBQUMsQ0FBQztRQUN6RSxVQUFVLEdBQUcsZUFBZSxDQUFDO0tBQzlCO0lBRUQsT0FBTyxVQUFVLENBQUM7QUFDcEIsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cbmltcG9ydCB7IFRlbnNvckluZm8gfSBmcm9tICcuLi90ZW5zb3JfaW5mbyc7XG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7YXNzZXJ0fSBmcm9tICcuLi91dGlsJztcblxuLyoqXG4gKiBQcmVwYXJlIHRoZSBzcGxpdCBzaXplIGFycmF5LiBXaGVuIHRoZSBpbnB1dCBpcyBhIG51bWJlciwgdGhlIGF4aXMgaXMgZXZlbmx5XG4gKiBkaXZpZGVkIGFtb25nIHRoZSBzcGxpdCBzaXplLiBXaGVuIHRoZSBpbnB1dCBjb250YWlucyB0aGUgbmVnYXRpdmUgdmFsdWUsIHRoZVxuICogcmVzdCBvZiB0aGUgYXhpcyBpcyBhbGxvY2F0ZWQgdG93YXJkIHRoYXQuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBwcmVwYXJlU3BsaXRTaXplKFxuICAgIHg6IFRlbnNvcnxUZW5zb3JJbmZvLCBudW1PclNpemVTcGxpdHM6IG51bWJlcltdfG51bWJlcixcbiAgICBheGlzID0gMCk6IG51bWJlcltdIHtcbiAgbGV0IHNwbGl0U2l6ZXMgPSBbXTtcbiAgaWYgKHR5cGVvZiAobnVtT3JTaXplU3BsaXRzKSA9PT0gJ251bWJlcicpIHtcbiAgICBhc3NlcnQoXG4gICAgICAgIHguc2hhcGVbYXhpc10gJSBudW1PclNpemVTcGxpdHMgPT09IDAsXG4gICAgICAgICgpID0+ICdOdW1iZXIgb2Ygc3BsaXRzIG11c3QgZXZlbmx5IGRpdmlkZSB0aGUgYXhpcy4nKTtcbiAgICBzcGxpdFNpemVzID1cbiAgICAgICAgbmV3IEFycmF5KG51bU9yU2l6ZVNwbGl0cykuZmlsbCh4LnNoYXBlW2F4aXNdIC8gbnVtT3JTaXplU3BsaXRzKTtcbiAgfSBlbHNlIHtcbiAgICBjb25zdCBudW1PZk5lZ3MgPSBudW1PclNpemVTcGxpdHMucmVkdWNlKChjb3VudCwgdmFsdWUpID0+IHtcbiAgICAgIGlmICh2YWx1ZSA9PT0gLTEpIHtcbiAgICAgICAgY291bnQgKz0gMTtcbiAgICAgIH1cbiAgICAgIHJldHVybiBjb3VudDtcbiAgICB9LCAwKTtcbiAgICBhc3NlcnQoXG4gICAgICAgIG51bU9mTmVncyA8PSAxLFxuICAgICAgICAoKSA9PiAnVGhlcmUgc2hvdWxkIGJlIG9ubHkgb25lIG5lZ2F0aXZlIHZhbHVlIGluIHNwbGl0IGFycmF5LicpO1xuICAgIGNvbnN0IG5lZ0luZGV4ID0gbnVtT3JTaXplU3BsaXRzLmluZGV4T2YoLTEpO1xuICAgIC8vIEFsbG93IHRoZSBudW1iZXIgb2Ygc3BsaXQgYXJyYXkgdG8gYmUgLTEsIHdoaWNoIGluZGljYXRlcyB0aGUgcmVzdFxuICAgIC8vIG9mIGRpbWVuc2lvbiBpcyBhbGxvY2F0ZWQgdG8gdGhhdCBzcGxpdC5cbiAgICBpZiAobmVnSW5kZXggIT09IC0xKSB7XG4gICAgICBjb25zdCB0b3RhbCA9IG51bU9yU2l6ZVNwbGl0cy5yZWR1Y2UoKGEsIGIpID0+IGIgPiAwID8gYSArIGIgOiBhKTtcbiAgICAgIG51bU9yU2l6ZVNwbGl0c1tuZWdJbmRleF0gPSB4LnNoYXBlW2F4aXNdIC0gdG90YWw7XG4gICAgfVxuICAgIGFzc2VydChcbiAgICAgICAgeC5zaGFwZVtheGlzXSA9PT0gbnVtT3JTaXplU3BsaXRzLnJlZHVjZSgoYSwgYikgPT4gYSArIGIpLFxuICAgICAgICAoKSA9PiAnVGhlIHN1bSBvZiBzaXplcyBtdXN0IG1hdGNoIHRoZSBzaXplIG9mIHRoZSBheGlzIGRpbWVuc2lvbi4nKTtcbiAgICBzcGxpdFNpemVzID0gbnVtT3JTaXplU3BsaXRzO1xuICB9XG5cbiAgcmV0dXJuIHNwbGl0U2l6ZXM7XG59XG4iXX0=