/** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { util } from '@tensorflow/tfjs-core'; function lowerBound(array, value) { let left = 0; let right = array.length; let mid = 0; while (left < right) { mid = Math.floor((left + right) / 2); if (array[mid] < value) { left = mid + 1; } else { right = mid; } } return right; } function upperBound(array, value) { let left = 0; let right = array.length; let mid = 0; while (left < right) { mid = Math.floor((left + right) / 2); if (array[mid] <= value) { left = mid + 1; } else { right = mid; } } return right; } export function searchSortedImpl(sortedInputs, values, batchSize, numInputs, numValues, side) { const output = util.getArrayFromDType('int32', batchSize * numValues); for (let b = 0; b < batchSize; ++b) { const sortedInputsSlice = sortedInputs.slice(b * numInputs, (b + 1) * numInputs); const outputOffset = b * numValues; for (let i = 0; i < numValues; ++i) { output[outputOffset + i] = side === 'left' ? lowerBound(sortedInputsSlice, values[i + outputOffset]) : upperBound(sortedInputsSlice, values[i + outputOffset]); } } return output; } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU2VhcmNoU29ydGVkX2ltcGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL1NlYXJjaFNvcnRlZF9pbXBsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBYSxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUV2RCxTQUFTLFVBQVUsQ0FBQyxLQUFpQixFQUFFLEtBQWE7SUFDbEQsSUFBSSxJQUFJLEdBQUcsQ0FBQyxDQUFDO0lBQ2IsSUFBSSxLQUFLLEdBQUcsS0FBSyxDQUFDLE1BQU0sQ0FBQztJQUN6QixJQUFJLEdBQUcsR0FBRyxDQUFDLENBQUM7SUFDWixPQUFPLElBQUksR0FBRyxLQUFLLEVBQUU7UUFDbkIsR0FBRyxHQUFHLElBQUksQ0FBQyxLQUFLLENBQUMsQ0FBQyxJQUFJLEdBQUcsS0FBSyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7UUFDckMsSUFBSSxLQUFLLENBQUMsR0FBRyxDQUFDLEdBQUcsS0FBSyxFQUFFO1lBQ3RCLElBQUksR0FBRyxHQUFHLEdBQUcsQ0FBQyxDQUFDO1NBQ2hCO2FBQU07WUFDTCxLQUFLLEdBQUcsR0FBRyxDQUFDO1NBQ2I7S0FDRjtJQUNELE9BQU8sS0FBSyxDQUFDO0FBQ2YsQ0FBQztBQUVELFNBQVMsVUFBVSxDQUFDLEtBQWlCLEVBQUUsS0FBYTtJQUNsRCxJQUFJLElBQUksR0FBRyxDQUFDLENBQUM7SUFDYixJQUFJLEtBQUssR0FBRyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBQ3pCLElBQUksR0FBRyxHQUFHLENBQUMsQ0FBQztJQUNaLE9BQU8sSUFBSSxHQUFHLEtBQUssRUFBRTtRQUNuQixHQUFHLEdBQUcsSUFBSSxDQUFDLEtBQUssQ0FBQyxDQUFDLElBQUksR0FBRyxLQUFLLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztRQUNyQyxJQUFJLEtBQUssQ0FBQyxHQUFHLENBQUMsSUFBSSxLQUFLLEVBQUU7WUFDdkIsSUFBSSxHQUFHLEdBQUcsR0FBRyxDQUFDLENBQUM7U0FDaEI7YUFBTTtZQUNMLEtBQUssR0FBRyxHQUFHLENBQUM7U0FDYjtLQUNGO0lBQ0QsT0FBTyxLQUFLLENBQUM7QUFDZixDQUFDO0FBRUQsTUFBTSxVQUFVLGdCQUFnQixDQUM1QixZQUF3QixFQUFFLE1BQWtCLEVBQUUsU0FBaUIsRUFDL0QsU0FBaUIsRUFBRSxTQUFpQixFQUFFLElBQW9CO0lBQzVELE1BQU0sTUFBTSxHQUNSLElBQUksQ0FBQyxpQkFBaUIsQ0FBQyxPQUFPLEVBQUUsU0FBUyxHQUFHLFNBQVMsQ0FBZSxDQUFDO0lBQ3pFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxTQUFTLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDbEMsTUFBTSxpQkFBaUIsR0FDbkIsWUFBWSxDQUFDLEtBQUssQ0FBQyxDQUFDLEdBQUcsU0FBUyxFQUFFLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxDQUFDO1FBQzNELE1BQU0sWUFBWSxHQUFHLENBQUMsR0FBRyxTQUFTLENBQUM7UUFDbkMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFNBQVMsRUFBRSxFQUFFLENBQUMsRUFBRTtZQUNsQyxNQUFNLENBQUMsWUFBWSxHQUFHLENBQUMsQ0FBQyxHQUFHLElBQUksS0FBSyxNQUFNLENBQUMsQ0FBQztnQkFDeEMsVUFBVSxDQUFDLGlCQUFpQixFQUFFLE1BQU0sQ0FBQyxDQUFDLEdBQUcsWUFBWSxDQUFDLENBQUMsQ0FBQyxDQUFDO2dCQUN6RCxVQUFVLENBQUMsaUJBQWlCLEVBQUUsTUFBTSxDQUFDLENBQUMsR0FBRyxZQUFZLENBQUMsQ0FBQyxDQUFDO1NBQzdEO0tBQ0Y7SUFDRCxPQUFPLE1BQU0sQ0FBQztBQUNoQixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjIgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1R5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmZ1bmN0aW9uIGxvd2VyQm91bmQoYXJyYXk6IFR5cGVkQXJyYXksIHZhbHVlOiBudW1iZXIpIHtcbiAgbGV0IGxlZnQgPSAwO1xuICBsZXQgcmlnaHQgPSBhcnJheS5sZW5ndGg7XG4gIGxldCBtaWQgPSAwO1xuICB3aGlsZSAobGVmdCA8IHJpZ2h0KSB7XG4gICAgbWlkID0gTWF0aC5mbG9vcigobGVmdCArIHJpZ2h0KSAvIDIpO1xuICAgIGlmIChhcnJheVttaWRdIDwgdmFsdWUpIHtcbiAgICAgIGxlZnQgPSBtaWQgKyAxO1xuICAgIH0gZWxzZSB7XG4gICAgICByaWdodCA9IG1pZDtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIHJpZ2h0O1xufVxuXG5mdW5jdGlvbiB1cHBlckJvdW5kKGFycmF5OiBUeXBlZEFycmF5LCB2YWx1ZTogbnVtYmVyKSB7XG4gIGxldCBsZWZ0ID0gMDtcbiAgbGV0IHJpZ2h0ID0gYXJyYXkubGVuZ3RoO1xuICBsZXQgbWlkID0gMDtcbiAgd2hpbGUgKGxlZnQgPCByaWdodCkge1xuICAgIG1pZCA9IE1hdGguZmxvb3IoKGxlZnQgKyByaWdodCkgLyAyKTtcbiAgICBpZiAoYXJyYXlbbWlkXSA8PSB2YWx1ZSkge1xuICAgICAgbGVmdCA9IG1pZCArIDE7XG4gICAgfSBlbHNlIHtcbiAgICAgIHJpZ2h0ID0gbWlkO1xuICAgIH1cbiAgfVxuICByZXR1cm4gcmlnaHQ7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBzZWFyY2hTb3J0ZWRJbXBsKFxuICAgIHNvcnRlZElucHV0czogVHlwZWRBcnJheSwgdmFsdWVzOiBUeXBlZEFycmF5LCBiYXRjaFNpemU6IG51bWJlcixcbiAgICBudW1JbnB1dHM6IG51bWJlciwgbnVtVmFsdWVzOiBudW1iZXIsIHNpZGU6ICdsZWZ0J3wncmlnaHQnKTogVHlwZWRBcnJheSB7XG4gIGNvbnN0IG91dHB1dCA9XG4gICAgICB1dGlsLmdldEFycmF5RnJvbURUeXBlKCdpbnQzMicsIGJhdGNoU2l6ZSAqIG51bVZhbHVlcykgYXMgVHlwZWRBcnJheTtcbiAgZm9yIChsZXQgYiA9IDA7IGIgPCBiYXRjaFNpemU7ICsrYikge1xuICAgIGNvbnN0IHNvcnRlZElucHV0c1NsaWNlID1cbiAgICAgICAgc29ydGVkSW5wdXRzLnNsaWNlKGIgKiBudW1JbnB1dHMsIChiICsgMSkgKiBudW1JbnB1dHMpO1xuICAgIGNvbnN0IG91dHB1dE9mZnNldCA9IGIgKiBudW1WYWx1ZXM7XG4gICAgZm9yIChsZXQgaSA9IDA7IGkgPCBudW1WYWx1ZXM7ICsraSkge1xuICAgICAgb3V0cHV0W291dHB1dE9mZnNldCArIGldID0gc2lkZSA9PT0gJ2xlZnQnID9cbiAgICAgICAgICBsb3dlckJvdW5kKHNvcnRlZElucHV0c1NsaWNlLCB2YWx1ZXNbaSArIG91dHB1dE9mZnNldF0pIDpcbiAgICAgICAgICB1cHBlckJvdW5kKHNvcnRlZElucHV0c1NsaWNlLCB2YWx1ZXNbaSArIG91dHB1dE9mZnNldF0pO1xuICAgIH1cbiAgfVxuICByZXR1cm4gb3V0cHV0O1xufVxuIl19