gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Softmax, util } from '@tensorflow/tfjs-core';
import { exp } from './Exp';
import { max } from './Max';
import { div } from './RealDiv';
import { reshape } from './Reshape';
import { sub } from './Sub';
import { sum } from './Sum';
export function softmax(args) {
    const { inputs, backend, attrs } = args;
    const { logits } = inputs;
    const { dim } = attrs;
    const logitsRank = logits.shape.length;
    let $dim = dim;
    if ($dim === -1) {
        $dim = logitsRank - 1;
    }
    if ($dim !== logitsRank - 1) {
        throw Error('Softmax along a non-last dimension is not yet supported. ' +
            `Logits was rank ${logitsRank} and dim was ${$dim}`);
    }
    const axes = util.parseAxisParam([$dim], logits.shape);
    const maxLogit = max({
        inputs: { x: logits },
        backend,
        attrs: { reductionIndices: axes, keepDims: false }
    });
    const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
    const maxLogitReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
    const a = sub({ inputs: { a: logits, b: maxLogitReshaped }, backend });
    const b = exp({ inputs: { x: a }, backend });
    const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
    const sumReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
    const result = div({ inputs: { a: b, b: sumReshaped }, backend });
    backend.disposeIntermediateTensorInfo(maxLogit);
    backend.disposeIntermediateTensorInfo(maxLogitReshaped);
    backend.disposeIntermediateTensorInfo(a);
    backend.disposeIntermediateTensorInfo(b);
    backend.disposeIntermediateTensorInfo(sumExp);
    backend.disposeIntermediateTensorInfo(sumReshaped);
    return result;
}
export const softmaxConfig = {
    kernelName: Softmax,
    backendName: 'cpu',
    kernelFunc: softmax
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU29mdG1heC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC1jcHUvc3JjL2tlcm5lbHMvU29mdG1heC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUE0QixPQUFPLEVBQTJDLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBSXJJLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBRTFCLE1BQU0sVUFBVSxPQUFPLENBQ25CLElBQ3lFO0lBRTNFLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsTUFBTSxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ3hCLE1BQU0sRUFBQyxHQUFHLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFcEIsTUFBTSxVQUFVLEdBQUcsTUFBTSxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUM7SUFFdkMsSUFBSSxJQUFJLEdBQUcsR0FBRyxDQUFDO0lBQ2YsSUFBSSxJQUFJLEtBQUssQ0FBQyxDQUFDLEVBQUU7UUFDZixJQUFJLEdBQUcsVUFBVSxHQUFHLENBQUMsQ0FBQztLQUN2QjtJQUNELElBQUksSUFBSSxLQUFLLFVBQVUsR0FBRyxDQUFDLEVBQUU7UUFDM0IsTUFBTSxLQUFLLENBQ1AsMkRBQTJEO1lBQzNELG1CQUFtQixVQUFVLGdCQUFnQixJQUFJLEVBQUUsQ0FBQyxDQUFDO0tBQzFEO0lBRUQsTUFBTSxJQUFJLEdBQUcsSUFBSSxDQUFDLGNBQWMsQ0FBQyxDQUFDLElBQUksQ0FBQyxFQUFFLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUN2RCxNQUFNLFFBQVEsR0FBRyxHQUFHLENBQUM7UUFDbkIsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQztRQUNuQixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsZ0JBQWdCLEVBQUUsSUFBSSxFQUFFLFFBQVEsRUFBRSxLQUFLLEVBQUM7S0FDakQsQ0FBQyxDQUFDO0lBQ0gsTUFBTSxhQUFhLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFFBQVEsQ0FBQyxLQUFLLEVBQUUsSUFBSSxDQUFDLENBQUM7SUFFOUUsTUFBTSxnQkFBZ0IsR0FDbEIsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLFFBQVEsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsYUFBYSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBQzdFLE1BQU0sQ0FBQyxHQUNILEdBQUcsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUUsQ0FBQyxFQUFFLGdCQUFnQixFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztJQUMzRSxNQUFNLENBQUMsR0FBRyxHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztJQUN2RCxNQUFNLE1BQU0sR0FDUixHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLElBQUksRUFBRSxJQUFJLEVBQUUsUUFBUSxFQUFFLEtBQUssRUFBQyxFQUFDLENBQUMsQ0FBQztJQUN6RSxNQUFNLFdBQVcsR0FDYixPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsTUFBTSxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxhQUFhLEVBQUMsRUFBQyxDQUFDLENBQUM7SUFFM0UsTUFBTSxNQUFNLEdBQUcsR0FBRyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsV0FBVyxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztJQUU1RSxPQUFPLENBQUMsNkJBQTZCLENBQUMsUUFBUSxDQUFDLENBQUM7SUFDaEQsT0FBTyxDQUFDLDZCQUE2QixDQUFDLGdCQUFnQixDQUFDLENBQUM7SUFDeEQsT0FBTyxDQUFDLDZCQUE2QixDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ3pDLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN6QyxPQUFPLENBQUMsNkJBQTZCLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDOUMsT0FBTyxDQUFDLDZCQUE2QixDQUFDLFdBQVcsQ0FBQyxDQUFDO0lBRW5ELE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxhQUFhLEdBQWlCO0lBQ3pDLFVBQVUsRUFBRSxPQUFPO0lBQ25CLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxPQUFnQztDQUM3QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTb2Z0bWF4LCBTb2Z0bWF4QXR0cnMsIFNvZnRtYXhJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcblxuaW1wb3J0IHtleHB9IGZyb20gJy4vRXhwJztcbmltcG9ydCB7bWF4fSBmcm9tICcuL01heCc7XG5pbXBvcnQge2Rpdn0gZnJvbSAnLi9SZWFsRGl2JztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcbmltcG9ydCB7c3VifSBmcm9tICcuL1N1Yic7XG5pbXBvcnQge3N1bX0gZnJvbSAnLi9TdW0nO1xuXG5leHBvcnQgZnVuY3Rpb24gc29mdG1heChcbiAgICBhcmdzOlxuICAgICAgICB7aW5wdXRzOiBTb2Z0bWF4SW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSwgYXR0cnM6IFNvZnRtYXhBdHRyc30pOlxuICAgIFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7bG9naXRzfSA9IGlucHV0cztcbiAgY29uc3Qge2RpbX0gPSBhdHRycztcblxuICBjb25zdCBsb2dpdHNSYW5rID0gbG9naXRzLnNoYXBlLmxlbmd0aDtcblxuICBsZXQgJGRpbSA9IGRpbTtcbiAgaWYgKCRkaW0gPT09IC0xKSB7XG4gICAgJGRpbSA9IGxvZ2l0c1JhbmsgLSAxO1xuICB9XG4gIGlmICgkZGltICE9PSBsb2dpdHNSYW5rIC0gMSkge1xuICAgIHRocm93IEVycm9yKFxuICAgICAgICAnU29mdG1heCBhbG9uZyBhIG5vbi1sYXN0IGRpbWVuc2lvbiBpcyBub3QgeWV0IHN1cHBvcnRlZC4gJyArXG4gICAgICAgIGBMb2dpdHMgd2FzIHJhbmsgJHtsb2dpdHNSYW5rfSBhbmQgZGltIHdhcyAkeyRkaW19YCk7XG4gIH1cblxuICBjb25zdCBheGVzID0gdXRpbC5wYXJzZUF4aXNQYXJhbShbJGRpbV0sIGxvZ2l0cy5zaGFwZSk7XG4gIGNvbnN0IG1heExvZ2l0ID0gbWF4KHtcbiAgICBpbnB1dHM6IHt4OiBsb2dpdHN9LFxuICAgIGJhY2tlbmQsXG4gICAgYXR0cnM6IHtyZWR1Y3Rpb25JbmRpY2VzOiBheGVzLCBrZWVwRGltczogZmFsc2V9XG4gIH0pO1xuICBjb25zdCBleHBhbmRlZFNoYXBlID0gYmFja2VuZF91dGlsLmV4cGFuZFNoYXBlVG9LZWVwRGltKG1heExvZ2l0LnNoYXBlLCBheGVzKTtcblxuICBjb25zdCBtYXhMb2dpdFJlc2hhcGVkID1cbiAgICAgIHJlc2hhcGUoe2lucHV0czoge3g6IG1heExvZ2l0fSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogZXhwYW5kZWRTaGFwZX19KTtcbiAgY29uc3QgYSA9XG4gICAgICBzdWIoe2lucHV0czoge2E6IGxvZ2l0cywgYjogbWF4TG9naXRSZXNoYXBlZH0sIGJhY2tlbmR9KSBhcyBUZW5zb3JJbmZvO1xuICBjb25zdCBiID0gZXhwKHtpbnB1dHM6IHt4OiBhfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gIGNvbnN0IHN1bUV4cCA9XG4gICAgICBzdW0oe2lucHV0czoge3g6IGJ9LCBiYWNrZW5kLCBhdHRyczoge2F4aXM6IGF4ZXMsIGtlZXBEaW1zOiBmYWxzZX19KTtcbiAgY29uc3Qgc3VtUmVzaGFwZWQgPVxuICAgICAgcmVzaGFwZSh7aW5wdXRzOiB7eDogc3VtRXhwfSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogZXhwYW5kZWRTaGFwZX19KTtcblxuICBjb25zdCByZXN1bHQgPSBkaXYoe2lucHV0czoge2E6IGIsIGI6IHN1bVJlc2hhcGVkfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhtYXhMb2dpdCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8obWF4TG9naXRSZXNoYXBlZCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oYSk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oYik7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oc3VtRXhwKTtcbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhzdW1SZXNoYXBlZCk7XG5cbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuZXhwb3J0IGNvbnN0IHNvZnRtYXhDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU29mdG1heCxcbiAgYmFja2VuZE5hbWU6ICdjcHUnLFxuICBrZXJuZWxGdW5jOiBzb2Z0bWF4IGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==