/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { convertToTensor } from '../tensor_util_env';
|
import * as util from '../util';
|
import { gather } from './gather';
|
import { reshape } from './reshape';
|
import { squeeze } from './squeeze';
|
import { whereAsync } from './where_async';
|
/**
|
* Apply boolean mask to tensor.
|
*
|
* ```js
|
* const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
|
* const mask = tf.tensor1d([1, 0, 1], 'bool');
|
* const result = await tf.booleanMaskAsync(tensor, mask);
|
* result.print();
|
* ```
|
*
|
* @param tensor N-D tensor.
|
* @param mask K-D boolean tensor, K <= N and K must be known statically.
|
* @param axis A 0-D int Tensor representing the axis in tensor to mask from.
|
* By default, axis is 0 which will mask from the first dimension.
|
* Otherwise K + axis <= N.
|
*
|
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
|
*/
|
async function booleanMaskAsync_(tensor, mask, axis) {
|
const $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
|
const $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');
|
const axisFrom = axis == null ? 0 : axis;
|
const maskDim = $mask.rank;
|
const tensorShape = $tensor.shape;
|
util.assert(maskDim > 0, () => 'mask cannot be scalar');
|
util.assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, `mask's shape must match the first K dimensions of tensor's shape,`);
|
let leadingSize = 1;
|
for (let i = axisFrom; i < axisFrom + maskDim; i++) {
|
leadingSize *= tensorShape[i];
|
}
|
const targetTensorShape = tensorShape.slice(0, axisFrom)
|
.concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
|
const reshapedTensor = reshape($tensor, targetTensorShape);
|
const reshapedMask = reshape($mask, [-1]);
|
const positivePositions = await whereAsync(reshapedMask);
|
const indices = squeeze(positivePositions, [1]);
|
const res = gather(reshapedTensor, indices, axisFrom);
|
// Ensure no memory leak.
|
if (tensor !== $tensor) {
|
$tensor.dispose();
|
}
|
if (mask !== $mask) {
|
$mask.dispose();
|
}
|
indices.dispose();
|
reshapedTensor.dispose();
|
reshapedMask.dispose();
|
positivePositions.dispose();
|
return res;
|
}
|
export const booleanMaskAsync = booleanMaskAsync_;
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYm9vbGVhbl9tYXNrLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYm9vbGVhbl9tYXNrLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUdILE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUVuRCxPQUFPLEtBQUssSUFBSSxNQUFNLFNBQVMsQ0FBQztBQUVoQyxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sVUFBVSxDQUFDO0FBQ2hDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsVUFBVSxFQUFDLE1BQU0sZUFBZSxDQUFDO0FBRXpDOzs7Ozs7Ozs7Ozs7Ozs7OztHQWlCRztBQUNILEtBQUssVUFBVSxpQkFBaUIsQ0FDNUIsTUFBeUIsRUFBRSxJQUF1QixFQUNsRCxJQUFhO0lBQ2YsTUFBTSxPQUFPLEdBQUcsZUFBZSxDQUFDLE1BQU0sRUFBRSxRQUFRLEVBQUUsVUFBVSxDQUFDLENBQUM7SUFDOUQsTUFBTSxLQUFLLEdBQUcsZUFBZSxDQUFDLElBQUksRUFBRSxNQUFNLEVBQUUsVUFBVSxFQUFFLE1BQU0sQ0FBQyxDQUFDO0lBRWhFLE1BQU0sUUFBUSxHQUFHLElBQUksSUFBSSxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsSUFBSSxDQUFDO0lBQ3pDLE1BQU0sT0FBTyxHQUFHLEtBQUssQ0FBQyxJQUFJLENBQUM7SUFDM0IsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLEtBQUssQ0FBQztJQUVsQyxJQUFJLENBQUMsTUFBTSxDQUFDLE9BQU8sR0FBRyxDQUFDLEVBQUUsR0FBRyxFQUFFLENBQUMsdUJBQXVCLENBQUMsQ0FBQztJQUN4RCxJQUFJLENBQUMsaUJBQWlCLENBQ2xCLFdBQVcsQ0FBQyxLQUFLLENBQUMsUUFBUSxFQUFFLFFBQVEsR0FBRyxPQUFPLENBQUMsRUFBRSxLQUFLLENBQUMsS0FBSyxFQUM1RCxtRUFBbUUsQ0FBQyxDQUFDO0lBRXpFLElBQUksV0FBVyxHQUFHLENBQUMsQ0FBQztJQUNwQixLQUFLLElBQUksQ0FBQyxHQUFHLFFBQVEsRUFBRSxDQUFDLEdBQUcsUUFBUSxHQUFHLE9BQU8sRUFBRSxDQUFDLEVBQUUsRUFBRTtRQUNsRCxXQUFXLElBQUksV0FBVyxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQy9CO0lBQ0QsTUFBTSxpQkFBaUIsR0FDbkIsV0FBVyxDQUFDLEtBQUssQ0FBQyxDQUFDLEVBQUUsUUFBUSxDQUFDO1NBQ3pCLE1BQU0sQ0FBQyxDQUFDLFdBQVcsQ0FBQyxFQUFFLFdBQVcsQ0FBQyxLQUFLLENBQUMsUUFBUSxHQUFHLE9BQU8sQ0FBQyxDQUFDLENBQUM7SUFDdEUsTUFBTSxjQUFjLEdBQUcsT0FBTyxDQUFDLE9BQU8sRUFBRSxpQkFBaUIsQ0FBQyxDQUFDO0lBQzNELE1BQU0sWUFBWSxHQUFHLE9BQU8sQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDMUMsTUFBTSxpQkFBaUIsR0FBRyxNQUFNLFVBQVUsQ0FBQyxZQUFZLENBQUMsQ0FBQztJQUN6RCxNQUFNLE9BQU8sR0FBRyxPQUFPLENBQUMsaUJBQWlCLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBRWhELE1BQU0sR0FBRyxHQUFHLE1BQU0sQ0FBQyxjQUFjLEVBQUUsT0FBTyxFQUFFLFFBQVEsQ0FBQyxDQUFDO0lBRXRELHlCQUF5QjtJQUN6QixJQUFJLE1BQU0sS0FBSyxPQUFPLEVBQUU7UUFDdEIsT0FBTyxDQUFDLE9BQU8sRUFBRSxDQUFDO0tBQ25CO0lBQ0QsSUFBSSxJQUFJLEtBQUssS0FBSyxFQUFFO1FBQ2xCLEtBQUssQ0FBQyxPQUFPLEVBQUUsQ0FBQztLQUNqQjtJQUNELE9BQU8sQ0FBQyxPQUFPLEVBQUUsQ0FBQztJQUNsQixjQUFjLENBQUMsT0FBTyxFQUFFLENBQUM7SUFDekIsWUFBWSxDQUFDLE9BQU8sRUFBRSxDQUFDO0lBQ3ZCLGlCQUFpQixDQUFDLE9BQU8sRUFBRSxDQUFDO0lBRTVCLE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGdCQUFnQixHQUFHLGlCQUFpQixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5pbXBvcnQgKiBhcyB1dGlsIGZyb20gJy4uL3V0aWwnO1xuXG5pbXBvcnQge2dhdGhlcn0gZnJvbSAnLi9nYXRoZXInO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL3Jlc2hhcGUnO1xuaW1wb3J0IHtzcXVlZXplfSBmcm9tICcuL3NxdWVlemUnO1xuaW1wb3J0IHt3aGVyZUFzeW5jfSBmcm9tICcuL3doZXJlX2FzeW5jJztcblxuLyoqXG4gKiBBcHBseSBib29sZWFuIG1hc2sgdG8gdGVuc29yLlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCB0ZW5zb3IgPSB0Zi50ZW5zb3IyZChbMSwgMiwgMywgNCwgNSwgNl0sIFszLCAyXSk7XG4gKiBjb25zdCBtYXNrID0gdGYudGVuc29yMWQoWzEsIDAsIDFdLCAnYm9vbCcpO1xuICogY29uc3QgcmVzdWx0ID0gYXdhaXQgdGYuYm9vbGVhbk1hc2tBc3luYyh0ZW5zb3IsIG1hc2spO1xuICogcmVzdWx0LnByaW50KCk7XG4gKiBgYGBcbiAqXG4gKiBAcGFyYW0gdGVuc29yIE4tRCB0ZW5zb3IuXG4gKiBAcGFyYW0gbWFzayBLLUQgYm9vbGVhbiB0ZW5zb3IsIEsgPD0gTiBhbmQgSyBtdXN0IGJlIGtub3duIHN0YXRpY2FsbHkuXG4gKiBAcGFyYW0gYXhpcyBBIDAtRCBpbnQgVGVuc29yIHJlcHJlc2VudGluZyB0aGUgYXhpcyBpbiB0ZW5zb3IgdG8gbWFzayBmcm9tLlxuICogICAgIEJ5IGRlZmF1bHQsIGF4aXMgaXMgMCB3aGljaCB3aWxsIG1hc2sgZnJvbSB0aGUgZmlyc3QgZGltZW5zaW9uLlxuICogICAgIE90aGVyd2lzZSBLICsgYXhpcyA8PSBOLlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdUZW5zb3JzJywgc3ViaGVhZGluZzogJ1NsaWNpbmcgYW5kIEpvaW5pbmcnfVxuICovXG5hc3luYyBmdW5jdGlvbiBib29sZWFuTWFza0FzeW5jXyhcbiAgICB0ZW5zb3I6IFRlbnNvcnxUZW5zb3JMaWtlLCBtYXNrOiBUZW5zb3J8VGVuc29yTGlrZSxcbiAgICBheGlzPzogbnVtYmVyKTogUHJvbWlzZTxUZW5zb3I+IHtcbiAgY29uc3QgJHRlbnNvciA9IGNvbnZlcnRUb1RlbnNvcih0ZW5zb3IsICd0ZW5zb3InLCAnYm9vbE1hc2snKTtcbiAgY29uc3QgJG1hc2sgPSBjb252ZXJ0VG9UZW5zb3IobWFzaywgJ21hc2snLCAnYm9vbE1hc2snLCAnYm9vbCcpO1xuXG4gIGNvbnN0IGF4aXNGcm9tID0gYXhpcyA9PSBudWxsID8gMCA6IGF4aXM7XG4gIGNvbnN0IG1hc2tEaW0gPSAkbWFzay5yYW5rO1xuICBjb25zdCB0ZW5zb3JTaGFwZSA9ICR0ZW5zb3Iuc2hhcGU7XG5cbiAgdXRpbC5hc3NlcnQobWFza0RpbSA+IDAsICgpID0+ICdtYXNrIGNhbm5vdCBiZSBzY2FsYXInKTtcbiAgdXRpbC5hc3NlcnRTaGFwZXNNYXRjaChcbiAgICAgIHRlbnNvclNoYXBlLnNsaWNlKGF4aXNGcm9tLCBheGlzRnJvbSArIG1hc2tEaW0pLCAkbWFzay5zaGFwZSxcbiAgICAgIGBtYXNrJ3Mgc2hhcGUgbXVzdCBtYXRjaCB0aGUgZmlyc3QgSyBkaW1lbnNpb25zIG9mIHRlbnNvcidzIHNoYXBlLGApO1xuXG4gIGxldCBsZWFkaW5nU2l6ZSA9IDE7XG4gIGZvciAobGV0IGkgPSBheGlzRnJvbTsgaSA8IGF4aXNGcm9tICsgbWFza0RpbTsgaSsrKSB7XG4gICAgbGVhZGluZ1NpemUgKj0gdGVuc29yU2hhcGVbaV07XG4gIH1cbiAgY29uc3QgdGFyZ2V0VGVuc29yU2hhcGUgPVxuICAgICAgdGVuc29yU2hhcGUuc2xpY2UoMCwgYXhpc0Zyb20pXG4gICAgICAgICAgLmNvbmNhdChbbGVhZGluZ1NpemVdLCB0ZW5zb3JTaGFwZS5zbGljZShheGlzRnJvbSArIG1hc2tEaW0pKTtcbiAgY29uc3QgcmVzaGFwZWRUZW5zb3IgPSByZXNoYXBlKCR0ZW5zb3IsIHRhcmdldFRlbnNvclNoYXBlKTtcbiAgY29uc3QgcmVzaGFwZWRNYXNrID0gcmVzaGFwZSgkbWFzaywgWy0xXSk7XG4gIGNvbnN0IHBvc2l0aXZlUG9zaXRpb25zID0gYXdhaXQgd2hlcmVBc3luYyhyZXNoYXBlZE1hc2spO1xuICBjb25zdCBpbmRpY2VzID0gc3F1ZWV6ZShwb3NpdGl2ZVBvc2l0aW9ucywgWzFdKTtcblxuICBjb25zdCByZXMgPSBnYXRoZXIocmVzaGFwZWRUZW5zb3IsIGluZGljZXMsIGF4aXNGcm9tKTtcblxuICAvLyBFbnN1cmUgbm8gbWVtb3J5IGxlYWsuXG4gIGlmICh0ZW5zb3IgIT09ICR0ZW5zb3IpIHtcbiAgICAkdGVuc29yLmRpc3Bvc2UoKTtcbiAgfVxuICBpZiAobWFzayAhPT0gJG1hc2spIHtcbiAgICAkbWFzay5kaXNwb3NlKCk7XG4gIH1cbiAgaW5kaWNlcy5kaXNwb3NlKCk7XG4gIHJlc2hhcGVkVGVuc29yLmRpc3Bvc2UoKTtcbiAgcmVzaGFwZWRNYXNrLmRpc3Bvc2UoKTtcbiAgcG9zaXRpdmVQb3NpdGlvbnMuZGlzcG9zZSgpO1xuXG4gIHJldHVybiByZXM7XG59XG5cbmV4cG9ydCBjb25zdCBib29sZWFuTWFza0FzeW5jID0gYm9vbGVhbk1hc2tBc3luY187XG4iXX0=
|