import { convertToTensor } from '../../tensor_util_env'; import { cast } from '../cast'; import { div } from '../div'; import { Reduction } from '../loss_ops_utils'; import { mean } from '../mean'; import { mul } from '../mul'; import { notEqual } from '../not_equal'; import { ones } from '../ones'; import { op } from '../operation'; import { scalar } from '../scalar'; import { sum } from '../sum'; /** * Computes the weighted loss between two tensors. * * @param losses Tensor of shape `[batch_size, d1, ..., dN]`. * @param weights Tensor whose rank is either 0, or the same rank as * `losses`, and must be broadcastable to `losses` (i.e., all * dimensions must be either `1`, or the same as the corresponding * `losses` dimension). * * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */ function computeWeightedLoss_(losses, weights, reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) { const $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss'); let $weights = null; if (weights != null) { $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss'); } const weightedLoss = ($weights == null) ? $losses : mul($losses, $weights); if (reduction === Reduction.NONE) { return weightedLoss; } if (reduction === Reduction.SUM) { return sum(weightedLoss); } if (reduction === Reduction.MEAN) { if ($weights == null) { return mean(weightedLoss); } else { const broadcastFactor = $losses.size / $weights.size; const result = div(sum(weightedLoss), sum($weights)); return broadcastFactor > 1 ? div(result, scalar(broadcastFactor)) : result; } } if (reduction === Reduction.SUM_BY_NONZERO_WEIGHTS) { if ($weights == null) { return div(sum(weightedLoss), scalar($losses.size)); } else { const broadcastedWeights = mul($weights, ones($losses.shape)); const numNonZeros = cast(sum(notEqual(broadcastedWeights, scalar(0))), 'float32'); return div(sum(weightedLoss), numNonZeros); } } throw Error(`Unknown reduction: ${reduction}`); } export const computeWeightedLoss = /* @__PURE__ */ op({ computeWeightedLoss_ }); //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29tcHV0ZV93ZWlnaHRlZF9sb3NzLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvbG9zc2VzL2NvbXB1dGVfd2VpZ2h0ZWRfbG9zcy50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFpQkEsT0FBTyxFQUFDLGVBQWUsRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBR3RELE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFDN0IsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLFFBQVEsQ0FBQztBQUMzQixPQUFPLEVBQUMsU0FBUyxFQUFDLE1BQU0sbUJBQW1CLENBQUM7QUFDNUMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUM3QixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sUUFBUSxDQUFDO0FBQzNCLE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxjQUFjLENBQUM7QUFDdEMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUM3QixPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sY0FBYyxDQUFDO0FBQ2hDLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDakMsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLFFBQVEsQ0FBQztBQUUzQjs7Ozs7Ozs7OztHQVVHO0FBQ0gsU0FBUyxvQkFBb0IsQ0FDekIsTUFBb0IsRUFBRSxPQUEyQixFQUNqRCxTQUFTLEdBQUcsU0FBUyxDQUFDLHNCQUFzQjtJQUM5QyxNQUFNLE9BQU8sR0FBRyxlQUFlLENBQUMsTUFBTSxFQUFFLFFBQVEsRUFBRSxxQkFBcUIsQ0FBQyxDQUFDO0lBQ3pFLElBQUksUUFBUSxHQUFXLElBQUksQ0FBQztJQUM1QixJQUFJLE9BQU8sSUFBSSxJQUFJLEVBQUU7UUFDbkIsUUFBUSxHQUFHLGVBQWUsQ0FBQyxPQUFPLEVBQUUsU0FBUyxFQUFFLHFCQUFxQixDQUFDLENBQUM7S0FDdkU7SUFFRCxNQUFNLFlBQVksR0FBRyxDQUFDLFFBQVEsSUFBSSxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsT0FBTyxFQUFFLFFBQVEsQ0FBQyxDQUFDO0lBRTNFLElBQUksU0FBUyxLQUFLLFNBQVMsQ0FBQyxJQUFJLEVBQUU7UUFDaEMsT0FBTyxZQUFpQixDQUFDO0tBQzFCO0lBQ0QsSUFBSSxTQUFTLEtBQUssU0FBUyxDQUFDLEdBQUcsRUFBRTtRQUMvQixPQUFPLEdBQUcsQ0FBQyxZQUFZLENBQUMsQ0FBQztLQUMxQjtJQUNELElBQUksU0FBUyxLQUFLLFNBQVMsQ0FBQyxJQUFJLEVBQUU7UUFDaEMsSUFBSSxRQUFRLElBQUksSUFBSSxFQUFFO1lBQ3BCLE9BQU8sSUFBSSxDQUFDLFlBQVksQ0FBQyxDQUFDO1NBQzNCO2FBQU07WUFDTCxNQUFNLGVBQWUsR0FBRyxPQUFPLENBQUMsSUFBSSxHQUFHLFFBQVEsQ0FBQyxJQUFJLENBQUM7WUFDckQsTUFBTSxNQUFNLEdBQUcsR0FBRyxDQUFDLEdBQUcsQ0FBQyxZQUFZLENBQUMsRUFBRSxHQUFHLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQztZQUNyRCxPQUFPLGVBQWUsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxNQUFNLEVBQUUsTUFBTSxDQUFDLGVBQWUsQ0FBQyxDQUFDLENBQUMsQ0FBQztnQkFDdEMsTUFBVyxDQUFDO1NBQzFDO0tBQ0Y7SUFDRCxJQUFJLFNBQVMsS0FBSyxTQUFTLENBQUMsc0JBQXNCLEVBQUU7UUFDbEQsSUFBSSxRQUFRLElBQUksSUFBSSxFQUFFO1lBQ3BCLE9BQU8sR0FBRyxDQUFDLEdBQUcsQ0FBQyxZQUFZLENBQUMsRUFBRSxNQUFNLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUM7U0FDckQ7YUFBTTtZQUNMLE1BQU0sa0JBQWtCLEdBQUcsR0FBRyxDQUFDLFFBQVEsRUFBRSxJQUFJLENBQUMsT0FBTyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUM7WUFFOUQsTUFBTSxXQUFXLEdBQ2IsSUFBSSxDQUFDLEdBQUcsQ0FBQyxRQUFRLENBQUMsa0JBQWtCLEVBQUUsTUFBTSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztZQUNsRSxPQUFPLEdBQUcsQ0FBQyxHQUFHLENBQUMsWUFBWSxDQUFDLEVBQUUsV0FBVyxDQUFDLENBQUM7U0FDNUM7S0FDRjtJQUVELE1BQU0sS0FBSyxDQUFDLHNCQUFzQixTQUFTLEVBQUUsQ0FBQyxDQUFDO0FBQ2pELENBQUM7QUFDRCxNQUFNLENBQUMsTUFBTSxtQkFBbUIsR0FBRyxlQUFlLENBQUMsRUFBRSxDQUFDLEVBQUMsb0JBQW9CLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtUZW5zb3J9IGZyb20gJy4uLy4uL3RlbnNvcic7XG5pbXBvcnQge2NvbnZlcnRUb1RlbnNvcn0gZnJvbSAnLi4vLi4vdGVuc29yX3V0aWxfZW52JztcbmltcG9ydCB7VGVuc29yTGlrZX0gZnJvbSAnLi4vLi4vdHlwZXMnO1xuXG5pbXBvcnQge2Nhc3R9IGZyb20gJy4uL2Nhc3QnO1xuaW1wb3J0IHtkaXZ9IGZyb20gJy4uL2Rpdic7XG5pbXBvcnQge1JlZHVjdGlvbn0gZnJvbSAnLi4vbG9zc19vcHNfdXRpbHMnO1xuaW1wb3J0IHttZWFufSBmcm9tICcuLi9tZWFuJztcbmltcG9ydCB7bXVsfSBmcm9tICcuLi9tdWwnO1xuaW1wb3J0IHtub3RFcXVhbH0gZnJvbSAnLi4vbm90X2VxdWFsJztcbmltcG9ydCB7b25lc30gZnJvbSAnLi4vb25lcyc7XG5pbXBvcnQge29wfSBmcm9tICcuLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtzY2FsYXJ9IGZyb20gJy4uL3NjYWxhcic7XG5pbXBvcnQge3N1bX0gZnJvbSAnLi4vc3VtJztcblxuLyoqXG4gKiBDb21wdXRlcyB0aGUgd2VpZ2h0ZWQgbG9zcyBiZXR3ZWVuIHR3byB0ZW5zb3JzLlxuICpcbiAqIEBwYXJhbSBsb3NzZXMgVGVuc29yIG9mIHNoYXBlIGBbYmF0Y2hfc2l6ZSwgZDEsIC4uLiwgZE5dYC5cbiAqIEBwYXJhbSB3ZWlnaHRzIFRlbnNvciB3aG9zZSByYW5rIGlzIGVpdGhlciAwLCBvciB0aGUgc2FtZSByYW5rIGFzXG4gKiAgICBgbG9zc2VzYCwgYW5kIG11c3QgYmUgYnJvYWRjYXN0YWJsZSB0byBgbG9zc2VzYCAoaS5lLiwgYWxsXG4gKiAgICBkaW1lbnNpb25zIG11c3QgYmUgZWl0aGVyIGAxYCwgb3IgdGhlIHNhbWUgYXMgdGhlIGNvcnJlc3BvbmRpbmdcbiAqICAgIGBsb3NzZXNgIGRpbWVuc2lvbikuXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ1RyYWluaW5nJywgc3ViaGVhZGluZzogJ0xvc3NlcycsIG5hbWVzcGFjZTogJ2xvc3Nlcyd9XG4gKi9cbmZ1bmN0aW9uIGNvbXB1dGVXZWlnaHRlZExvc3NfPFQgZXh0ZW5kcyBUZW5zb3IsIE8gZXh0ZW5kcyBUZW5zb3I+KFxuICAgIGxvc3NlczogVHxUZW5zb3JMaWtlLCB3ZWlnaHRzPzogVGVuc29yfFRlbnNvckxpa2UsXG4gICAgcmVkdWN0aW9uID0gUmVkdWN0aW9uLlNVTV9CWV9OT05aRVJPX1dFSUdIVFMpOiBPIHtcbiAgY29uc3QgJGxvc3NlcyA9IGNvbnZlcnRUb1RlbnNvcihsb3NzZXMsICdsb3NzZXMnLCAnY29tcHV0ZVdlaWdodGVkTG9zcycpO1xuICBsZXQgJHdlaWdodHM6IFRlbnNvciA9IG51bGw7XG4gIGlmICh3ZWlnaHRzICE9IG51bGwpIHtcbiAgICAkd2VpZ2h0cyA9IGNvbnZlcnRUb1RlbnNvcih3ZWlnaHRzLCAnd2VpZ2h0cycsICdjb21wdXRlV2VpZ2h0ZWRMb3NzJyk7XG4gIH1cblxuICBjb25zdCB3ZWlnaHRlZExvc3MgPSAoJHdlaWdodHMgPT0gbnVsbCkgPyAkbG9zc2VzIDogbXVsKCRsb3NzZXMsICR3ZWlnaHRzKTtcblxuICBpZiAocmVkdWN0aW9uID09PSBSZWR1Y3Rpb24uTk9ORSkge1xuICAgIHJldHVybiB3ZWlnaHRlZExvc3MgYXMgTztcbiAgfVxuICBpZiAocmVkdWN0aW9uID09PSBSZWR1Y3Rpb24uU1VNKSB7XG4gICAgcmV0dXJuIHN1bSh3ZWlnaHRlZExvc3MpO1xuICB9XG4gIGlmIChyZWR1Y3Rpb24gPT09IFJlZHVjdGlvbi5NRUFOKSB7XG4gICAgaWYgKCR3ZWlnaHRzID09IG51bGwpIHtcbiAgICAgIHJldHVybiBtZWFuKHdlaWdodGVkTG9zcyk7XG4gICAgfSBlbHNlIHtcbiAgICAgIGNvbnN0IGJyb2FkY2FzdEZhY3RvciA9ICRsb3NzZXMuc2l6ZSAvICR3ZWlnaHRzLnNpemU7XG4gICAgICBjb25zdCByZXN1bHQgPSBkaXYoc3VtKHdlaWdodGVkTG9zcyksIHN1bSgkd2VpZ2h0cykpO1xuICAgICAgcmV0dXJuIGJyb2FkY2FzdEZhY3RvciA+IDEgPyBkaXYocmVzdWx0LCBzY2FsYXIoYnJvYWRjYXN0RmFjdG9yKSkgOlxuICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXN1bHQgYXMgTztcbiAgICB9XG4gIH1cbiAgaWYgKHJlZHVjdGlvbiA9PT0gUmVkdWN0aW9uLlNVTV9CWV9OT05aRVJPX1dFSUdIVFMpIHtcbiAgICBpZiAoJHdlaWdodHMgPT0gbnVsbCkge1xuICAgICAgcmV0dXJuIGRpdihzdW0od2VpZ2h0ZWRMb3NzKSwgc2NhbGFyKCRsb3NzZXMuc2l6ZSkpO1xuICAgIH0gZWxzZSB7XG4gICAgICBjb25zdCBicm9hZGNhc3RlZFdlaWdodHMgPSBtdWwoJHdlaWdodHMsIG9uZXMoJGxvc3Nlcy5zaGFwZSkpO1xuXG4gICAgICBjb25zdCBudW1Ob25aZXJvcyA9XG4gICAgICAgICAgY2FzdChzdW0obm90RXF1YWwoYnJvYWRjYXN0ZWRXZWlnaHRzLCBzY2FsYXIoMCkpKSwgJ2Zsb2F0MzInKTtcbiAgICAgIHJldHVybiBkaXYoc3VtKHdlaWdodGVkTG9zcyksIG51bU5vblplcm9zKTtcbiAgICB9XG4gIH1cblxuICB0aHJvdyBFcnJvcihgVW5rbm93biByZWR1Y3Rpb246ICR7cmVkdWN0aW9ufWApO1xufVxuZXhwb3J0IGNvbnN0IGNvbXB1dGVXZWlnaHRlZExvc3MgPSAvKiBAX19QVVJFX18gKi8gb3Aoe2NvbXB1dGVXZWlnaHRlZExvc3NffSk7XG4iXX0=