gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { ENGINE } from '../engine';
import { Select } from '../kernel_names';
import { convertToTensor } from '../tensor_util_env';
import { broadcastTo } from './broadcast_to';
import { assertAndGetBroadcastShape } from './broadcast_util';
import { op } from './operation';
/**
 * Returns the elements, either `a` or `b` depending on the `condition`.
 *
 * If the condition is true, select from `a`, otherwise select from `b`.
 *
 * ```js
 * const cond = tf.tensor1d([false, false, true], 'bool');
 * const a = tf.tensor1d([1 , 2, 3]);
 * const b = tf.tensor1d([-1, -2, -3]);
 *
 * a.where(cond, b).print();
 * ```
 *
 * @param condition The input condition. Must be of dtype bool.
 * @param a If `condition` is rank 1, `a` may have a higher rank but
 *     its first dimension must match the size of `condition`.
 * @param b A tensor with the same dtype as `a` and with shape that is
 *     compatible with `a`.
 * @return A tensor with same dtype as `a` and `b`, and shape that is
 *     broadcastable from `a` and `b`.
 *
 * @doc {heading: 'Operations', subheading: 'Logical'}
 */
function where_(condition, a, b) {
    const $a = convertToTensor(a, 'a', 'where');
    const $b = convertToTensor(b, 'b', 'where');
    const $condition = convertToTensor(condition, 'condition', 'where', 'bool');
    // TODO: move this logic to forward function when the broadcastTo op is
    // implemented in WASM.
    // Find the broadcastable shape for $condition, $a, and $b.
    const broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
    const $broadcastedCondition = broadcastTo($condition, broadcastShape);
    const $broadcastedA = broadcastTo($a, broadcastShape);
    const $broadcastedB = broadcastTo($b, broadcastShape);
    const inputs = {
        condition: $broadcastedCondition,
        t: $broadcastedA,
        e: $broadcastedB
    };
    return ENGINE.runKernel(Select, inputs);
}
export const where = /* @__PURE__ */ op({ where_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoid2hlcmUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy93aGVyZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2pDLE9BQU8sRUFBQyxNQUFNLEVBQWUsTUFBTSxpQkFBaUIsQ0FBQztBQUdyRCxPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFHbkQsT0FBTyxFQUFDLFdBQVcsRUFBQyxNQUFNLGdCQUFnQixDQUFDO0FBQzNDLE9BQU8sRUFBQywwQkFBMEIsRUFBQyxNQUFNLGtCQUFrQixDQUFDO0FBQzVELE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFL0I7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7R0FzQkc7QUFDSCxTQUFTLE1BQU0sQ0FDWCxTQUE0QixFQUFFLENBQWUsRUFBRSxDQUFlO0lBQ2hFLE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBQzVDLE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBQzVDLE1BQU0sVUFBVSxHQUFHLGVBQWUsQ0FBQyxTQUFTLEVBQUUsV0FBVyxFQUFFLE9BQU8sRUFBRSxNQUFNLENBQUMsQ0FBQztJQUM1RSx1RUFBdUU7SUFDdkUsdUJBQXVCO0lBQ3ZCLDJEQUEyRDtJQUMzRCxNQUFNLGNBQWMsR0FBRywwQkFBMEIsQ0FDN0MsMEJBQTBCLENBQUMsVUFBVSxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsS0FBSyxDQUFDLEVBQUUsRUFBRSxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBQ3RFLE1BQU0scUJBQXFCLEdBQUcsV0FBVyxDQUFDLFVBQVUsRUFBRSxjQUFjLENBQUMsQ0FBQztJQUN0RSxNQUFNLGFBQWEsR0FBRyxXQUFXLENBQUMsRUFBRSxFQUFFLGNBQWMsQ0FBQyxDQUFDO0lBQ3RELE1BQU0sYUFBYSxHQUFHLFdBQVcsQ0FBQyxFQUFFLEVBQUUsY0FBYyxDQUFDLENBQUM7SUFFdEQsTUFBTSxNQUFNLEdBQWlCO1FBQzNCLFNBQVMsRUFBRSxxQkFBcUI7UUFDaEMsQ0FBQyxFQUFFLGFBQWE7UUFDaEIsQ0FBQyxFQUFFLGFBQWE7S0FDakIsQ0FBQztJQUNGLE9BQU8sTUFBTSxDQUFDLFNBQVMsQ0FBQyxNQUFNLEVBQUUsTUFBbUMsQ0FBQyxDQUFDO0FBQ3ZFLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxLQUFLLEdBQUcsZUFBZSxDQUFDLEVBQUUsQ0FBQyxFQUFDLE1BQU0sRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7RU5HSU5FfSBmcm9tICcuLi9lbmdpbmUnO1xuaW1wb3J0IHtTZWxlY3QsIFNlbGVjdElucHV0c30gZnJvbSAnLi4va2VybmVsX25hbWVzJztcbmltcG9ydCB7VGVuc29yfSBmcm9tICcuLi90ZW5zb3InO1xuaW1wb3J0IHtOYW1lZFRlbnNvck1hcH0gZnJvbSAnLi4vdGVuc29yX3R5cGVzJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5cbmltcG9ydCB7YnJvYWRjYXN0VG99IGZyb20gJy4vYnJvYWRjYXN0X3RvJztcbmltcG9ydCB7YXNzZXJ0QW5kR2V0QnJvYWRjYXN0U2hhcGV9IGZyb20gJy4vYnJvYWRjYXN0X3V0aWwnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuXG4vKipcbiAqIFJldHVybnMgdGhlIGVsZW1lbnRzLCBlaXRoZXIgYGFgIG9yIGBiYCBkZXBlbmRpbmcgb24gdGhlIGBjb25kaXRpb25gLlxuICpcbiAqIElmIHRoZSBjb25kaXRpb24gaXMgdHJ1ZSwgc2VsZWN0IGZyb20gYGFgLCBvdGhlcndpc2Ugc2VsZWN0IGZyb20gYGJgLlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCBjb25kID0gdGYudGVuc29yMWQoW2ZhbHNlLCBmYWxzZSwgdHJ1ZV0sICdib29sJyk7XG4gKiBjb25zdCBhID0gdGYudGVuc29yMWQoWzEgLCAyLCAzXSk7XG4gKiBjb25zdCBiID0gdGYudGVuc29yMWQoWy0xLCAtMiwgLTNdKTtcbiAqXG4gKiBhLndoZXJlKGNvbmQsIGIpLnByaW50KCk7XG4gKiBgYGBcbiAqXG4gKiBAcGFyYW0gY29uZGl0aW9uIFRoZSBpbnB1dCBjb25kaXRpb24uIE11c3QgYmUgb2YgZHR5cGUgYm9vbC5cbiAqIEBwYXJhbSBhIElmIGBjb25kaXRpb25gIGlzIHJhbmsgMSwgYGFgIG1heSBoYXZlIGEgaGlnaGVyIHJhbmsgYnV0XG4gKiAgICAgaXRzIGZpcnN0IGRpbWVuc2lvbiBtdXN0IG1hdGNoIHRoZSBzaXplIG9mIGBjb25kaXRpb25gLlxuICogQHBhcmFtIGIgQSB0ZW5zb3Igd2l0aCB0aGUgc2FtZSBkdHlwZSBhcyBgYWAgYW5kIHdpdGggc2hhcGUgdGhhdCBpc1xuICogICAgIGNvbXBhdGlibGUgd2l0aCBgYWAuXG4gKiBAcmV0dXJuIEEgdGVuc29yIHdpdGggc2FtZSBkdHlwZSBhcyBgYWAgYW5kIGBiYCwgYW5kIHNoYXBlIHRoYXQgaXNcbiAqICAgICBicm9hZGNhc3RhYmxlIGZyb20gYGFgIGFuZCBgYmAuXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ09wZXJhdGlvbnMnLCBzdWJoZWFkaW5nOiAnTG9naWNhbCd9XG4gKi9cbmZ1bmN0aW9uIHdoZXJlXzxUIGV4dGVuZHMgVGVuc29yPihcbiAgICBjb25kaXRpb246IFRlbnNvcnxUZW5zb3JMaWtlLCBhOiBUfFRlbnNvckxpa2UsIGI6IFR8VGVuc29yTGlrZSk6IFQge1xuICBjb25zdCAkYSA9IGNvbnZlcnRUb1RlbnNvcihhLCAnYScsICd3aGVyZScpO1xuICBjb25zdCAkYiA9IGNvbnZlcnRUb1RlbnNvcihiLCAnYicsICd3aGVyZScpO1xuICBjb25zdCAkY29uZGl0aW9uID0gY29udmVydFRvVGVuc29yKGNvbmRpdGlvbiwgJ2NvbmRpdGlvbicsICd3aGVyZScsICdib29sJyk7XG4gIC8vIFRPRE86IG1vdmUgdGhpcyBsb2dpYyB0byBmb3J3YXJkIGZ1bmN0aW9uIHdoZW4gdGhlIGJyb2FkY2FzdFRvIG9wIGlzXG4gIC8vIGltcGxlbWVudGVkIGluIFdBU00uXG4gIC8vIEZpbmQgdGhlIGJyb2FkY2FzdGFibGUgc2hhcGUgZm9yICRjb25kaXRpb24sICRhLCBhbmQgJGIuXG4gIGNvbnN0IGJyb2FkY2FzdFNoYXBlID0gYXNzZXJ0QW5kR2V0QnJvYWRjYXN0U2hhcGUoXG4gICAgICBhc3NlcnRBbmRHZXRCcm9hZGNhc3RTaGFwZSgkY29uZGl0aW9uLnNoYXBlLCAkYS5zaGFwZSksICRiLnNoYXBlKTtcbiAgY29uc3QgJGJyb2FkY2FzdGVkQ29uZGl0aW9uID0gYnJvYWRjYXN0VG8oJGNvbmRpdGlvbiwgYnJvYWRjYXN0U2hhcGUpO1xuICBjb25zdCAkYnJvYWRjYXN0ZWRBID0gYnJvYWRjYXN0VG8oJGEsIGJyb2FkY2FzdFNoYXBlKTtcbiAgY29uc3QgJGJyb2FkY2FzdGVkQiA9IGJyb2FkY2FzdFRvKCRiLCBicm9hZGNhc3RTaGFwZSk7XG5cbiAgY29uc3QgaW5wdXRzOiBTZWxlY3RJbnB1dHMgPSB7XG4gICAgY29uZGl0aW9uOiAkYnJvYWRjYXN0ZWRDb25kaXRpb24sXG4gICAgdDogJGJyb2FkY2FzdGVkQSxcbiAgICBlOiAkYnJvYWRjYXN0ZWRCXG4gIH07XG4gIHJldHVybiBFTkdJTkUucnVuS2VybmVsKFNlbGVjdCwgaW5wdXRzIGFzIHVua25vd24gYXMgTmFtZWRUZW5zb3JNYXApO1xufVxuXG5leHBvcnQgY29uc3Qgd2hlcmUgPSAvKiBAX19QVVJFX18gKi8gb3Aoe3doZXJlX30pO1xuIl19