gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { Tensor } from '../tensor';
import { convertToTensor } from '../tensor_util_env';
import * as util from '../util';
import { add } from './add';
import { div } from './div';
import { getNoiseShape } from './dropout_util';
import { floor } from './floor';
import { mul } from './mul';
import { op } from './operation';
import { randomUniform } from './random_uniform';
/**
 * Computes dropout.
 *
 * ```js
 * const x = tf.tensor1d([1, 2, 2, 1]);
 * const rate = 0.75;
 * const output = tf.dropout(x, rate);
 * output.print();
 * ```
 *
 * @param x A floating point Tensor or TensorLike.
 * @param rate A float in the range [0, 1). The probability that each element
 *   of x is discarded.
 * @param noiseShape An array of numbers of type int32, representing the
 * shape for randomly generated keep/drop flags. If the noiseShape has null
 * value, it will be automatically replaced with the x's relative dimension
 * size. Optional.
 * @param seed Used to create random seeds. Optional.
 * @returns A Tensor of the same shape of x.
 *
 * @doc {heading: 'Operations', subheading: 'Dropout'}
 */
function dropout_(x, rate, noiseShape, seed) {
    const $x = convertToTensor(x, 'x', 'dropout');
    util.assert($x.dtype === 'float32', () => `x has to be a floating point tensor since it's going to be ` +
        `scaled, but got a ${$x.dtype} tensor instead.`);
    util.assert(rate >= 0 && rate < 1, () => `rate must be a float in the range [0, 1), but got ${rate}.`);
    if (rate === 0) {
        return x instanceof Tensor ? $x.clone() : $x;
    }
    const $noiseShape = getNoiseShape($x, noiseShape);
    const keepProb = 1 - rate;
    const multiplier = div(floor(add(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
    return mul($x, multiplier);
}
export const dropout = /* @__PURE__ */ op({ dropout_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZHJvcG91dC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL2Ryb3BvdXQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFFbkQsT0FBTyxLQUFLLElBQUksTUFBTSxTQUFTLENBQUM7QUFFaEMsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxhQUFhLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUM3QyxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLEVBQUUsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUMvQixPQUFPLEVBQUMsYUFBYSxFQUFDLE1BQU0sa0JBQWtCLENBQUM7QUFFL0M7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7OztHQXFCRztBQUNILFNBQVMsUUFBUSxDQUNiLENBQW9CLEVBQUUsSUFBWSxFQUFFLFVBQXFCLEVBQ3pELElBQW9CO0lBQ3RCLE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBRTlDLElBQUksQ0FBQyxNQUFNLENBQ1AsRUFBRSxDQUFDLEtBQUssS0FBSyxTQUFTLEVBQ3RCLEdBQUcsRUFBRSxDQUFDLDZEQUE2RDtRQUMvRCxxQkFBcUIsRUFBRSxDQUFDLEtBQUssa0JBQWtCLENBQUMsQ0FBQztJQUN6RCxJQUFJLENBQUMsTUFBTSxDQUNQLElBQUksSUFBSSxDQUFDLElBQUksSUFBSSxHQUFHLENBQUMsRUFDckIsR0FBRyxFQUFFLENBQUMscURBQXFELElBQUksR0FBRyxDQUFDLENBQUM7SUFFeEUsSUFBSSxJQUFJLEtBQUssQ0FBQyxFQUFFO1FBQ2QsT0FBTyxDQUFDLFlBQVksTUFBTSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQztLQUM5QztJQUVELE1BQU0sV0FBVyxHQUFHLGFBQWEsQ0FBQyxFQUFFLEVBQUUsVUFBVSxDQUFDLENBQUM7SUFDbEQsTUFBTSxRQUFRLEdBQUcsQ0FBQyxHQUFHLElBQUksQ0FBQztJQUMxQixNQUFNLFVBQVUsR0FBRyxHQUFHLENBQ2xCLEtBQUssQ0FBQyxHQUFHLENBQUMsYUFBYSxDQUFDLFdBQVcsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLFNBQVMsRUFBRSxJQUFJLENBQUMsRUFBRSxRQUFRLENBQUMsQ0FBQyxFQUN2RSxRQUFRLENBQUMsQ0FBQztJQUVkLE9BQU8sR0FBRyxDQUFDLEVBQUUsRUFBRSxVQUFVLENBQUMsQ0FBQztBQUM3QixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sT0FBTyxHQUFHLGVBQWUsQ0FBQyxFQUFFLENBQUMsRUFBQyxRQUFRLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5pbXBvcnQgKiBhcyB1dGlsIGZyb20gJy4uL3V0aWwnO1xuXG5pbXBvcnQge2FkZH0gZnJvbSAnLi9hZGQnO1xuaW1wb3J0IHtkaXZ9IGZyb20gJy4vZGl2JztcbmltcG9ydCB7Z2V0Tm9pc2VTaGFwZX0gZnJvbSAnLi9kcm9wb3V0X3V0aWwnO1xuaW1wb3J0IHtmbG9vcn0gZnJvbSAnLi9mbG9vcic7XG5pbXBvcnQge211bH0gZnJvbSAnLi9tdWwnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtyYW5kb21Vbmlmb3JtfSBmcm9tICcuL3JhbmRvbV91bmlmb3JtJztcblxuLyoqXG4gKiBDb21wdXRlcyBkcm9wb3V0LlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCB4ID0gdGYudGVuc29yMWQoWzEsIDIsIDIsIDFdKTtcbiAqIGNvbnN0IHJhdGUgPSAwLjc1O1xuICogY29uc3Qgb3V0cHV0ID0gdGYuZHJvcG91dCh4LCByYXRlKTtcbiAqIG91dHB1dC5wcmludCgpO1xuICogYGBgXG4gKlxuICogQHBhcmFtIHggQSBmbG9hdGluZyBwb2ludCBUZW5zb3Igb3IgVGVuc29yTGlrZS5cbiAqIEBwYXJhbSByYXRlIEEgZmxvYXQgaW4gdGhlIHJhbmdlIFswLCAxKS4gVGhlIHByb2JhYmlsaXR5IHRoYXQgZWFjaCBlbGVtZW50XG4gKiAgIG9mIHggaXMgZGlzY2FyZGVkLlxuICogQHBhcmFtIG5vaXNlU2hhcGUgQW4gYXJyYXkgb2YgbnVtYmVycyBvZiB0eXBlIGludDMyLCByZXByZXNlbnRpbmcgdGhlXG4gKiBzaGFwZSBmb3IgcmFuZG9tbHkgZ2VuZXJhdGVkIGtlZXAvZHJvcCBmbGFncy4gSWYgdGhlIG5vaXNlU2hhcGUgaGFzIG51bGxcbiAqIHZhbHVlLCBpdCB3aWxsIGJlIGF1dG9tYXRpY2FsbHkgcmVwbGFjZWQgd2l0aCB0aGUgeCdzIHJlbGF0aXZlIGRpbWVuc2lvblxuICogc2l6ZS4gT3B0aW9uYWwuXG4gKiBAcGFyYW0gc2VlZCBVc2VkIHRvIGNyZWF0ZSByYW5kb20gc2VlZHMuIE9wdGlvbmFsLlxuICogQHJldHVybnMgQSBUZW5zb3Igb2YgdGhlIHNhbWUgc2hhcGUgb2YgeC5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdEcm9wb3V0J31cbiAqL1xuZnVuY3Rpb24gZHJvcG91dF8oXG4gICAgeDogVGVuc29yfFRlbnNvckxpa2UsIHJhdGU6IG51bWJlciwgbm9pc2VTaGFwZT86IG51bWJlcltdLFxuICAgIHNlZWQ/OiBudW1iZXJ8c3RyaW5nKTogVGVuc29yIHtcbiAgY29uc3QgJHggPSBjb252ZXJ0VG9UZW5zb3IoeCwgJ3gnLCAnZHJvcG91dCcpO1xuXG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgJHguZHR5cGUgPT09ICdmbG9hdDMyJyxcbiAgICAgICgpID0+IGB4IGhhcyB0byBiZSBhIGZsb2F0aW5nIHBvaW50IHRlbnNvciBzaW5jZSBpdCdzIGdvaW5nIHRvIGJlIGAgK1xuICAgICAgICAgIGBzY2FsZWQsIGJ1dCBnb3QgYSAkeyR4LmR0eXBlfSB0ZW5zb3IgaW5zdGVhZC5gKTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICByYXRlID49IDAgJiYgcmF0ZSA8IDEsXG4gICAgICAoKSA9PiBgcmF0ZSBtdXN0IGJlIGEgZmxvYXQgaW4gdGhlIHJhbmdlIFswLCAxKSwgYnV0IGdvdCAke3JhdGV9LmApO1xuXG4gIGlmIChyYXRlID09PSAwKSB7XG4gICAgcmV0dXJuIHggaW5zdGFuY2VvZiBUZW5zb3IgPyAkeC5jbG9uZSgpIDogJHg7XG4gIH1cblxuICBjb25zdCAkbm9pc2VTaGFwZSA9IGdldE5vaXNlU2hhcGUoJHgsIG5vaXNlU2hhcGUpO1xuICBjb25zdCBrZWVwUHJvYiA9IDEgLSByYXRlO1xuICBjb25zdCBtdWx0aXBsaWVyID0gZGl2KFxuICAgICAgZmxvb3IoYWRkKHJhbmRvbVVuaWZvcm0oJG5vaXNlU2hhcGUsIDAsIDEsICdmbG9hdDMyJywgc2VlZCksIGtlZXBQcm9iKSksXG4gICAgICBrZWVwUHJvYik7XG5cbiAgcmV0dXJuIG11bCgkeCwgbXVsdGlwbGllcik7XG59XG5cbmV4cG9ydCBjb25zdCBkcm9wb3V0ID0gLyogQF9fUFVSRV9fICovIG9wKHtkcm9wb3V0X30pO1xuIl19