/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { customGrad } from '../../gradients';
|
import { convertToTensor } from '../../tensor_util_env';
|
import { assertShapesMatch } from '../../util';
|
import { add } from '../add';
|
import { expandShapeToKeepDim } from '../axis_util';
|
import { cast } from '../cast';
|
import { div } from '../div';
|
import { exp } from '../exp';
|
import { logSumExp } from '../log_sum_exp';
|
import { Reduction } from '../loss_ops_utils';
|
import { mul } from '../mul';
|
import { neg } from '../neg';
|
import { op } from '../operation';
|
import { reshape } from '../reshape';
|
import { scalar } from '../scalar';
|
import { sub } from '../sub';
|
import { sum } from '../sum';
|
import { computeWeightedLoss } from './compute_weighted_loss';
|
/**
|
* Computes softmax cross entropy between logits and labels.
|
*
|
* Measures the probability error in discrete classification tasks in which
|
* the classes are mutually exclusive (each entry is in exactly one class).
|
* For example, each CIFAR-10 image is labeled with one and only one label: an
|
* image can be a dog or a truck, but not both.
|
*
|
* `NOTE`: While the classes are mutually exclusive, their probabilities need
|
* not be. All that is required is that each row of labels is a valid
|
* probability distribution. If they are not, the computation of the gradient
|
* will be incorrect.
|
*
|
* `WARNING`: This op expects unscaled logits, since it performs a softmax on
|
* logits internally for efficiency. Do not call this op with the output of
|
* softmax, as it will produce incorrect results.
|
*
|
* logits and labels must have the same shape, e.g. [batch_size, num_classes]
|
* and the same dtype.
|
* @param labels The labels array.
|
* @param logits The logits array.
|
* @param dim The dimension softmax would be performed on. Defaults to `-1`
|
* which indicates the last dimension.
|
*/
|
function softmaxCrossEntropyWithLogits_(labels, logits, dim = -1) {
|
if (dim === -1) {
|
dim = logits.rank - 1;
|
}
|
if (dim !== logits.rank - 1) {
|
throw Error(`Softmax cross entropy along a non-last dimension is not yet ` +
|
`supported. Labels / logits was rank ${logits.rank} ` +
|
`and dim was ${dim}`);
|
}
|
// Use a custom gradient for numerical stability.
|
const customOp = customGrad((labels, logits, save) => {
|
// Reference:
|
// 1. http://cs231n.github.io/linear-classify/#softmax
|
// 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/
|
const keepDims = true;
|
const lse = logSumExp(logits, [dim], keepDims);
|
const logResult = sub(cast(logits, 'float32'), lse);
|
save([labels, logResult]);
|
const costVector = neg(mul(logResult, labels));
|
const value = sum(costVector, [dim]);
|
const gradFunc = (dy, saved) => {
|
const [labels, logResult] = saved;
|
const dyShape = expandShapeToKeepDim(dy.shape, [dim]);
|
return [
|
mul(reshape(dy, dyShape), sub(cast(labels, 'float32'), exp(logResult))),
|
mul(reshape(dy, dyShape), sub(exp(logResult), cast(labels, 'float32'))),
|
];
|
};
|
return { value, gradFunc };
|
});
|
return customOp(labels, logits);
|
}
|
/**
|
* Computes the softmax cross entropy loss between two tensors.
|
*
|
* If labelSmoothing is nonzero, smooth the labels towards 1/2:
|
*
|
* newOnehotLabels = onehotLabels * (1 - labelSmoothing)
|
* + labelSmoothing / numClasses
|
*
|
* @param onehotLabels One hot encoded labels
|
* [batch_size, num_classes], same dimensions as 'predictions'.
|
* @param logits The predicted outputs.
|
* @param weights Tensor whose rank is either 0, or 1, and must be
|
* broadcastable to `loss` of shape [batch_size]
|
* @param labelSmoothing If greater than 0, then smooth the labels.
|
* @param reduction Type of reduction to apply to loss. Should be of type
|
* `Reduction`
|
*
|
* @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
|
*/
|
function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing = 0, reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) {
|
let $onehotLabels = convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');
|
const $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');
|
let $weights = null;
|
if (weights != null) {
|
$weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');
|
}
|
assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');
|
if (labelSmoothing > 0) {
|
const labelSmoothingScalar = scalar(labelSmoothing);
|
const one = scalar(1);
|
const numClasses = scalar($onehotLabels.shape[1]);
|
$onehotLabels =
|
add(mul($onehotLabels, sub(one, labelSmoothingScalar)), div(labelSmoothingScalar, numClasses));
|
}
|
const losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);
|
return computeWeightedLoss(losses, $weights, reduction);
|
}
|
export const softmaxCrossEntropy = /* @__PURE__ */ op({ softmaxCrossEntropy_ });
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"softmax_cross_entropy.js","sourceRoot":"","sources":["../../../../../../../tfjs-core/src/ops/losses/softmax_cross_entropy.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAC,UAAU,EAAC,MAAM,iBAAiB,CAAC;AAG3C,OAAO,EAAC,eAAe,EAAC,MAAM,uBAAuB,CAAC;AAEtD,OAAO,EAAC,iBAAiB,EAAC,MAAM,YAAY,CAAC;AAC7C,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,oBAAoB,EAAC,MAAM,cAAc,CAAC;AAClD,OAAO,EAAC,IAAI,EAAC,MAAM,SAAS,CAAC;AAC7B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,SAAS,EAAC,MAAM,gBAAgB,CAAC;AACzC,OAAO,EAAC,SAAS,EAAC,MAAM,mBAAmB,CAAC;AAC5C,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,EAAE,EAAC,MAAM,cAAc,CAAC;AAChC,OAAO,EAAC,OAAO,EAAC,MAAM,YAAY,CAAC;AACnC,OAAO,EAAC,MAAM,EAAC,MAAM,WAAW,CAAC;AACjC,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAE3B,OAAO,EAAC,mBAAmB,EAAC,MAAM,yBAAyB,CAAC;AAE5D;;;;;;;;;;;;;;;;;;;;;;;GAuBG;AACH,SAAS,8BAA8B,CACnC,MAAS,EAAE,MAAS,EAAE,GAAG,GAAG,CAAC,CAAC;IAChC,IAAI,GAAG,KAAK,CAAC,CAAC,EAAE;QACd,GAAG,GAAG,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;KACvB;IAED,IAAI,GAAG,KAAK,MAAM,CAAC,IAAI,GAAG,CAAC,EAAE;QAC3B,MAAM,KAAK,CACP,8DAA8D;YAC9D,uCAAuC,MAAM,CAAC,IAAI,GAAG;YACrD,eAAe,GAAG,EAAE,CAAC,CAAC;KAC3B;IACD,iDAAiD;IACjD,MAAM,QAAQ,GACV,UAAU,CAAC,CAAC,MAAc,EAAE,MAAc,EAAE,IAAkB,EAAE,EAAE;QAChE,aAAa;QACb,wDAAwD;QACxD,8DAA8D;QAC9D,MAAM,QAAQ,GAAG,IAAI,CAAC;QACtB,MAAM,GAAG,GAAG,SAAS,CAAC,MAAM,EAAE,CAAC,GAAG,CAAC,EAAE,QAAQ,CAAC,CAAC;QAC/C,MAAM,SAAS,GAAG,GAAG,CAAC,IAAI,CAAC,MAAM,EAAE,SAAS,CAAC,EAAE,GAAG,CAAC,CAAC;QACpD,IAAI,CAAC,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC,CAAC;QAE1B,MAAM,UAAU,GAAG,GAAG,CAAC,GAAG,CAAC,SAAS,EAAE,MAAM,CAAC,CAAC,CAAC;QAC/C,MAAM,KAAK,GAAM,GAAG,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC;QAExC,MAAM,QAAQ,GAAG,CAAC,EAAK,EAAE,KAAe,EAAE,EAAE;YAC1C,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,GAAG,KAAK,CAAC;YAClC,MAAM,OAAO,GAAG,oBAAoB,CAAC,EAAE,CAAC,KAAK,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC;YACtD,OAAO;gBACL,GAAG,CAAC,OAAO,CAAC,EAAE,EAAE,OAAO,CAAC,EACpB,GAAG,CAAC,IAAI,CAAC,MAAM,EAAE,SAAS,CAAC,EAAE,GAAG,CAAC,SAAS,CAAC,CAAC,CAAC;gBACjD,GAAG,CAAC,OAAO,CAAC,EAAE,EAAE,OAAO,CAAC,EACpB,GAAG,CAAC,GAAG,CAAC,SAAS,CAAC,EAAE,IAAI,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC,CAAC;aAClD,CAAC;QACJ,CAAC,CAAC;QACF,OAAO,EAAC,KAAK,EAAE,QAAQ,EAAC,CAAC;IAC3B,CAAC,CAAC,CAAC;IAEP,OAAO,QAAQ,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;AAClC,CAAC;AAED;;;;;;;;;;;;;;;;;;GAkBG;AACH,SAAS,oBAAoB,CACzB,YAA0B,EAAE,MAAoB,EAChD,OAA2B,EAAE,cAAc,GAAG,CAAC,EAC/C,SAAS,GAAG,SAAS,CAAC,sBAAsB;IAC9C,IAAI,aAAa,GACb,eAAe,CAAC,YAAY,EAAE,cAAc,EAAE,qBAAqB,CAAC,CAAC;IACzE,MAAM,OAAO,GAAG,eAAe,CAAC,MAAM,EAAE,QAAQ,EAAE,qBAAqB,CAAC,CAAC;IACzE,IAAI,QAAQ,GAAW,IAAI,CAAC;IAE5B,IAAI,OAAO,IAAI,IAAI,EAAE;QACnB,QAAQ,GAAG,eAAe,CAAC,OAAO,EAAE,SAAS,EAAE,qBAAqB,CAAC,CAAC;KACvE;IAED,iBAAiB,CACb,aAAa,CAAC,KAAK,EAAE,OAAO,CAAC,KAAK,EAAE,gCAAgC,CAAC,CAAC;IAE1E,IAAI,cAAc,GAAG,CAAC,EAAE;QACtB,MAAM,oBAAoB,GAAG,MAAM,CAAC,cAAc,CAAC,CAAC;QACpD,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACtB,MAAM,UAAU,GAAG,MAAM,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;QAElD,aAAa;YACT,GAAG,CAAC,GAAG,CAAC,aAAa,EAAE,GAAG,CAAC,GAAG,EAAE,oBAAoB,CAAC,CAAC,EAClD,GAAG,CAAC,oBAAoB,EAAE,UAAU,CAAC,CAAC,CAAC;KAChD;IAED,MAAM,MAAM,GAAG,8BAA8B,CAAC,aAAa,EAAE,OAAO,CAAC,CAAC;IAEtE,OAAO,mBAAmB,CAAC,MAAM,EAAE,QAAQ,EAAE,SAAS,CAAC,CAAC;AAC1D,CAAC;AAED,MAAM,CAAC,MAAM,mBAAmB,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,oBAAoB,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {customGrad} from '../../gradients';\nimport {Tensor} from '../../tensor';\nimport {GradSaveFunc} from '../../tensor_types';\nimport {convertToTensor} from '../../tensor_util_env';\nimport {TensorLike} from '../../types';\nimport {assertShapesMatch} from '../../util';\nimport {add} from '../add';\nimport {expandShapeToKeepDim} from '../axis_util';\nimport {cast} from '../cast';\nimport {div} from '../div';\nimport {exp} from '../exp';\nimport {logSumExp} from '../log_sum_exp';\nimport {Reduction} from '../loss_ops_utils';\nimport {mul} from '../mul';\nimport {neg} from '../neg';\nimport {op} from '../operation';\nimport {reshape} from '../reshape';\nimport {scalar} from '../scalar';\nimport {sub} from '../sub';\nimport {sum} from '../sum';\n\nimport {computeWeightedLoss} from './compute_weighted_loss';\n\n/**\n * Computes softmax cross entropy between logits and labels.\n *\n * Measures the probability error in discrete classification tasks in which\n * the classes are mutually exclusive (each entry is in exactly one class).\n * For example, each CIFAR-10 image is labeled with one and only one label: an\n * image can be a dog or a truck, but not both.\n *\n * `NOTE`: While the classes are mutually exclusive, their probabilities need\n * not be. All that is required is that each row of labels is a valid\n * probability distribution. If they are not, the computation of the gradient\n * will be incorrect.\n *\n * `WARNING`: This op expects unscaled logits, since it performs a softmax on\n * logits internally for efficiency. Do not call this op with the output of\n * softmax, as it will produce incorrect results.\n *\n * logits and labels must have the same shape, e.g. [batch_size, num_classes]\n * and the same dtype.\n * @param labels The labels array.\n * @param logits The logits array.\n * @param dim The dimension softmax would be performed on. Defaults to `-1`\n *     which indicates the last dimension.\n */\nfunction softmaxCrossEntropyWithLogits_<T extends Tensor, O extends Tensor>(\n    labels: T, logits: T, dim = -1): O {\n  if (dim === -1) {\n    dim = logits.rank - 1;\n  }\n\n  if (dim !== logits.rank - 1) {\n    throw Error(\n        `Softmax cross entropy along a non-last dimension is not yet ` +\n        `supported. Labels / logits was rank ${logits.rank} ` +\n        `and dim was ${dim}`);\n  }\n  // Use a custom gradient for numerical stability.\n  const customOp =\n      customGrad((labels: Tensor, logits: Tensor, save: GradSaveFunc) => {\n        // Reference:\n        //   1. http://cs231n.github.io/linear-classify/#softmax\n        //   2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/\n        const keepDims = true;\n        const lse = logSumExp(logits, [dim], keepDims);\n        const logResult = sub(cast(logits, 'float32'), lse);\n        save([labels, logResult]);\n\n        const costVector = neg(mul(logResult, labels));\n        const value: O = sum(costVector, [dim]);\n\n        const gradFunc = (dy: O, saved: Tensor[]) => {\n          const [labels, logResult] = saved;\n          const dyShape = expandShapeToKeepDim(dy.shape, [dim]);\n          return [\n            mul(reshape(dy, dyShape),\n                sub(cast(labels, 'float32'), exp(logResult))),\n            mul(reshape(dy, dyShape),\n                sub(exp(logResult), cast(labels, 'float32'))),\n          ];\n        };\n        return {value, gradFunc};\n      });\n\n  return customOp(labels, logits);\n}\n\n/**\n * Computes the softmax cross entropy loss between two tensors.\n *\n * If labelSmoothing is nonzero, smooth the labels towards 1/2:\n *\n *   newOnehotLabels = onehotLabels * (1 - labelSmoothing)\n *                         + labelSmoothing / numClasses\n *\n * @param onehotLabels One hot encoded labels\n *    [batch_size, num_classes], same dimensions as 'predictions'.\n * @param logits The predicted outputs.\n * @param weights Tensor whose rank is either 0, or 1, and must be\n *    broadcastable to `loss`  of shape [batch_size]\n * @param labelSmoothing If greater than 0, then smooth the labels.\n * @param reduction Type of reduction to apply to loss. Should be of type\n *    `Reduction`\n *\n * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }\n */\nfunction softmaxCrossEntropy_<T extends Tensor, O extends Tensor>(\n    onehotLabels: T|TensorLike, logits: T|TensorLike,\n    weights?: Tensor|TensorLike, labelSmoothing = 0,\n    reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O {\n  let $onehotLabels =\n      convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');\n  const $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');\n  let $weights: Tensor = null;\n\n  if (weights != null) {\n    $weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');\n  }\n\n  assertShapesMatch(\n      $onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');\n\n  if (labelSmoothing > 0) {\n    const labelSmoothingScalar = scalar(labelSmoothing);\n    const one = scalar(1);\n    const numClasses = scalar($onehotLabels.shape[1]);\n\n    $onehotLabels =\n        add(mul($onehotLabels, sub(one, labelSmoothingScalar)),\n            div(labelSmoothingScalar, numClasses));\n  }\n\n  const losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);\n\n  return computeWeightedLoss(losses, $weights, reduction);\n}\n\nexport const softmaxCrossEntropy = /* @__PURE__ */ op({softmaxCrossEntropy_});\n"]}
|