/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Softmax, util } from '@tensorflow/tfjs-core'; import { exp } from './Exp'; import { max } from './Max'; import { realDiv } from './RealDiv'; import { reshape } from './Reshape'; import { sub } from './Sub'; import { sum } from './Sum'; export function softmax(args) { const { inputs, backend, attrs } = args; const { logits } = inputs; const { dim } = attrs; const axes = util.parseAxisParam([dim], logits.shape); const maxLogit = max({ inputs: { x: logits }, backend, attrs: { reductionIndices: axes, keepDims: false } }); const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes); const maxLogitsReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } }); const a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend }); const b = exp({ inputs: { x: a }, backend }); const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } }); const sumExpReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } }); const res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend }); backend.disposeIntermediateTensorInfo(maxLogit); backend.disposeIntermediateTensorInfo(maxLogitsReshaped); backend.disposeIntermediateTensorInfo(a); backend.disposeIntermediateTensorInfo(b); backend.disposeIntermediateTensorInfo(sumExp); backend.disposeIntermediateTensorInfo(sumExpReshaped); return res; } export const softmaxConfig = { kernelName: Softmax, backendName: 'webgl', kernelFunc: softmax }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU29mdG1heC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9Tb2Z0bWF4LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQTRCLE9BQU8sRUFBMkMsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJckksT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFFMUIsTUFBTSxVQUFVLE9BQU8sQ0FBQyxJQUl2QjtJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsTUFBTSxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ3hCLE1BQU0sRUFBQyxHQUFHLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFcEIsTUFBTSxJQUFJLEdBQUcsSUFBSSxDQUFDLGNBQWMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUV0RCxNQUFNLFFBQVEsR0FBRyxHQUFHLENBQUM7UUFDbkIsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQztRQUNuQixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsZ0JBQWdCLEVBQUUsSUFBSSxFQUFFLFFBQVEsRUFBRSxLQUFLLEVBQUM7S0FDakQsQ0FBQyxDQUFDO0lBRUgsTUFBTSxhQUFhLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLFFBQVEsQ0FBQyxLQUFLLEVBQUUsSUFBSSxDQUFDLENBQUM7SUFFOUUsTUFBTSxpQkFBaUIsR0FDbkIsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLFFBQVEsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsYUFBYSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBQzdFLE1BQU0sQ0FBQyxHQUNILEdBQUcsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUUsQ0FBQyxFQUFFLGlCQUFpQixFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztJQUM1RSxNQUFNLENBQUMsR0FBRyxHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztJQUN2RCxNQUFNLE1BQU0sR0FDUixHQUFHLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLElBQUksRUFBRSxJQUFJLEVBQUUsUUFBUSxFQUFFLEtBQUssRUFBQyxFQUFDLENBQUMsQ0FBQztJQUN6RSxNQUFNLGNBQWMsR0FDaEIsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsYUFBYSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBRTNFLE1BQU0sR0FBRyxHQUNMLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxFQUFFLGNBQWMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFlLENBQUM7SUFFeEUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLFFBQVEsQ0FBQyxDQUFDO0lBQ2hELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxpQkFBaUIsQ0FBQyxDQUFDO0lBQ3pELE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN6QyxPQUFPLENBQUMsNkJBQTZCLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDekMsT0FBTyxDQUFDLDZCQUE2QixDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQzlDLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxjQUFjLENBQUMsQ0FBQztJQUV0RCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxhQUFhLEdBQWlCO0lBQ3pDLFVBQVUsRUFBRSxPQUFPO0lBQ25CLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxPQUFnQztDQUM3QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTb2Z0bWF4LCBTb2Z0bWF4QXR0cnMsIFNvZnRtYXhJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5cbmltcG9ydCB7ZXhwfSBmcm9tICcuL0V4cCc7XG5pbXBvcnQge21heH0gZnJvbSAnLi9NYXgnO1xuaW1wb3J0IHtyZWFsRGl2fSBmcm9tICcuL1JlYWxEaXYnO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuaW1wb3J0IHtzdWJ9IGZyb20gJy4vU3ViJztcbmltcG9ydCB7c3VtfSBmcm9tICcuL1N1bSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzb2Z0bWF4KGFyZ3M6IHtcbiAgaW5wdXRzOiBTb2Z0bWF4SW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMLFxuICBhdHRyczogU29mdG1heEF0dHJzXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHtsb2dpdHN9ID0gaW5wdXRzO1xuICBjb25zdCB7ZGltfSA9IGF0dHJzO1xuXG4gIGNvbnN0IGF4ZXMgPSB1dGlsLnBhcnNlQXhpc1BhcmFtKFtkaW1dLCBsb2dpdHMuc2hhcGUpO1xuXG4gIGNvbnN0IG1heExvZ2l0ID0gbWF4KHtcbiAgICBpbnB1dHM6IHt4OiBsb2dpdHN9LFxuICAgIGJhY2tlbmQsXG4gICAgYXR0cnM6IHtyZWR1Y3Rpb25JbmRpY2VzOiBheGVzLCBrZWVwRGltczogZmFsc2V9XG4gIH0pO1xuXG4gIGNvbnN0IGV4cGFuZGVkU2hhcGUgPSBiYWNrZW5kX3V0aWwuZXhwYW5kU2hhcGVUb0tlZXBEaW0obWF4TG9naXQuc2hhcGUsIGF4ZXMpO1xuXG4gIGNvbnN0IG1heExvZ2l0c1Jlc2hhcGVkID1cbiAgICAgIHJlc2hhcGUoe2lucHV0czoge3g6IG1heExvZ2l0fSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogZXhwYW5kZWRTaGFwZX19KTtcbiAgY29uc3QgYSA9XG4gICAgICBzdWIoe2lucHV0czoge2E6IGxvZ2l0cywgYjogbWF4TG9naXRzUmVzaGFwZWR9LCBiYWNrZW5kfSkgYXMgVGVuc29ySW5mbztcbiAgY29uc3QgYiA9IGV4cCh7aW5wdXRzOiB7eDogYX0sIGJhY2tlbmR9KSBhcyBUZW5zb3JJbmZvO1xuICBjb25zdCBzdW1FeHAgPVxuICAgICAgc3VtKHtpbnB1dHM6IHt4OiBifSwgYmFja2VuZCwgYXR0cnM6IHtheGlzOiBheGVzLCBrZWVwRGltczogZmFsc2V9fSk7XG4gIGNvbnN0IHN1bUV4cFJlc2hhcGVkID1cbiAgICAgIHJlc2hhcGUoe2lucHV0czoge3g6IHN1bUV4cH0sIGJhY2tlbmQsIGF0dHJzOiB7c2hhcGU6IGV4cGFuZGVkU2hhcGV9fSk7XG5cbiAgY29uc3QgcmVzID1cbiAgICAgIHJlYWxEaXYoe2lucHV0czoge2E6IGIsIGI6IHN1bUV4cFJlc2hhcGVkfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhtYXhMb2dpdCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8obWF4TG9naXRzUmVzaGFwZWQpO1xuICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKGEpO1xuICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKGIpO1xuICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHN1bUV4cCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oc3VtRXhwUmVzaGFwZWQpO1xuXG4gIHJldHVybiByZXM7XG59XG5cbmV4cG9ydCBjb25zdCBzb2Z0bWF4Q29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IFNvZnRtYXgsXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiBzb2Z0bWF4IGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==