/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { convertToTensor } from '../../tensor_util_env';
|
import { assertShapesMatch } from '../../util';
|
import { abs } from '../abs';
|
import { add } from '../add';
|
import { exp } from '../exp';
|
import { log1p } from '../log1p';
|
import { Reduction } from '../loss_ops_utils';
|
import { mul } from '../mul';
|
import { neg } from '../neg';
|
import { op } from '../operation';
|
import { relu } from '../relu';
|
import { scalar } from '../scalar';
|
import { sub } from '../sub';
|
import { computeWeightedLoss } from './compute_weighted_loss';
|
function sigmoidCrossEntropyWithLogits_(labels, logits) {
|
const $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');
|
const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');
|
assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');
|
/**
|
* Implementation Details:
|
*
|
* For brevity, let `x = logits`, `z = labels`. The logistic loss is
|
* z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
* = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
|
* = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
|
* = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
|
* = (1 - z) * x + log(1 + exp(-x))
|
* = x - x * z + log(1 + exp(-x))
|
*
|
* For x < 0, to avoid overflow in exp(-x), we reformulate the above
|
* x - x * z + log(1 + exp(-x))
|
* = log(exp(x)) - x * z + log(1 + exp(-x))
|
* = - x * z + log(1 + exp(x))
|
*
|
* Hence, to ensure stability and avoid overflow, the implementation uses
|
* this equivalent formulation:
|
* max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
*/
|
const maxOutput = relu($logits);
|
const outputXTarget = mul($logits, $labels);
|
const sigmoidOutput = log1p(exp(neg(abs($logits))));
|
return add(sub(maxOutput, outputXTarget), sigmoidOutput);
|
}
|
/**
|
* Computes the sigmoid cross entropy loss between two tensors.
|
*
|
* If labelSmoothing is nonzero, smooth the labels towards 1/2:
|
*
|
* newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)
|
* + 0.5 * labelSmoothing
|
*
|
* @param multiClassLabels The ground truth output tensor of shape
|
* [batch_size, num_classes], same dimensions as 'predictions'.
|
* @param logits The predicted outputs.
|
* @param weights Tensor whose rank is either 0, or the same rank as
|
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
|
* must be either `1`, or the same as the corresponding `losses`
|
* dimension).
|
* @param labelSmoothing If greater than 0, then smooth the labels.
|
* @param reduction Type of reduction to apply to loss. Should be of type
|
* `Reduction`
|
*
|
* @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
|
*/
|
function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing = 0, reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) {
|
let $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');
|
const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');
|
let $weights = null;
|
if (weights != null) {
|
$weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');
|
}
|
assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');
|
if (labelSmoothing > 0) {
|
const labelSmoothingScalar = scalar(labelSmoothing);
|
const one = scalar(1);
|
const half = scalar(0.5);
|
$multiClassLabels =
|
add(mul($multiClassLabels, sub(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
|
}
|
const losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
|
return computeWeightedLoss(losses, $weights, reduction);
|
}
|
export const sigmoidCrossEntropy = /* @__PURE__ */ op({ sigmoidCrossEntropy_ });
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"sigmoid_cross_entropy.js","sourceRoot":"","sources":["../../../../../../../tfjs-core/src/ops/losses/sigmoid_cross_entropy.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,OAAO,EAAC,eAAe,EAAC,MAAM,uBAAuB,CAAC;AAEtD,OAAO,EAAC,iBAAiB,EAAC,MAAM,YAAY,CAAC;AAC7C,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,KAAK,EAAC,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,SAAS,EAAC,MAAM,mBAAmB,CAAC;AAC5C,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,EAAE,EAAC,MAAM,cAAc,CAAC;AAChC,OAAO,EAAC,IAAI,EAAC,MAAM,SAAS,CAAC;AAC7B,OAAO,EAAC,MAAM,EAAC,MAAM,WAAW,CAAC;AACjC,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAE3B,OAAO,EAAC,mBAAmB,EAAC,MAAM,yBAAyB,CAAC;AAE5D,SAAS,8BAA8B,CACnC,MAAoB,EAAE,MAAoB;IAC5C,MAAM,OAAO,GACT,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,+BAA+B,CAAC,CAAC;IACvE,MAAM,OAAO,GACT,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,+BAA+B,CAAC,CAAC;IACvE,iBAAiB,CACb,OAAO,CAAC,KAAK,EAAE,OAAO,CAAC,KAAK,EAAE,0CAA0C,CAAC,CAAC;IAE9E;;;;;;;;;;;;;;;;;;;OAmBG;IACH,MAAM,SAAS,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC;IAChC,MAAM,aAAa,GAAG,GAAG,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;IAC5C,MAAM,aAAa,GAAG,KAAK,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;IAEpD,OAAO,GAAG,CAAC,GAAG,CAAC,SAAS,EAAE,aAAa,CAAC,EAAE,aAAa,CAAC,CAAC;AAC3D,CAAC;AAED;;;;;;;;;;;;;;;;;;;;GAoBG;AACH,SAAS,oBAAoB,CACzB,gBAA8B,EAAE,MAAoB,EACpD,OAA2B,EAAE,cAAc,GAAG,CAAC,EAC/C,SAAS,GAAG,SAAS,CAAC,sBAAsB;IAC9C,IAAI,iBAAiB,GAAG,eAAe,CACnC,gBAAgB,EAAE,kBAAkB,EAAE,qBAAqB,CAAC,CAAC;IACjE,MAAM,OAAO,GAAG,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,qBAAqB,CAAC,CAAC;IACzE,IAAI,QAAQ,GAAW,IAAI,CAAC;IAC5B,IAAI,OAAO,IAAI,IAAI,EAAE;QACnB,QAAQ,GAAG,eAAe,CAAC,OAAO,EAAE,SAAS,EAAE,qBAAqB,CAAC,CAAC;KACvE;IACD,iBAAiB,CACb,iBAAiB,CAAC,KAAK,EAAE,OAAO,CAAC,KAAK,EAAE,gCAAgC,CAAC,CAAC;IAE9E,IAAI,cAAc,GAAG,CAAC,EAAE;QACtB,MAAM,oBAAoB,GAAG,MAAM,CAAC,cAAc,CAAC,CAAC;QACpD,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACtB,MAAM,IAAI,GAAG,MAAM,CAAC,GAAG,CAAC,CAAC;QAEzB,iBAAiB;YACb,GAAG,CAAC,GAAG,CAAC,iBAAiB,EAAE,GAAG,CAAC,GAAG,EAAE,oBAAoB,CAAC,CAAC,EACtD,GAAG,CAAC,IAAI,EAAE,oBAAoB,CAAC,CAAC,CAAC;KAC1C;IACD,MAAM,MAAM,GAAG,8BAA8B,CAAC,iBAAiB,EAAE,OAAO,CAAC,CAAC;IAE1E,OAAO,mBAAmB,CAAC,MAAM,EAAE,QAAQ,EAAE,SAAS,CAAC,CAAC;AAC1D,CAAC;AAED,MAAM,CAAC,MAAM,mBAAmB,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,oBAAoB,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '../../tensor';\nimport {convertToTensor} from '../../tensor_util_env';\nimport {TensorLike} from '../../types';\nimport {assertShapesMatch} from '../../util';\nimport {abs} from '../abs';\nimport {add} from '../add';\nimport {exp} from '../exp';\nimport {log1p} from '../log1p';\nimport {Reduction} from '../loss_ops_utils';\nimport {mul} from '../mul';\nimport {neg} from '../neg';\nimport {op} from '../operation';\nimport {relu} from '../relu';\nimport {scalar} from '../scalar';\nimport {sub} from '../sub';\n\nimport {computeWeightedLoss} from './compute_weighted_loss';\n\nfunction sigmoidCrossEntropyWithLogits_<T extends Tensor, O extends Tensor>(\n    labels: T|TensorLike, logits: T|TensorLike): O {\n  const $labels =\n      convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');\n  const $logits =\n      convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');\n  assertShapesMatch(\n      $labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');\n\n  /**\n   * Implementation Details:\n   *\n   * For brevity, let `x = logits`, `z = labels`.  The logistic loss is\n   *     z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))\n   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))\n   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))\n   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))\n   *   = (1 - z) * x + log(1 + exp(-x))\n   *   = x - x * z + log(1 + exp(-x))\n   *\n   *   For x < 0, to avoid overflow in exp(-x), we reformulate the above\n   *     x - x * z + log(1 + exp(-x))\n   *   = log(exp(x)) - x * z + log(1 + exp(-x))\n   *   = - x * z + log(1 + exp(x))\n   *\n   * Hence, to ensure stability and avoid overflow, the implementation uses\n   * this equivalent formulation:\n   *     max(x, 0) - x * z + log(1 + exp(-abs(x)))\n   */\n  const maxOutput = relu($logits);\n  const outputXTarget = mul($logits, $labels);\n  const sigmoidOutput = log1p(exp(neg(abs($logits))));\n\n  return add(sub(maxOutput, outputXTarget), sigmoidOutput);\n}\n\n/**\n * Computes the sigmoid cross entropy loss between two tensors.\n *\n * If labelSmoothing is nonzero, smooth the labels towards 1/2:\n *\n *   newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)\n *                         + 0.5 * labelSmoothing\n *\n * @param multiClassLabels The ground truth output tensor of shape\n * [batch_size, num_classes], same dimensions as 'predictions'.\n * @param logits The predicted outputs.\n * @param weights Tensor whose rank is either 0, or the same rank as\n *    `labels`, and must be broadcastable to `labels` (i.e., all dimensions\n *    must be either `1`, or the same as the corresponding `losses`\n *    dimension).\n * @param labelSmoothing If greater than 0, then smooth the labels.\n * @param reduction Type of reduction to apply to loss. Should be of type\n *    `Reduction`\n *\n * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }\n */\nfunction sigmoidCrossEntropy_<T extends Tensor, O extends Tensor>(\n    multiClassLabels: T|TensorLike, logits: T|TensorLike,\n    weights?: Tensor|TensorLike, labelSmoothing = 0,\n    reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n  let $multiClassLabels = convertToTensor(\n      multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');\n  const $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');\n  let $weights: Tensor = null;\n  if (weights != null) {\n    $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');\n  }\n  assertShapesMatch(\n      $multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');\n\n  if (labelSmoothing > 0) {\n    const labelSmoothingScalar = scalar(labelSmoothing);\n    const one = scalar(1);\n    const half = scalar(0.5);\n\n    $multiClassLabels =\n        add(mul($multiClassLabels, sub(one, labelSmoothingScalar)),\n            mul(half, labelSmoothingScalar));\n  }\n  const losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);\n\n  return computeWeightedLoss(losses, $weights, reduction);\n}\n\nexport const sigmoidCrossEntropy = /* @__PURE__ */ op({sigmoidCrossEntropy_});\n"]}
|