/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { env } from '@tensorflow/tfjs-core';
|
export class SearchSortedProgram {
|
constructor(batchSize, numInputs, numValues, side) {
|
this.variableNames = ['sortedSequence', 'values'];
|
this.customUniforms = [{ name: 'numInputs', type: 'int' }];
|
this.outputShape = [batchSize, numValues];
|
const webGL2LoopHead = 'while (left < right) {';
|
// WebGL1 doesn't accept non constant loop conditions, so upper bound loop
|
// iterations.
|
const webGL1LoopHead = `for (int i = 0; i < ${Math.ceil(Math.log2(numInputs + 1))}; ++i) { if (left >= right) break;`;
|
const loopHead = env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead :
|
webGL1LoopHead;
|
// left corresponds to lower bound and right to upper bound.
|
const boundComparator = side === 'left' ? '<' : '<=';
|
this.userCode = `
|
int findBound(int batch, float value) {
|
int left = 0;
|
int right = numInputs;
|
int mid;
|
${loopHead}
|
mid = (left + right) / 2;
|
if (getSortedSequence(batch, mid) ${boundComparator} value) {
|
left = mid + 1;
|
} else {
|
right = mid;
|
}
|
}
|
return right;
|
}
|
|
void main() {
|
ivec2 coords = getOutputCoords();
|
int batch = coords[0];
|
int valueIndex = coords[1];
|
|
float value = getValues(batch, valueIndex);
|
|
setOutput(float(findBound(batch, value)));
|
}
|
`;
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoic2VhcmNoX3NvcnRlZF9ncHUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL3NlYXJjaF9zb3J0ZWRfZ3B1LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUkxQyxNQUFNLE9BQU8sbUJBQW1CO0lBTTlCLFlBQ0ksU0FBaUIsRUFBRSxTQUFpQixFQUFFLFNBQWlCLEVBQ3ZELElBQW9CO1FBUHhCLGtCQUFhLEdBQUcsQ0FBQyxnQkFBZ0IsRUFBRSxRQUFRLENBQUMsQ0FBQztRQUc3QyxtQkFBYyxHQUFHLENBQUMsRUFBQyxJQUFJLEVBQUUsV0FBVyxFQUFFLElBQUksRUFBRSxLQUFvQixFQUFDLENBQUMsQ0FBQztRQUtqRSxJQUFJLENBQUMsV0FBVyxHQUFHLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxDQUFDO1FBRTFDLE1BQU0sY0FBYyxHQUFHLHdCQUF3QixDQUFDO1FBQ2hELDBFQUEwRTtRQUMxRSxjQUFjO1FBQ2QsTUFBTSxjQUFjLEdBQUcsdUJBQ25CLElBQUksQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxTQUFTLEdBQUcsQ0FBQyxDQUFDLENBQUMsb0NBQW9DLENBQUM7UUFDNUUsTUFBTSxRQUFRLEdBQUcsR0FBRyxFQUFFLENBQUMsU0FBUyxDQUFDLGVBQWUsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUMsY0FBYyxDQUFDLENBQUM7WUFDaEIsY0FBYyxDQUFDO1FBRXpFLDREQUE0RDtRQUM1RCxNQUFNLGVBQWUsR0FBRyxJQUFJLEtBQUssTUFBTSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLElBQUksQ0FBQztRQUNyRCxJQUFJLENBQUMsUUFBUSxHQUFHOzs7OztXQUtULFFBQVE7OytDQUU0QixlQUFlOzs7Ozs7Ozs7Ozs7Ozs7Ozs7TUFrQnhELENBQUM7SUFDTCxDQUFDO0NBQ0YiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMiBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7ZW52fSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuaW1wb3J0IHtHUEdQVVByb2dyYW19IGZyb20gJy4vZ3BncHVfbWF0aCc7XG5pbXBvcnQge1VuaWZvcm1UeXBlfSBmcm9tICcuL3NoYWRlcl9jb21waWxlcic7XG5cbmV4cG9ydCBjbGFzcyBTZWFyY2hTb3J0ZWRQcm9ncmFtIGltcGxlbWVudHMgR1BHUFVQcm9ncmFtIHtcbiAgdmFyaWFibGVOYW1lcyA9IFsnc29ydGVkU2VxdWVuY2UnLCAndmFsdWVzJ107XG4gIG91dHB1dFNoYXBlOiBudW1iZXJbXTtcbiAgdXNlckNvZGU6IHN0cmluZztcbiAgY3VzdG9tVW5pZm9ybXMgPSBbe25hbWU6ICdudW1JbnB1dHMnLCB0eXBlOiAnaW50JyBhcyBVbmlmb3JtVHlwZX1dO1xuXG4gIGNvbnN0cnVjdG9yKFxuICAgICAgYmF0Y2hTaXplOiBudW1iZXIsIG51bUlucHV0czogbnVtYmVyLCBudW1WYWx1ZXM6IG51bWJlcixcbiAgICAgIHNpZGU6ICdsZWZ0J3wncmlnaHQnKSB7XG4gICAgdGhpcy5vdXRwdXRTaGFwZSA9IFtiYXRjaFNpemUsIG51bVZhbHVlc107XG5cbiAgICBjb25zdCB3ZWJHTDJMb29wSGVhZCA9ICd3aGlsZSAobGVmdCA8IHJpZ2h0KSB7JztcbiAgICAvLyBXZWJHTDEgZG9lc24ndCBhY2NlcHQgbm9uIGNvbnN0YW50IGxvb3AgY29uZGl0aW9ucywgc28gdXBwZXIgYm91bmQgbG9vcFxuICAgIC8vIGl0ZXJhdGlvbnMuXG4gICAgY29uc3Qgd2ViR0wxTG9vcEhlYWQgPSBgZm9yIChpbnQgaSA9IDA7IGkgPCAke1xuICAgICAgICBNYXRoLmNlaWwoTWF0aC5sb2cyKG51bUlucHV0cyArIDEpKX07ICsraSkgeyBpZiAobGVmdCA+PSByaWdodCkgYnJlYWs7YDtcbiAgICBjb25zdCBsb29wSGVhZCA9IGVudigpLmdldE51bWJlcignV0VCR0xfVkVSU0lPTicpID09PSAyID8gd2ViR0wyTG9vcEhlYWQgOlxuICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB3ZWJHTDFMb29wSGVhZDtcblxuICAgIC8vIGxlZnQgY29ycmVzcG9uZHMgdG8gbG93ZXIgYm91bmQgYW5kIHJpZ2h0IHRvIHVwcGVyIGJvdW5kLlxuICAgIGNvbnN0IGJvdW5kQ29tcGFyYXRvciA9IHNpZGUgPT09ICdsZWZ0JyA/ICc8JyA6ICc8PSc7XG4gICAgdGhpcy51c2VyQ29kZSA9IGBcbiAgICAgICBpbnQgZmluZEJvdW5kKGludCBiYXRjaCwgZmxvYXQgdmFsdWUpIHtcbiAgICAgICAgIGludCBsZWZ0ID0gMDtcbiAgICAgICAgIGludCByaWdodCA9IG51bUlucHV0cztcbiAgICAgICAgIGludCBtaWQ7XG4gICAgICAgICAke2xvb3BIZWFkfVxuICAgICAgICAgICBtaWQgPSAobGVmdCArIHJpZ2h0KSAvIDI7XG4gICAgICAgICAgIGlmIChnZXRTb3J0ZWRTZXF1ZW5jZShiYXRjaCwgbWlkKSAke2JvdW5kQ29tcGFyYXRvcn0gdmFsdWUpIHtcbiAgICAgICAgICAgICBsZWZ0ID0gbWlkICsgMTtcbiAgICAgICAgICAgfSBlbHNlIHtcbiAgICAgICAgICAgICByaWdodCA9IG1pZDtcbiAgICAgICAgICAgfVxuICAgICAgICAgfVxuICAgICAgICAgcmV0dXJuIHJpZ2h0O1xuICAgICAgIH1cblxuICAgICAgIHZvaWQgbWFpbigpIHtcbiAgICAgICAgIGl2ZWMyIGNvb3JkcyA9IGdldE91dHB1dENvb3JkcygpO1xuICAgICAgICAgaW50IGJhdGNoID0gY29vcmRzWzBdO1xuICAgICAgICAgaW50IHZhbHVlSW5kZXggPSBjb29yZHNbMV07XG5cbiAgICAgICAgIGZsb2F0IHZhbHVlID0gZ2V0VmFsdWVzKGJhdGNoLCB2YWx1ZUluZGV4KTtcblxuICAgICAgICAgc2V0T3V0cHV0KGZsb2F0KGZpbmRCb3VuZChiYXRjaCwgdmFsdWUpKSk7XG4gICAgICAgfVxuICAgICBgO1xuICB9XG59XG4iXX0=
|