gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
/**
 * @license
 * Copyright 2022 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { util } from '@tensorflow/tfjs-core';
function lowerBound(array, value) {
    let left = 0;
    let right = array.length;
    let mid = 0;
    while (left < right) {
        mid = Math.floor((left + right) / 2);
        if (array[mid] < value) {
            left = mid + 1;
        }
        else {
            right = mid;
        }
    }
    return right;
}
function upperBound(array, value) {
    let left = 0;
    let right = array.length;
    let mid = 0;
    while (left < right) {
        mid = Math.floor((left + right) / 2);
        if (array[mid] <= value) {
            left = mid + 1;
        }
        else {
            right = mid;
        }
    }
    return right;
}
export function searchSortedImpl(sortedInputs, values, batchSize, numInputs, numValues, side) {
    const output = util.getArrayFromDType('int32', batchSize * numValues);
    for (let b = 0; b < batchSize; ++b) {
        const sortedInputsSlice = sortedInputs.slice(b * numInputs, (b + 1) * numInputs);
        const outputOffset = b * numValues;
        for (let i = 0; i < numValues; ++i) {
            output[outputOffset + i] = side === 'left' ?
                lowerBound(sortedInputsSlice, values[i + outputOffset]) :
                upperBound(sortedInputsSlice, values[i + outputOffset]);
        }
    }
    return output;
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU2VhcmNoU29ydGVkX2ltcGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL1NlYXJjaFNvcnRlZF9pbXBsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBYSxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUV2RCxTQUFTLFVBQVUsQ0FBQyxLQUFpQixFQUFFLEtBQWE7SUFDbEQsSUFBSSxJQUFJLEdBQUcsQ0FBQyxDQUFDO0lBQ2IsSUFBSSxLQUFLLEdBQUcsS0FBSyxDQUFDLE1BQU0sQ0FBQztJQUN6QixJQUFJLEdBQUcsR0FBRyxDQUFDLENBQUM7SUFDWixPQUFPLElBQUksR0FBRyxLQUFLLEVBQUU7UUFDbkIsR0FBRyxHQUFHLElBQUksQ0FBQyxLQUFLLENBQUMsQ0FBQyxJQUFJLEdBQUcsS0FBSyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7UUFDckMsSUFBSSxLQUFLLENBQUMsR0FBRyxDQUFDLEdBQUcsS0FBSyxFQUFFO1lBQ3RCLElBQUksR0FBRyxHQUFHLEdBQUcsQ0FBQyxDQUFDO1NBQ2hCO2FBQU07WUFDTCxLQUFLLEdBQUcsR0FBRyxDQUFDO1NBQ2I7S0FDRjtJQUNELE9BQU8sS0FBSyxDQUFDO0FBQ2YsQ0FBQztBQUVELFNBQVMsVUFBVSxDQUFDLEtBQWlCLEVBQUUsS0FBYTtJQUNsRCxJQUFJLElBQUksR0FBRyxDQUFDLENBQUM7SUFDYixJQUFJLEtBQUssR0FBRyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBQ3pCLElBQUksR0FBRyxHQUFHLENBQUMsQ0FBQztJQUNaLE9BQU8sSUFBSSxHQUFHLEtBQUssRUFBRTtRQUNuQixHQUFHLEdBQUcsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLElBQUksR0FBRyxLQUFLLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztRQUNyQyxJQUFJLEtBQUssQ0FBQyxHQUFHLENBQUMsSUFBSSxLQUFLLEVBQUU7WUFDdkIsSUFBSSxHQUFHLEdBQUcsR0FBRyxDQUFDLENBQUM7U0FDaEI7YUFBTTtZQUNMLEtBQUssR0FBRyxHQUFHLENBQUM7U0FDYjtLQUNGO0lBQ0QsT0FBTyxLQUFLLENBQUM7QUFDZixDQUFDO0FBRUQsTUFBTSxVQUFVLGdCQUFnQixDQUM1QixZQUF3QixFQUFFLE1BQWtCLEVBQUUsU0FBaUIsRUFDL0QsU0FBaUIsRUFBRSxTQUFpQixFQUFFLElBQW9CO0lBQzVELE1BQU0sTUFBTSxHQUNSLElBQUksQ0FBQyxpQkFBaUIsQ0FBQyxPQUFPLEVBQUUsU0FBUyxHQUFHLFNBQVMsQ0FBZSxDQUFDO0lBQ3pFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxTQUFTLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDbEMsTUFBTSxpQkFBaUIsR0FDbkIsWUFBWSxDQUFDLEtBQUssQ0FBQyxDQUFDLEdBQUcsU0FBUyxFQUFFLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxDQUFDO1FBQzNELE1BQU0sWUFBWSxHQUFHLENBQUMsR0FBRyxTQUFTLENBQUM7UUFDbkMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFNBQVMsRUFBRSxFQUFFLENBQUMsRUFBRTtZQUNsQyxNQUFNLENBQUMsWUFBWSxHQUFHLENBQUMsQ0FBQyxHQUFHLElBQUksS0FBSyxNQUFNLENBQUMsQ0FBQztnQkFDeEMsVUFBVSxDQUFDLGlCQUFpQixFQUFFLE1BQU0sQ0FBQyxDQUFDLEdBQUcsWUFBWSxDQUFDLENBQUMsQ0FBQyxDQUFDO2dCQUN6RCxVQUFVLENBQUMsaUJBQWlCLEVBQUUsTUFBTSxDQUFDLENBQUMsR0FBRyxZQUFZLENBQUMsQ0FBQyxDQUFDO1NBQzdEO0tBQ0Y7SUFDRCxPQUFPLE1BQU0sQ0FBQztBQUNoQixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjIgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1R5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmZ1bmN0aW9uIGxvd2VyQm91bmQoYXJyYXk6IFR5cGVkQXJyYXksIHZhbHVlOiBudW1iZXIpIHtcbiAgbGV0IGxlZnQgPSAwO1xuICBsZXQgcmlnaHQgPSBhcnJheS5sZW5ndGg7XG4gIGxldCBtaWQgPSAwO1xuICB3aGlsZSAobGVmdCA8IHJpZ2h0KSB7XG4gICAgbWlkID0gTWF0aC5mbG9vcigobGVmdCArIHJpZ2h0KSAvIDIpO1xuICAgIGlmIChhcnJheVttaWRdIDwgdmFsdWUpIHtcbiAgICAgIGxlZnQgPSBtaWQgKyAxO1xuICAgIH0gZWxzZSB7XG4gICAgICByaWdodCA9IG1pZDtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIHJpZ2h0O1xufVxuXG5mdW5jdGlvbiB1cHBlckJvdW5kKGFycmF5OiBUeXBlZEFycmF5LCB2YWx1ZTogbnVtYmVyKSB7XG4gIGxldCBsZWZ0ID0gMDtcbiAgbGV0IHJpZ2h0ID0gYXJyYXkubGVuZ3RoO1xuICBsZXQgbWlkID0gMDtcbiAgd2hpbGUgKGxlZnQgPCByaWdodCkge1xuICAgIG1pZCA9IE1hdGguZmxvb3IoKGxlZnQgKyByaWdodCkgLyAyKTtcbiAgICBpZiAoYXJyYXlbbWlkXSA8PSB2YWx1ZSkge1xuICAgICAgbGVmdCA9IG1pZCArIDE7XG4gICAgfSBlbHNlIHtcbiAgICAgIHJpZ2h0ID0gbWlkO1xuICAgIH1cbiAgfVxuICByZXR1cm4gcmlnaHQ7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBzZWFyY2hTb3J0ZWRJbXBsKFxuICAgIHNvcnRlZElucHV0czogVHlwZWRBcnJheSwgdmFsdWVzOiBUeXBlZEFycmF5LCBiYXRjaFNpemU6IG51bWJlcixcbiAgICBudW1JbnB1dHM6IG51bWJlciwgbnVtVmFsdWVzOiBudW1iZXIsIHNpZGU6ICdsZWZ0J3wncmlnaHQnKTogVHlwZWRBcnJheSB7XG4gIGNvbnN0IG91dHB1dCA9XG4gICAgICB1dGlsLmdldEFycmF5RnJvbURUeXBlKCdpbnQzMicsIGJhdGNoU2l6ZSAqIG51bVZhbHVlcykgYXMgVHlwZWRBcnJheTtcbiAgZm9yIChsZXQgYiA9IDA7IGIgPCBiYXRjaFNpemU7ICsrYikge1xuICAgIGNvbnN0IHNvcnRlZElucHV0c1NsaWNlID1cbiAgICAgICAgc29ydGVkSW5wdXRzLnNsaWNlKGIgKiBudW1JbnB1dHMsIChiICsgMSkgKiBudW1JbnB1dHMpO1xuICAgIGNvbnN0IG91dHB1dE9mZnNldCA9IGIgKiBudW1WYWx1ZXM7XG4gICAgZm9yIChsZXQgaSA9IDA7IGkgPCBudW1WYWx1ZXM7ICsraSkge1xuICAgICAgb3V0cHV0W291dHB1dE9mZnNldCArIGldID0gc2lkZSA9PT0gJ2xlZnQnID9cbiAgICAgICAgICBsb3dlckJvdW5kKHNvcnRlZElucHV0c1NsaWNlLCB2YWx1ZXNbaSArIG91dHB1dE9mZnNldF0pIDpcbiAgICAgICAgICB1cHBlckJvdW5kKHNvcnRlZElucHV0c1NsaWNlLCB2YWx1ZXNbaSArIG91dHB1dE9mZnNldF0pO1xuICAgIH1cbiAgfVxuICByZXR1cm4gb3V0cHV0O1xufVxuIl19