gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import { convertToTensor } from '../../tensor_util_env';
import { cast } from '../cast';
import { div } from '../div';
import { Reduction } from '../loss_ops_utils';
import { mean } from '../mean';
import { mul } from '../mul';
import { notEqual } from '../not_equal';
import { ones } from '../ones';
import { op } from '../operation';
import { scalar } from '../scalar';
import { sum } from '../sum';
/**
 * Computes the weighted loss between two tensors.
 *
 * @param losses Tensor of shape `[batch_size, d1, ..., dN]`.
 * @param weights Tensor whose rank is either 0, or the same rank as
 *    `losses`, and must be broadcastable to `losses` (i.e., all
 *    dimensions must be either `1`, or the same as the corresponding
 *    `losses` dimension).
 *
 * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
 */
function computeWeightedLoss_(losses, weights, reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) {
    const $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');
    let $weights = null;
    if (weights != null) {
        $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');
    }
    const weightedLoss = ($weights == null) ? $losses : mul($losses, $weights);
    if (reduction === Reduction.NONE) {
        return weightedLoss;
    }
    if (reduction === Reduction.SUM) {
        return sum(weightedLoss);
    }
    if (reduction === Reduction.MEAN) {
        if ($weights == null) {
            return mean(weightedLoss);
        }
        else {
            const broadcastFactor = $losses.size / $weights.size;
            const result = div(sum(weightedLoss), sum($weights));
            return broadcastFactor > 1 ? div(result, scalar(broadcastFactor)) :
                result;
        }
    }
    if (reduction === Reduction.SUM_BY_NONZERO_WEIGHTS) {
        if ($weights == null) {
            return div(sum(weightedLoss), scalar($losses.size));
        }
        else {
            const broadcastedWeights = mul($weights, ones($losses.shape));
            const numNonZeros = cast(sum(notEqual(broadcastedWeights, scalar(0))), 'float32');
            return div(sum(weightedLoss), numNonZeros);
        }
    }
    throw Error(`Unknown reduction: ${reduction}`);
}
export const computeWeightedLoss = /* @__PURE__ */ op({ computeWeightedLoss_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29tcHV0ZV93ZWlnaHRlZF9sb3NzLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvbG9zc2VzL2NvbXB1dGVfd2VpZ2h0ZWRfbG9zcy50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFpQkEsT0FBTyxFQUFDLGVBQWUsRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBR3RELE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFDN0IsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLFFBQVEsQ0FBQztBQUMzQixPQUFPLEVBQUMsU0FBUyxFQUFDLE1BQU0sbUJBQW1CLENBQUM7QUFDNUMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUM3QixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sUUFBUSxDQUFDO0FBQzNCLE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxjQUFjLENBQUM7QUFDdEMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUM3QixPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sY0FBYyxDQUFDO0FBQ2hDLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDakMsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLFFBQVEsQ0FBQztBQUUzQjs7Ozs7Ozs7OztHQVVHO0FBQ0gsU0FBUyxvQkFBb0IsQ0FDekIsTUFBb0IsRUFBRSxPQUEyQixFQUNqRCxTQUFTLEdBQUcsU0FBUyxDQUFDLHNCQUFzQjtJQUM5QyxNQUFNLE9BQU8sR0FBRyxlQUFlLENBQUMsTUFBTSxFQUFFLFFBQVEsRUFBRSxxQkFBcUIsQ0FBQyxDQUFDO0lBQ3pFLElBQUksUUFBUSxHQUFXLElBQUksQ0FBQztJQUM1QixJQUFJLE9BQU8sSUFBSSxJQUFJLEVBQUU7UUFDbkIsUUFBUSxHQUFHLGVBQWUsQ0FBQyxPQUFPLEVBQUUsU0FBUyxFQUFFLHFCQUFxQixDQUFDLENBQUM7S0FDdkU7SUFFRCxNQUFNLFlBQVksR0FBRyxDQUFDLFFBQVEsSUFBSSxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsT0FBTyxFQUFFLFFBQVEsQ0FBQyxDQUFDO0lBRTNFLElBQUksU0FBUyxLQUFLLFNBQVMsQ0FBQyxJQUFJLEVBQUU7UUFDaEMsT0FBTyxZQUFpQixDQUFDO0tBQzFCO0lBQ0QsSUFBSSxTQUFTLEtBQUssU0FBUyxDQUFDLEdBQUcsRUFBRTtRQUMvQixPQUFPLEdBQUcsQ0FBQyxZQUFZLENBQUMsQ0FBQztLQUMxQjtJQUNELElBQUksU0FBUyxLQUFLLFNBQVMsQ0FBQyxJQUFJLEVBQUU7UUFDaEMsSUFBSSxRQUFRLElBQUksSUFBSSxFQUFFO1lBQ3BCLE9BQU8sSUFBSSxDQUFDLFlBQVksQ0FBQyxDQUFDO1NBQzNCO2FBQU07WUFDTCxNQUFNLGVBQWUsR0FBRyxPQUFPLENBQUMsSUFBSSxHQUFHLFFBQVEsQ0FBQyxJQUFJLENBQUM7WUFDckQsTUFBTSxNQUFNLEdBQUcsR0FBRyxDQUFDLEdBQUcsQ0FBQyxZQUFZLENBQUMsRUFBRSxHQUFHLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQztZQUNyRCxPQUFPLGVBQWUsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxNQUFNLEVBQUUsTUFBTSxDQUFDLGVBQWUsQ0FBQyxDQUFDLENBQUMsQ0FBQztnQkFDdEMsTUFBVyxDQUFDO1NBQzFDO0tBQ0Y7SUFDRCxJQUFJLFNBQVMsS0FBSyxTQUFTLENBQUMsc0JBQXNCLEVBQUU7UUFDbEQsSUFBSSxRQUFRLElBQUksSUFBSSxFQUFFO1lBQ3BCLE9BQU8sR0FBRyxDQUFDLEdBQUcsQ0FBQyxZQUFZLENBQUMsRUFBRSxNQUFNLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUM7U0FDckQ7YUFBTTtZQUNMLE1BQU0sa0JBQWtCLEdBQUcsR0FBRyxDQUFDLFFBQVEsRUFBRSxJQUFJLENBQUMsT0FBTyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUM7WUFFOUQsTUFBTSxXQUFXLEdBQ2IsSUFBSSxDQUFDLEdBQUcsQ0FBQyxRQUFRLENBQUMsa0JBQWtCLEVBQUUsTUFBTSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztZQUNsRSxPQUFPLEdBQUcsQ0FBQyxHQUFHLENBQUMsWUFBWSxDQUFDLEVBQUUsV0FBVyxDQUFDLENBQUM7U0FDNUM7S0FDRjtJQUVELE1BQU0sS0FBSyxDQUFDLHNCQUFzQixTQUFTLEVBQUUsQ0FBQyxDQUFDO0FBQ2pELENBQUM7QUFDRCxNQUFNLENBQUMsTUFBTSxtQkFBbUIsR0FBRyxlQUFlLENBQUMsRUFBRSxDQUFDLEVBQUMsb0JBQW9CLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtUZW5zb3J9IGZyb20gJy4uLy4uL3RlbnNvcic7XG5pbXBvcnQge2NvbnZlcnRUb1RlbnNvcn0gZnJvbSAnLi4vLi4vdGVuc29yX3V0aWxfZW52JztcbmltcG9ydCB7VGVuc29yTGlrZX0gZnJvbSAnLi4vLi4vdHlwZXMnO1xuXG5pbXBvcnQge2Nhc3R9IGZyb20gJy4uL2Nhc3QnO1xuaW1wb3J0IHtkaXZ9IGZyb20gJy4uL2Rpdic7XG5pbXBvcnQge1JlZHVjdGlvbn0gZnJvbSAnLi4vbG9zc19vcHNfdXRpbHMnO1xuaW1wb3J0IHttZWFufSBmcm9tICcuLi9tZWFuJztcbmltcG9ydCB7bXVsfSBmcm9tICcuLi9tdWwnO1xuaW1wb3J0IHtub3RFcXVhbH0gZnJvbSAnLi4vbm90X2VxdWFsJztcbmltcG9ydCB7b25lc30gZnJvbSAnLi4vb25lcyc7XG5pbXBvcnQge29wfSBmcm9tICcuLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtzY2FsYXJ9IGZyb20gJy4uL3NjYWxhcic7XG5pbXBvcnQge3N1bX0gZnJvbSAnLi4vc3VtJztcblxuLyoqXG4gKiBDb21wdXRlcyB0aGUgd2VpZ2h0ZWQgbG9zcyBiZXR3ZWVuIHR3byB0ZW5zb3JzLlxuICpcbiAqIEBwYXJhbSBsb3NzZXMgVGVuc29yIG9mIHNoYXBlIGBbYmF0Y2hfc2l6ZSwgZDEsIC4uLiwgZE5dYC5cbiAqIEBwYXJhbSB3ZWlnaHRzIFRlbnNvciB3aG9zZSByYW5rIGlzIGVpdGhlciAwLCBvciB0aGUgc2FtZSByYW5rIGFzXG4gKiAgICBgbG9zc2VzYCwgYW5kIG11c3QgYmUgYnJvYWRjYXN0YWJsZSB0byBgbG9zc2VzYCAoaS5lLiwgYWxsXG4gKiAgICBkaW1lbnNpb25zIG11c3QgYmUgZWl0aGVyIGAxYCwgb3IgdGhlIHNhbWUgYXMgdGhlIGNvcnJlc3BvbmRpbmdcbiAqICAgIGBsb3NzZXNgIGRpbWVuc2lvbikuXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ1RyYWluaW5nJywgc3ViaGVhZGluZzogJ0xvc3NlcycsIG5hbWVzcGFjZTogJ2xvc3Nlcyd9XG4gKi9cbmZ1bmN0aW9uIGNvbXB1dGVXZWlnaHRlZExvc3NfPFQgZXh0ZW5kcyBUZW5zb3IsIE8gZXh0ZW5kcyBUZW5zb3I+KFxuICAgIGxvc3NlczogVHxUZW5zb3JMaWtlLCB3ZWlnaHRzPzogVGVuc29yfFRlbnNvckxpa2UsXG4gICAgcmVkdWN0aW9uID0gUmVkdWN0aW9uLlNVTV9CWV9OT05aRVJPX1dFSUdIVFMpOiBPIHtcbiAgY29uc3QgJGxvc3NlcyA9IGNvbnZlcnRUb1RlbnNvcihsb3NzZXMsICdsb3NzZXMnLCAnY29tcHV0ZVdlaWdodGVkTG9zcycpO1xuICBsZXQgJHdlaWdodHM6IFRlbnNvciA9IG51bGw7XG4gIGlmICh3ZWlnaHRzICE9IG51bGwpIHtcbiAgICAkd2VpZ2h0cyA9IGNvbnZlcnRUb1RlbnNvcih3ZWlnaHRzLCAnd2VpZ2h0cycsICdjb21wdXRlV2VpZ2h0ZWRMb3NzJyk7XG4gIH1cblxuICBjb25zdCB3ZWlnaHRlZExvc3MgPSAoJHdlaWdodHMgPT0gbnVsbCkgPyAkbG9zc2VzIDogbXVsKCRsb3NzZXMsICR3ZWlnaHRzKTtcblxuICBpZiAocmVkdWN0aW9uID09PSBSZWR1Y3Rpb24uTk9ORSkge1xuICAgIHJldHVybiB3ZWlnaHRlZExvc3MgYXMgTztcbiAgfVxuICBpZiAocmVkdWN0aW9uID09PSBSZWR1Y3Rpb24uU1VNKSB7XG4gICAgcmV0dXJuIHN1bSh3ZWlnaHRlZExvc3MpO1xuICB9XG4gIGlmIChyZWR1Y3Rpb24gPT09IFJlZHVjdGlvbi5NRUFOKSB7XG4gICAgaWYgKCR3ZWlnaHRzID09IG51bGwpIHtcbiAgICAgIHJldHVybiBtZWFuKHdlaWdodGVkTG9zcyk7XG4gICAgfSBlbHNlIHtcbiAgICAgIGNvbnN0IGJyb2FkY2FzdEZhY3RvciA9ICRsb3NzZXMuc2l6ZSAvICR3ZWlnaHRzLnNpemU7XG4gICAgICBjb25zdCByZXN1bHQgPSBkaXYoc3VtKHdlaWdodGVkTG9zcyksIHN1bSgkd2VpZ2h0cykpO1xuICAgICAgcmV0dXJuIGJyb2FkY2FzdEZhY3RvciA+IDEgPyBkaXYocmVzdWx0LCBzY2FsYXIoYnJvYWRjYXN0RmFjdG9yKSkgOlxuICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXN1bHQgYXMgTztcbiAgICB9XG4gIH1cbiAgaWYgKHJlZHVjdGlvbiA9PT0gUmVkdWN0aW9uLlNVTV9CWV9OT05aRVJPX1dFSUdIVFMpIHtcbiAgICBpZiAoJHdlaWdodHMgPT0gbnVsbCkge1xuICAgICAgcmV0dXJuIGRpdihzdW0od2VpZ2h0ZWRMb3NzKSwgc2NhbGFyKCRsb3NzZXMuc2l6ZSkpO1xuICAgIH0gZWxzZSB7XG4gICAgICBjb25zdCBicm9hZGNhc3RlZFdlaWdodHMgPSBtdWwoJHdlaWdodHMsIG9uZXMoJGxvc3Nlcy5zaGFwZSkpO1xuXG4gICAgICBjb25zdCBudW1Ob25aZXJvcyA9XG4gICAgICAgICAgY2FzdChzdW0obm90RXF1YWwoYnJvYWRjYXN0ZWRXZWlnaHRzLCBzY2FsYXIoMCkpKSwgJ2Zsb2F0MzInKTtcbiAgICAgIHJldHVybiBkaXYoc3VtKHdlaWdodGVkTG9zcyksIG51bU5vblplcm9zKTtcbiAgICB9XG4gIH1cblxuICB0aHJvdyBFcnJvcihgVW5rbm93biByZWR1Y3Rpb246ICR7cmVkdWN0aW9ufWApO1xufVxuZXhwb3J0IGNvbnN0IGNvbXB1dGVXZWlnaHRlZExvc3MgPSAvKiBAX19QVVJFX18gKi8gb3Aoe2NvbXB1dGVXZWlnaHRlZExvc3NffSk7XG4iXX0=