gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Softmax, util } from '@tensorflow/tfjs-core';
import { exp } from './Exp';
import { max } from './Max';
import { realDiv } from './RealDiv';
import { reshape } from './Reshape';
import { sub } from './Sub';
import { sum } from './Sum';
export function softmax(args) {
    const { inputs, backend, attrs } = args;
    const { logits } = inputs;
    const { dim } = attrs;
    const axes = util.parseAxisParam([dim], logits.shape);
    const maxLogit = max({
        inputs: { x: logits },
        backend,
        attrs: { reductionIndices: axes, keepDims: false }
    });
    const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
    const maxLogitsReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
    const a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend });
    const b = exp({ inputs: { x: a }, backend });
    const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
    const sumExpReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
    const res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend });
    backend.disposeIntermediateTensorInfo(maxLogit);
    backend.disposeIntermediateTensorInfo(maxLogitsReshaped);
    backend.disposeIntermediateTensorInfo(a);
    backend.disposeIntermediateTensorInfo(b);
    backend.disposeIntermediateTensorInfo(sumExp);
    backend.disposeIntermediateTensorInfo(sumExpReshaped);
    return res;
}
export const softmaxConfig = {
    kernelName: Softmax,
    backendName: 'webgl',
    kernelFunc: softmax
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU29mdG1heC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9Tb2Z0bWF4LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQTRCLE9BQU8sRUFBMkMsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJckksT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFFMUIsTUFBTSxVQUFVLE9BQU8sQ0FBQyxJQUl2QjtJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsTUFBTSxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ3hCLE1BQU0sRUFBQyxHQUFHLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFcEIsTUFBTSxJQUFJLEdBQUcsSUFBSSxDQUFDLGNBQWMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUV0RCxNQUFNLFFBQVEsR0FBRyxHQUFHLENBQUM7UUFDbkIsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQztRQUNuQixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsZ0JBQWdCLEVBQUUsSUFBSSxFQUFFLFFBQVEsRUFBRSxLQUFLLEVBQUM7S0FDakQsQ0FBQyxDQUFDO0lBRUgsTUFBTSxhQUFhLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFFBQVEsQ0FBQyxLQUFLLEVBQUUsSUFBSSxDQUFDLENBQUM7SUFFOUUsTUFBTSxpQkFBaUIsR0FDbkIsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLFFBQVEsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsYUFBYSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBQzdFLE1BQU0sQ0FBQyxHQUNILEdBQUcsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUUsQ0FBQyxFQUFFLGlCQUFpQixFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztJQUM1RSxNQUFNLENBQUMsR0FBRyxHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztJQUN2RCxNQUFNLE1BQU0sR0FDUixHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLElBQUksRUFBRSxJQUFJLEVBQUUsUUFBUSxFQUFFLEtBQUssRUFBQyxFQUFDLENBQUMsQ0FBQztJQUN6RSxNQUFNLGNBQWMsR0FDaEIsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsYUFBYSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBRTNFLE1BQU0sR0FBRyxHQUNMLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLGNBQWMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFlLENBQUM7SUFFeEUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLFFBQVEsQ0FBQyxDQUFDO0lBQ2hELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxpQkFBaUIsQ0FBQyxDQUFDO0lBQ3pELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN6QyxPQUFPLENBQUMsNkJBQTZCLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDekMsT0FBTyxDQUFDLDZCQUE2QixDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQzlDLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxjQUFjLENBQUMsQ0FBQztJQUV0RCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxhQUFhLEdBQWlCO0lBQ3pDLFVBQVUsRUFBRSxPQUFPO0lBQ25CLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxPQUFnQztDQUM3QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTb2Z0bWF4LCBTb2Z0bWF4QXR0cnMsIFNvZnRtYXhJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5cbmltcG9ydCB7ZXhwfSBmcm9tICcuL0V4cCc7XG5pbXBvcnQge21heH0gZnJvbSAnLi9NYXgnO1xuaW1wb3J0IHtyZWFsRGl2fSBmcm9tICcuL1JlYWxEaXYnO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuaW1wb3J0IHtzdWJ9IGZyb20gJy4vU3ViJztcbmltcG9ydCB7c3VtfSBmcm9tICcuL1N1bSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzb2Z0bWF4KGFyZ3M6IHtcbiAgaW5wdXRzOiBTb2Z0bWF4SW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMLFxuICBhdHRyczogU29mdG1heEF0dHJzXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHtsb2dpdHN9ID0gaW5wdXRzO1xuICBjb25zdCB7ZGltfSA9IGF0dHJzO1xuXG4gIGNvbnN0IGF4ZXMgPSB1dGlsLnBhcnNlQXhpc1BhcmFtKFtkaW1dLCBsb2dpdHMuc2hhcGUpO1xuXG4gIGNvbnN0IG1heExvZ2l0ID0gbWF4KHtcbiAgICBpbnB1dHM6IHt4OiBsb2dpdHN9LFxuICAgIGJhY2tlbmQsXG4gICAgYXR0cnM6IHtyZWR1Y3Rpb25JbmRpY2VzOiBheGVzLCBrZWVwRGltczogZmFsc2V9XG4gIH0pO1xuXG4gIGNvbnN0IGV4cGFuZGVkU2hhcGUgPSBiYWNrZW5kX3V0aWwuZXhwYW5kU2hhcGVUb0tlZXBEaW0obWF4TG9naXQuc2hhcGUsIGF4ZXMpO1xuXG4gIGNvbnN0IG1heExvZ2l0c1Jlc2hhcGVkID1cbiAgICAgIHJlc2hhcGUoe2lucHV0czoge3g6IG1heExvZ2l0fSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogZXhwYW5kZWRTaGFwZX19KTtcbiAgY29uc3QgYSA9XG4gICAgICBzdWIoe2lucHV0czoge2E6IGxvZ2l0cywgYjogbWF4TG9naXRzUmVzaGFwZWR9LCBiYWNrZW5kfSkgYXMgVGVuc29ySW5mbztcbiAgY29uc3QgYiA9IGV4cCh7aW5wdXRzOiB7eDogYX0sIGJhY2tlbmR9KSBhcyBUZW5zb3JJbmZvO1xuICBjb25zdCBzdW1FeHAgPVxuICAgICAgc3VtKHtpbnB1dHM6IHt4OiBifSwgYmFja2VuZCwgYXR0cnM6IHtheGlzOiBheGVzLCBrZWVwRGltczogZmFsc2V9fSk7XG4gIGNvbnN0IHN1bUV4cFJlc2hhcGVkID1cbiAgICAgIHJlc2hhcGUoe2lucHV0czoge3g6IHN1bUV4cH0sIGJhY2tlbmQsIGF0dHJzOiB7c2hhcGU6IGV4cGFuZGVkU2hhcGV9fSk7XG5cbiAgY29uc3QgcmVzID1cbiAgICAgIHJlYWxEaXYoe2lucHV0czoge2E6IGIsIGI6IHN1bUV4cFJlc2hhcGVkfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhtYXhMb2dpdCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8obWF4TG9naXRzUmVzaGFwZWQpO1xuICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKGEpO1xuICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKGIpO1xuICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHN1bUV4cCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oc3VtRXhwUmVzaGFwZWQpO1xuXG4gIHJldHVybiByZXM7XG59XG5cbmV4cG9ydCBjb25zdCBzb2Z0bWF4Q29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IFNvZnRtYXgsXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiBzb2Z0bWF4IGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==