/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from '../engine';
|
import { Select } from '../kernel_names';
|
import { convertToTensor } from '../tensor_util_env';
|
import { broadcastTo } from './broadcast_to';
|
import { assertAndGetBroadcastShape } from './broadcast_util';
|
import { op } from './operation';
|
/**
|
* Returns the elements, either `a` or `b` depending on the `condition`.
|
*
|
* If the condition is true, select from `a`, otherwise select from `b`.
|
*
|
* ```js
|
* const cond = tf.tensor1d([false, false, true], 'bool');
|
* const a = tf.tensor1d([1 , 2, 3]);
|
* const b = tf.tensor1d([-1, -2, -3]);
|
*
|
* a.where(cond, b).print();
|
* ```
|
*
|
* @param condition The input condition. Must be of dtype bool.
|
* @param a If `condition` is rank 1, `a` may have a higher rank but
|
* its first dimension must match the size of `condition`.
|
* @param b A tensor with the same dtype as `a` and with shape that is
|
* compatible with `a`.
|
* @return A tensor with same dtype as `a` and `b`, and shape that is
|
* broadcastable from `a` and `b`.
|
*
|
* @doc {heading: 'Operations', subheading: 'Logical'}
|
*/
|
function where_(condition, a, b) {
|
const $a = convertToTensor(a, 'a', 'where');
|
const $b = convertToTensor(b, 'b', 'where');
|
const $condition = convertToTensor(condition, 'condition', 'where', 'bool');
|
// TODO: move this logic to forward function when the broadcastTo op is
|
// implemented in WASM.
|
// Find the broadcastable shape for $condition, $a, and $b.
|
const broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
|
const $broadcastedCondition = broadcastTo($condition, broadcastShape);
|
const $broadcastedA = broadcastTo($a, broadcastShape);
|
const $broadcastedB = broadcastTo($b, broadcastShape);
|
const inputs = {
|
condition: $broadcastedCondition,
|
t: $broadcastedA,
|
e: $broadcastedB
|
};
|
return ENGINE.runKernel(Select, inputs);
|
}
|
export const where = /* @__PURE__ */ op({ where_ });
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoid2hlcmUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy93aGVyZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2pDLE9BQU8sRUFBQyxNQUFNLEVBQWUsTUFBTSxpQkFBaUIsQ0FBQztBQUdyRCxPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFHbkQsT0FBTyxFQUFDLFdBQVcsRUFBQyxNQUFNLGdCQUFnQixDQUFDO0FBQzNDLE9BQU8sRUFBQywwQkFBMEIsRUFBQyxNQUFNLGtCQUFrQixDQUFDO0FBQzVELE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFL0I7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7R0FzQkc7QUFDSCxTQUFTLE1BQU0sQ0FDWCxTQUE0QixFQUFFLENBQWUsRUFBRSxDQUFlO0lBQ2hFLE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBQzVDLE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBQzVDLE1BQU0sVUFBVSxHQUFHLGVBQWUsQ0FBQyxTQUFTLEVBQUUsV0FBVyxFQUFFLE9BQU8sRUFBRSxNQUFNLENBQUMsQ0FBQztJQUM1RSx1RUFBdUU7SUFDdkUsdUJBQXVCO0lBQ3ZCLDJEQUEyRDtJQUMzRCxNQUFNLGNBQWMsR0FBRywwQkFBMEIsQ0FDN0MsMEJBQTBCLENBQUMsVUFBVSxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsS0FBSyxDQUFDLEVBQUUsRUFBRSxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBQ3RFLE1BQU0scUJBQXFCLEdBQUcsV0FBVyxDQUFDLFVBQVUsRUFBRSxjQUFjLENBQUMsQ0FBQztJQUN0RSxNQUFNLGFBQWEsR0FBRyxXQUFXLENBQUMsRUFBRSxFQUFFLGNBQWMsQ0FBQyxDQUFDO0lBQ3RELE1BQU0sYUFBYSxHQUFHLFdBQVcsQ0FBQyxFQUFFLEVBQUUsY0FBYyxDQUFDLENBQUM7SUFFdEQsTUFBTSxNQUFNLEdBQWlCO1FBQzNCLFNBQVMsRUFBRSxxQkFBcUI7UUFDaEMsQ0FBQyxFQUFFLGFBQWE7UUFDaEIsQ0FBQyxFQUFFLGFBQWE7S0FDakIsQ0FBQztJQUNGLE9BQU8sTUFBTSxDQUFDLFNBQVMsQ0FBQyxNQUFNLEVBQUUsTUFBbUMsQ0FBQyxDQUFDO0FBQ3ZFLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxLQUFLLEdBQUcsZUFBZSxDQUFDLEVBQUUsQ0FBQyxFQUFDLE1BQU0sRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7RU5HSU5FfSBmcm9tICcuLi9lbmdpbmUnO1xuaW1wb3J0IHtTZWxlY3QsIFNlbGVjdElucHV0c30gZnJvbSAnLi4va2VybmVsX25hbWVzJztcbmltcG9ydCB7VGVuc29yfSBmcm9tICcuLi90ZW5zb3InO1xuaW1wb3J0IHtOYW1lZFRlbnNvck1hcH0gZnJvbSAnLi4vdGVuc29yX3R5cGVzJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5cbmltcG9ydCB7YnJvYWRjYXN0VG99IGZyb20gJy4vYnJvYWRjYXN0X3RvJztcbmltcG9ydCB7YXNzZXJ0QW5kR2V0QnJvYWRjYXN0U2hhcGV9IGZyb20gJy4vYnJvYWRjYXN0X3V0aWwnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuXG4vKipcbiAqIFJldHVybnMgdGhlIGVsZW1lbnRzLCBlaXRoZXIgYGFgIG9yIGBiYCBkZXBlbmRpbmcgb24gdGhlIGBjb25kaXRpb25gLlxuICpcbiAqIElmIHRoZSBjb25kaXRpb24gaXMgdHJ1ZSwgc2VsZWN0IGZyb20gYGFgLCBvdGhlcndpc2Ugc2VsZWN0IGZyb20gYGJgLlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCBjb25kID0gdGYudGVuc29yMWQoW2ZhbHNlLCBmYWxzZSwgdHJ1ZV0sICdib29sJyk7XG4gKiBjb25zdCBhID0gdGYudGVuc29yMWQoWzEgLCAyLCAzXSk7XG4gKiBjb25zdCBiID0gdGYudGVuc29yMWQoWy0xLCAtMiwgLTNdKTtcbiAqXG4gKiBhLndoZXJlKGNvbmQsIGIpLnByaW50KCk7XG4gKiBgYGBcbiAqXG4gKiBAcGFyYW0gY29uZGl0aW9uIFRoZSBpbnB1dCBjb25kaXRpb24uIE11c3QgYmUgb2YgZHR5cGUgYm9vbC5cbiAqIEBwYXJhbSBhIElmIGBjb25kaXRpb25gIGlzIHJhbmsgMSwgYGFgIG1heSBoYXZlIGEgaGlnaGVyIHJhbmsgYnV0XG4gKiAgICAgaXRzIGZpcnN0IGRpbWVuc2lvbiBtdXN0IG1hdGNoIHRoZSBzaXplIG9mIGBjb25kaXRpb25gLlxuICogQHBhcmFtIGIgQSB0ZW5zb3Igd2l0aCB0aGUgc2FtZSBkdHlwZSBhcyBgYWAgYW5kIHdpdGggc2hhcGUgdGhhdCBpc1xuICogICAgIGNvbXBhdGlibGUgd2l0aCBgYWAuXG4gKiBAcmV0dXJuIEEgdGVuc29yIHdpdGggc2FtZSBkdHlwZSBhcyBgYWAgYW5kIGBiYCwgYW5kIHNoYXBlIHRoYXQgaXNcbiAqICAgICBicm9hZGNhc3RhYmxlIGZyb20gYGFgIGFuZCBgYmAuXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ09wZXJhdGlvbnMnLCBzdWJoZWFkaW5nOiAnTG9naWNhbCd9XG4gKi9cbmZ1bmN0aW9uIHdoZXJlXzxUIGV4dGVuZHMgVGVuc29yPihcbiAgICBjb25kaXRpb246IFRlbnNvcnxUZW5zb3JMaWtlLCBhOiBUfFRlbnNvckxpa2UsIGI6IFR8VGVuc29yTGlrZSk6IFQge1xuICBjb25zdCAkYSA9IGNvbnZlcnRUb1RlbnNvcihhLCAnYScsICd3aGVyZScpO1xuICBjb25zdCAkYiA9IGNvbnZlcnRUb1RlbnNvcihiLCAnYicsICd3aGVyZScpO1xuICBjb25zdCAkY29uZGl0aW9uID0gY29udmVydFRvVGVuc29yKGNvbmRpdGlvbiwgJ2NvbmRpdGlvbicsICd3aGVyZScsICdib29sJyk7XG4gIC8vIFRPRE86IG1vdmUgdGhpcyBsb2dpYyB0byBmb3J3YXJkIGZ1bmN0aW9uIHdoZW4gdGhlIGJyb2FkY2FzdFRvIG9wIGlzXG4gIC8vIGltcGxlbWVudGVkIGluIFdBU00uXG4gIC8vIEZpbmQgdGhlIGJyb2FkY2FzdGFibGUgc2hhcGUgZm9yICRjb25kaXRpb24sICRhLCBhbmQgJGIuXG4gIGNvbnN0IGJyb2FkY2FzdFNoYXBlID0gYXNzZXJ0QW5kR2V0QnJvYWRjYXN0U2hhcGUoXG4gICAgICBhc3NlcnRBbmRHZXRCcm9hZGNhc3RTaGFwZSgkY29uZGl0aW9uLnNoYXBlLCAkYS5zaGFwZSksICRiLnNoYXBlKTtcbiAgY29uc3QgJGJyb2FkY2FzdGVkQ29uZGl0aW9uID0gYnJvYWRjYXN0VG8oJGNvbmRpdGlvbiwgYnJvYWRjYXN0U2hhcGUpO1xuICBjb25zdCAkYnJvYWRjYXN0ZWRBID0gYnJvYWRjYXN0VG8oJGEsIGJyb2FkY2FzdFNoYXBlKTtcbiAgY29uc3QgJGJyb2FkY2FzdGVkQiA9IGJyb2FkY2FzdFRvKCRiLCBicm9hZGNhc3RTaGFwZSk7XG5cbiAgY29uc3QgaW5wdXRzOiBTZWxlY3RJbnB1dHMgPSB7XG4gICAgY29uZGl0aW9uOiAkYnJvYWRjYXN0ZWRDb25kaXRpb24sXG4gICAgdDogJGJyb2FkY2FzdGVkQSxcbiAgICBlOiAkYnJvYWRjYXN0ZWRCXG4gIH07XG4gIHJldHVybiBFTkdJTkUucnVuS2VybmVsKFNlbGVjdCwgaW5wdXRzIGFzIHVua25vd24gYXMgTmFtZWRUZW5zb3JNYXApO1xufVxuXG5leHBvcnQgY29uc3Qgd2hlcmUgPSAvKiBAX19QVVJFX18gKi8gb3Aoe3doZXJlX30pO1xuIl19
|