gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { convertToTensor } from '../../tensor_util_env';
import { assertShapesMatch } from '../../util';
import { abs } from '../abs';
import { add } from '../add';
import { exp } from '../exp';
import { log1p } from '../log1p';
import { Reduction } from '../loss_ops_utils';
import { mul } from '../mul';
import { neg } from '../neg';
import { op } from '../operation';
import { relu } from '../relu';
import { scalar } from '../scalar';
import { sub } from '../sub';
import { computeWeightedLoss } from './compute_weighted_loss';
function sigmoidCrossEntropyWithLogits_(labels, logits) {
    const $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');
    const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');
    assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');
    /**
     * Implementation Details:
     *
     * For brevity, let `x = logits`, `z = labels`.  The logistic loss is
     *     z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
     *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
     *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
     *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
     *   = (1 - z) * x + log(1 + exp(-x))
     *   = x - x * z + log(1 + exp(-x))
     *
     *   For x < 0, to avoid overflow in exp(-x), we reformulate the above
     *     x - x * z + log(1 + exp(-x))
     *   = log(exp(x)) - x * z + log(1 + exp(-x))
     *   = - x * z + log(1 + exp(x))
     *
     * Hence, to ensure stability and avoid overflow, the implementation uses
     * this equivalent formulation:
     *     max(x, 0) - x * z + log(1 + exp(-abs(x)))
     */
    const maxOutput = relu($logits);
    const outputXTarget = mul($logits, $labels);
    const sigmoidOutput = log1p(exp(neg(abs($logits))));
    return add(sub(maxOutput, outputXTarget), sigmoidOutput);
}
/**
 * Computes the sigmoid cross entropy loss between two tensors.
 *
 * If labelSmoothing is nonzero, smooth the labels towards 1/2:
 *
 *   newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)
 *                         + 0.5 * labelSmoothing
 *
 * @param multiClassLabels The ground truth output tensor of shape
 * [batch_size, num_classes], same dimensions as 'predictions'.
 * @param logits The predicted outputs.
 * @param weights Tensor whose rank is either 0, or the same rank as
 *    `labels`, and must be broadcastable to `labels` (i.e., all dimensions
 *    must be either `1`, or the same as the corresponding `losses`
 *    dimension).
 * @param labelSmoothing If greater than 0, then smooth the labels.
 * @param reduction Type of reduction to apply to loss. Should be of type
 *    `Reduction`
 *
 * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
 */
