/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { Tensor } from '../tensor';
|
import { convertToTensor } from '../tensor_util_env';
|
import * as util from '../util';
|
import { add } from './add';
|
import { div } from './div';
|
import { getNoiseShape } from './dropout_util';
|
import { floor } from './floor';
|
import { mul } from './mul';
|
import { op } from './operation';
|
import { randomUniform } from './random_uniform';
|
/**
|
* Computes dropout.
|
*
|
* ```js
|
* const x = tf.tensor1d([1, 2, 2, 1]);
|
* const rate = 0.75;
|
* const output = tf.dropout(x, rate);
|
* output.print();
|
* ```
|
*
|
* @param x A floating point Tensor or TensorLike.
|
* @param rate A float in the range [0, 1). The probability that each element
|
* of x is discarded.
|
* @param noiseShape An array of numbers of type int32, representing the
|
* shape for randomly generated keep/drop flags. If the noiseShape has null
|
* value, it will be automatically replaced with the x's relative dimension
|
* size. Optional.
|
* @param seed Used to create random seeds. Optional.
|
* @returns A Tensor of the same shape of x.
|
*
|
* @doc {heading: 'Operations', subheading: 'Dropout'}
|
*/
|
function dropout_(x, rate, noiseShape, seed) {
|
const $x = convertToTensor(x, 'x', 'dropout');
|
util.assert($x.dtype === 'float32', () => `x has to be a floating point tensor since it's going to be ` +
|
`scaled, but got a ${$x.dtype} tensor instead.`);
|
util.assert(rate >= 0 && rate < 1, () => `rate must be a float in the range [0, 1), but got ${rate}.`);
|
if (rate === 0) {
|
return x instanceof Tensor ? $x.clone() : $x;
|
}
|
const $noiseShape = getNoiseShape($x, noiseShape);
|
const keepProb = 1 - rate;
|
const multiplier = div(floor(add(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
|
return mul($x, multiplier);
|
}
|
export const dropout = /* @__PURE__ */ op({ dropout_ });
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZHJvcG91dC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL2Ryb3BvdXQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFFbkQsT0FBTyxLQUFLLElBQUksTUFBTSxTQUFTLENBQUM7QUFFaEMsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxhQUFhLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUM3QyxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLEVBQUUsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUMvQixPQUFPLEVBQUMsYUFBYSxFQUFDLE1BQU0sa0JBQWtCLENBQUM7QUFFL0M7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7OztHQXFCRztBQUNILFNBQVMsUUFBUSxDQUNiLENBQW9CLEVBQUUsSUFBWSxFQUFFLFVBQXFCLEVBQ3pELElBQW9CO0lBQ3RCLE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBRTlDLElBQUksQ0FBQyxNQUFNLENBQ1AsRUFBRSxDQUFDLEtBQUssS0FBSyxTQUFTLEVBQ3RCLEdBQUcsRUFBRSxDQUFDLDZEQUE2RDtRQUMvRCxxQkFBcUIsRUFBRSxDQUFDLEtBQUssa0JBQWtCLENBQUMsQ0FBQztJQUN6RCxJQUFJLENBQUMsTUFBTSxDQUNQLElBQUksSUFBSSxDQUFDLElBQUksSUFBSSxHQUFHLENBQUMsRUFDckIsR0FBRyxFQUFFLENBQUMscURBQXFELElBQUksR0FBRyxDQUFDLENBQUM7SUFFeEUsSUFBSSxJQUFJLEtBQUssQ0FBQyxFQUFFO1FBQ2QsT0FBTyxDQUFDLFlBQVksTUFBTSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQztLQUM5QztJQUVELE1BQU0sV0FBVyxHQUFHLGFBQWEsQ0FBQyxFQUFFLEVBQUUsVUFBVSxDQUFDLENBQUM7SUFDbEQsTUFBTSxRQUFRLEdBQUcsQ0FBQyxHQUFHLElBQUksQ0FBQztJQUMxQixNQUFNLFVBQVUsR0FBRyxHQUFHLENBQ2xCLEtBQUssQ0FBQyxHQUFHLENBQUMsYUFBYSxDQUFDLFdBQVcsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLFNBQVMsRUFBRSxJQUFJLENBQUMsRUFBRSxRQUFRLENBQUMsQ0FBQyxFQUN2RSxRQUFRLENBQUMsQ0FBQztJQUVkLE9BQU8sR0FBRyxDQUFDLEVBQUUsRUFBRSxVQUFVLENBQUMsQ0FBQztBQUM3QixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sT0FBTyxHQUFHLGVBQWUsQ0FBQyxFQUFFLENBQUMsRUFBQyxRQUFRLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5pbXBvcnQgKiBhcyB1dGlsIGZyb20gJy4uL3V0aWwnO1xuXG5pbXBvcnQge2FkZH0gZnJvbSAnLi9hZGQnO1xuaW1wb3J0IHtkaXZ9IGZyb20gJy4vZGl2JztcbmltcG9ydCB7Z2V0Tm9pc2VTaGFwZX0gZnJvbSAnLi9kcm9wb3V0X3V0aWwnO1xuaW1wb3J0IHtmbG9vcn0gZnJvbSAnLi9mbG9vcic7XG5pbXBvcnQge211bH0gZnJvbSAnLi9tdWwnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtyYW5kb21Vbmlmb3JtfSBmcm9tICcuL3JhbmRvbV91bmlmb3JtJztcblxuLyoqXG4gKiBDb21wdXRlcyBkcm9wb3V0LlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCB4ID0gdGYudGVuc29yMWQoWzEsIDIsIDIsIDFdKTtcbiAqIGNvbnN0IHJhdGUgPSAwLjc1O1xuICogY29uc3Qgb3V0cHV0ID0gdGYuZHJvcG91dCh4LCByYXRlKTtcbiAqIG91dHB1dC5wcmludCgpO1xuICogYGBgXG4gKlxuICogQHBhcmFtIHggQSBmbG9hdGluZyBwb2ludCBUZW5zb3Igb3IgVGVuc29yTGlrZS5cbiAqIEBwYXJhbSByYXRlIEEgZmxvYXQgaW4gdGhlIHJhbmdlIFswLCAxKS4gVGhlIHByb2JhYmlsaXR5IHRoYXQgZWFjaCBlbGVtZW50XG4gKiAgIG9mIHggaXMgZGlzY2FyZGVkLlxuICogQHBhcmFtIG5vaXNlU2hhcGUgQW4gYXJyYXkgb2YgbnVtYmVycyBvZiB0eXBlIGludDMyLCByZXByZXNlbnRpbmcgdGhlXG4gKiBzaGFwZSBmb3IgcmFuZG9tbHkgZ2VuZXJhdGVkIGtlZXAvZHJvcCBmbGFncy4gSWYgdGhlIG5vaXNlU2hhcGUgaGFzIG51bGxcbiAqIHZhbHVlLCBpdCB3aWxsIGJlIGF1dG9tYXRpY2FsbHkgcmVwbGFjZWQgd2l0aCB0aGUgeCdzIHJlbGF0aXZlIGRpbWVuc2lvblxuICogc2l6ZS4gT3B0aW9uYWwuXG4gKiBAcGFyYW0gc2VlZCBVc2VkIHRvIGNyZWF0ZSByYW5kb20gc2VlZHMuIE9wdGlvbmFsLlxuICogQHJldHVybnMgQSBUZW5zb3Igb2YgdGhlIHNhbWUgc2hhcGUgb2YgeC5cbiAqXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdEcm9wb3V0J31cbiAqL1xuZnVuY3Rpb24gZHJvcG91dF8oXG4gICAgeDogVGVuc29yfFRlbnNvckxpa2UsIHJhdGU6IG51bWJlciwgbm9pc2VTaGFwZT86IG51bWJlcltdLFxuICAgIHNlZWQ/OiBudW1iZXJ8c3RyaW5nKTogVGVuc29yIHtcbiAgY29uc3QgJHggPSBjb252ZXJ0VG9UZW5zb3IoeCwgJ3gnLCAnZHJvcG91dCcpO1xuXG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgJHguZHR5cGUgPT09ICdmbG9hdDMyJyxcbiAgICAgICgpID0+IGB4IGhhcyB0byBiZSBhIGZsb2F0aW5nIHBvaW50IHRlbnNvciBzaW5jZSBpdCdzIGdvaW5nIHRvIGJlIGAgK1xuICAgICAgICAgIGBzY2FsZWQsIGJ1dCBnb3QgYSAkeyR4LmR0eXBlfSB0ZW5zb3IgaW5zdGVhZC5gKTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICByYXRlID49IDAgJiYgcmF0ZSA8IDEsXG4gICAgICAoKSA9PiBgcmF0ZSBtdXN0IGJlIGEgZmxvYXQgaW4gdGhlIHJhbmdlIFswLCAxKSwgYnV0IGdvdCAke3JhdGV9LmApO1xuXG4gIGlmIChyYXRlID09PSAwKSB7XG4gICAgcmV0dXJuIHggaW5zdGFuY2VvZiBUZW5zb3IgPyAkeC5jbG9uZSgpIDogJHg7XG4gIH1cblxuICBjb25zdCAkbm9pc2VTaGFwZSA9IGdldE5vaXNlU2hhcGUoJHgsIG5vaXNlU2hhcGUpO1xuICBjb25zdCBrZWVwUHJvYiA9IDEgLSByYXRlO1xuICBjb25zdCBtdWx0aXBsaWVyID0gZGl2KFxuICAgICAgZmxvb3IoYWRkKHJhbmRvbVVuaWZvcm0oJG5vaXNlU2hhcGUsIDAsIDEsICdmbG9hdDMyJywgc2VlZCksIGtlZXBQcm9iKSksXG4gICAgICBrZWVwUHJvYik7XG5cbiAgcmV0dXJuIG11bCgkeCwgbXVsdGlwbGllcik7XG59XG5cbmV4cG9ydCBjb25zdCBkcm9wb3V0ID0gLyogQF9fUFVSRV9fICovIG9wKHtkcm9wb3V0X30pO1xuIl19
|