function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing = 0, reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) {
    let $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');
    const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');
    let $weights = null;
    if (weights != null) {
        $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');
    }
    assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');
    if (labelSmoothing > 0) {
        const labelSmoothingScalar = scalar(labelSmoothing);
        const one = scalar(1);
        const half = scalar(0.5);
        $multiClassLabels =
            add(mul($multiClassLabels, sub(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
    }
    const losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
    return computeWeightedLoss(losses, $weights, reduction);
}
export const sigmoidCrossEntropy = /* @__PURE__ */ op({ sigmoidCrossEntropy_ });
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"sigmoid_cross_entropy.js","sourceRoot":"","sources":["../../../../../../../tfjs-core/src/ops/losses/sigmoid_cross_entropy.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,OAAO,EAAC,eAAe,EAAC,MAAM,uBAAuB,CAAC;AAEtD,OAAO,EAAC,iBAAiB,EAAC,MAAM,YAAY,CAAC;AAC7C,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,KAAK,EAAC,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,SAAS,EAAC,MAAM,mBAAmB,CAAC;AAC5C,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,EAAE,EAAC,MAAM,cAAc,CAAC;AAChC,OAAO,EAAC,IAAI,EAAC,MAAM,SAAS,CAAC;AAC7B,OAAO,EAAC,MAAM,EAAC,MAAM,WAAW,CAAC;AACjC,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAE3B,OAAO,EAAC,mBAAmB,EAAC,MAAM,yBAAyB,CAAC;AAE5D,SAAS,8BAA8B,CACnC,MAAoB,EAAE,MAAoB;IAC5C,MAAM,OAAO,GACT,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,+BAA+B,CAAC,CAAC;IACvE,MAAM,OAAO,GACT,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,+BAA+B,CAAC,CAAC;IACvE,iBAAiB,CACb,OAAO,CAAC,KAAK,EAAE,OAAO,CAAC,KAAK,EAAE,0CAA0C,CAAC,CAAC;IAE9E;;;;;;;;;;;;;;;;;;;OAmBG;IACH,MAAM,SAAS,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC;IAChC,MAAM,aAAa,GAAG,GAAG,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;IAC5C,MAAM,aAAa,GAAG,KAAK,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;IAEpD,OAAO,GAAG,CAAC,GAAG,CAAC,SAAS,EAAE,aAAa,CAAC,EAAE,aAAa,CAAC,CAAC;AAC3D,CAAC;AAED;;;;;;;;;;;;;;;;;;;;GAoBG;AACH,SAAS,oBAAoB,CACzB,gBAA8B,EAAE,MAAoB,EACpD,OAA2B,EAAE,cAAc,GAAG,CAAC,EAC/C,SAAS,GAAG,SAAS,CAAC,sBAAsB;IAC9C,IAAI,iBAAiB,GAAG,eAAe,CACnC,gBAAgB,EAAE,kBAAkB,EAAE,qBAAqB,CAAC,CAAC;IACjE,MAAM,OAAO,GAAG,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,qBAAqB,CAAC,CAAC;IACzE,IAAI,QAAQ,GAAW,IAAI,CAAC;IAC5B,IAAI,OAAO,IAAI,IAAI,EAAE;QACnB,QAAQ,GAAG,eAAe,CAAC,OAAO,EAAE,SAAS,EAAE,qBAAqB,CAAC,CAAC;KACvE;IACD,iBAAiB,CACb,iBAAiB,CAAC,KAAK,EAAE,OAAO,CAAC,KAAK,EAAE,gCAAgC,CAAC,CAAC;IAE9E,IAAI,cAAc,GAAG,CAAC,EAAE;QACtB,MAAM,oBAAoB,GAAG,MAAM,CAAC,cAAc,CAAC,CAAC;QACpD,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACtB,MAAM,IAAI,GAAG,MAAM,CAAC,GAAG,CAAC,CAAC;QAEzB,iBAAiB;YACb,GAAG,CAAC,GAAG,CAAC,iBAAiB,EAAE,GAAG,CAAC,GAAG,EAAE,oBAAoB,CAAC,CAAC,EACtD,GAAG,CAAC,IAAI,EAAE,oBAAoB,CAAC,CAAC,CAAC;KAC1C;IACD,MAAM,MAAM,GAAG,8BAA8B,CAAC,iBAAiB,EAAE,OAAO,CAAC,CAAC;IAE1E,OAAO,mBAAmB,CAAC,MAAM,EAAE,QAAQ,EAAE,SAAS,CAAC,CAAC;AAC1D,CAAC;AAED,MAAM,CAAC,MAAM,mBAAmB,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,oBAAoB,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../../tensor';\nimport {convertToTensor} from '../../tensor_util_env';\nimport {TensorLike} from '../../types';\nimport {assertShapesMatch} from '../../util';\nimport {abs} from '../abs';\nimport {add} from '../add';\nimport {exp} from '../exp';\nimport {log1p} from '../log1p';\nimport {Reduction} from '../loss_ops_utils';\nimport {mul} from '../mul';\nimport {neg} from '../neg';\nimport {op} from '../operation';\nimport {relu} from '../relu';\nimport {scalar} from '../scalar';\nimport {sub} from '../sub';\n\nimport {computeWeightedLoss} from './compute_weighted_loss';\n\nfunction sigmoidCrossEntropyWithLogits_<T extends Tensor, O extends Tensor>(\n    labels: T|TensorLike, logits: T|TensorLike): O {\n  const $labels =\n      convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');\n  const $logits =\n      convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');\n  assertShapesMatch(\n      $labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');\n\n  /**\n   * Implementation Details:\n   *\n   * For brevity, let `x = logits`, `z = labels`.  The logistic loss is\n   *     z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))\n   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))\n   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))\n   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))\n   *   = (1 - z) * x + log(1 + exp(-x))\n   *   = x - x * z + log(1 + exp(-x))\n   *\n   *   For x < 0, to avoid overflow in exp(-x), we reformulate the above\n   *     x - x * z + log(1 + exp(-x))\n   *   = log(exp(x)) - x * z + log(1 + exp(-x))\n   *   = - x * z + log(1 + exp(x))\n   *\n   * Hence, to ensure stability and avoid overflow, the implementation uses\n   * this equivalent formulation:\n   *     max(x, 0) - x * z + log(1 + exp(-abs(x)))\n   */\n  const maxOutput = relu($logits);\n  const outputXTarget = mul($logits, $labels);\n  const sigmoidOutput = log1p(exp(neg(abs($logits))));\n\n  return add(sub(maxOutput, outputXTarget), sigmoidOutput);\n}\n\n/**\n * Computes the sigmoid cross entropy loss between two tensors.\n *\n * If labelSmoothing is nonzero, smooth the labels towards 1/2:\n *\n *   newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)\n *                         + 0.5 * labelSmoothing\n *\n * @param multiClassLabels The ground truth output tensor of shape\n * [batch_size, num_classes], same dimensions as 'predictions'.\n * @param logits The predicted outputs.\n * @param weights Tensor whose rank is either 0, or the same rank as\n *    `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n *    must be either `1`, or the same as the corresponding `losses`\n *    dimension).\n * @param labelSmoothing If greater than 0, then smooth the labels.\n * @param reduction Type of reduction to apply to loss. Should be of type\n *    `Reduction`\n *\n * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }\n */\nfunction sigmoidCrossEntropy_<T extends Tensor, O extends Tensor>(\n    multiClassLabels: T|TensorLike, logits: T|TensorLike,\n    weights?: Tensor|TensorLike, labelSmoothing = 0,\n    reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n  let $multiClassLabels = convertToTensor(\n      multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');\n  const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');\n  let $weights: Tensor = null;\n  if (weights != null) {\n    $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');\n  }\n  assertShapesMatch(\n      $multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');\n\n  if (labelSmoothing > 0) {\n    const labelSmoothingScalar = scalar(labelSmoothing);\n    const one = scalar(1);\n    const half = scalar(0.5);\n\n    $multiClassLabels =\n        add(mul($multiClassLabels, sub(one, labelSmoothingScalar)),\n            mul(half, labelSmoothingScalar));\n  }\n  const losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);\n\n  return computeWeightedLoss(losses, $weights, reduction);\n}\n\nexport const sigmoidCrossEntropy = /* @__PURE__ */ op({sigmoidCrossEntropy_});\n"]}