gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the License);
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an AS IS BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { _FusedMatMul } from '@tensorflow/tfjs-core';
import { applyActivation } from '../utils/fused_utils';
import { add } from './Add';
import { batchMatMul } from './BatchMatMul';
export function _fusedMatMul(args) {
    const { inputs, backend, attrs } = args;
    const { a, b, bias, preluActivationWeights } = inputs;
    const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
    let current;
    let addRes;
    let activationRes;
    const intermediates = [];
    const matMulRes = batchMatMul({ inputs: { a, b }, attrs: { transposeA, transposeB }, backend });
    current = matMulRes;
    if (bias) {
        addRes = add({ inputs: { a: current, b: bias }, backend });
        intermediates.push(current);
        current = addRes;
    }
    if (activation) {
        activationRes = applyActivation(backend, current, activation, preluActivationWeights, leakyreluAlpha);
        intermediates.push(current);
        current = activationRes;
    }
    for (const i of intermediates) {
        backend.disposeIntermediateTensorInfo(i);
    }
    return current;
}
export const _fusedMatMulConfig = {
    kernelName: _FusedMatMul,
    backendName: 'cpu',
    kernelFunc: _fusedMatMul,
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiX0Z1c2VkTWF0TXVsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9fRnVzZWRNYXRNdWwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBOEUsTUFBTSx1QkFBdUIsQ0FBQztBQUdoSSxPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sc0JBQXNCLENBQUM7QUFFckQsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsV0FBVyxFQUFDLE1BQU0sZUFBZSxDQUFDO0FBRTFDLE1BQU0sVUFBVSxZQUFZLENBQUMsSUFJNUI7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsSUFBSSxFQUFFLHNCQUFzQixFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ3BELE1BQU0sRUFBQyxVQUFVLEVBQUUsVUFBVSxFQUFFLFVBQVUsRUFBRSxjQUFjLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFbkUsSUFBSSxPQUFPLENBQUM7SUFDWixJQUFJLE1BQU0sQ0FBQztJQUNYLElBQUksYUFBYSxDQUFDO0lBRWxCLE1BQU0sYUFBYSxHQUFpQixFQUFFLENBQUM7SUFFdkMsTUFBTSxTQUFTLEdBQ1gsV0FBVyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLENBQUMsRUFBQyxFQUFFLEtBQUssRUFBRSxFQUFDLFVBQVUsRUFBRSxVQUFVLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO0lBQzVFLE9BQU8sR0FBRyxTQUFTLENBQUM7SUFFcEIsSUFBSSxJQUFJLEVBQUU7UUFDUixNQUFNLEdBQUcsR0FBRyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE9BQU8sRUFBRSxDQUFDLEVBQUUsSUFBSSxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztRQUNyRSxhQUFhLENBQUMsSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDO1FBQzVCLE9BQU8sR0FBRyxNQUFNLENBQUM7S0FDbEI7SUFDRCxJQUFJLFVBQVUsRUFBRTtRQUNkLGFBQWEsR0FBRyxlQUFlLENBQzNCLE9BQU8sRUFBRSxPQUFPLEVBQUUsVUFBVSxFQUFFLHNCQUFzQixFQUFFLGNBQWMsQ0FBQyxDQUFDO1FBQzFFLGFBQWEsQ0FBQyxJQUFJLENBQUMsT0FBTyxDQUFDLENBQUM7UUFDNUIsT0FBTyxHQUFHLGFBQWEsQ0FBQztLQUN6QjtJQUVELEtBQUssTUFBTSxDQUFDLElBQUksYUFBYSxFQUFFO1FBQzdCLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUMxQztJQUVELE9BQU8sT0FBTyxDQUFDO0FBQ2pCLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxrQkFBa0IsR0FBaUI7SUFDOUMsVUFBVSxFQUFFLFlBQVk7SUFDeEIsV0FBVyxFQUFFLEtBQUs7SUFDbEIsVUFBVSxFQUFFLFlBQXFDO0NBQ2xELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIExpY2Vuc2UpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gQVMgSVMgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge19GdXNlZE1hdE11bCwgX0Z1c2VkTWF0TXVsQXR0cnMsIF9GdXNlZE1hdE11bElucHV0cywgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBUZW5zb3JJbmZvfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kQ1BVfSBmcm9tICcuLi9iYWNrZW5kX2NwdSc7XG5pbXBvcnQge2FwcGx5QWN0aXZhdGlvbn0gZnJvbSAnLi4vdXRpbHMvZnVzZWRfdXRpbHMnO1xuXG5pbXBvcnQge2FkZH0gZnJvbSAnLi9BZGQnO1xuaW1wb3J0IHtiYXRjaE1hdE11bH0gZnJvbSAnLi9CYXRjaE1hdE11bCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBfZnVzZWRNYXRNdWwoYXJnczoge1xuICBpbnB1dHM6IF9GdXNlZE1hdE11bElucHV0cyxcbiAgYXR0cnM6IF9GdXNlZE1hdE11bEF0dHJzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZENQVVxufSkge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7YSwgYiwgYmlhcywgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0c30gPSBpbnB1dHM7XG4gIGNvbnN0IHt0cmFuc3Bvc2VBLCB0cmFuc3Bvc2VCLCBhY3RpdmF0aW9uLCBsZWFreXJlbHVBbHBoYX0gPSBhdHRycztcblxuICBsZXQgY3VycmVudDtcbiAgbGV0IGFkZFJlcztcbiAgbGV0IGFjdGl2YXRpb25SZXM7XG5cbiAgY29uc3QgaW50ZXJtZWRpYXRlczogVGVuc29ySW5mb1tdID0gW107XG5cbiAgY29uc3QgbWF0TXVsUmVzID1cbiAgICAgIGJhdGNoTWF0TXVsKHtpbnB1dHM6IHthLCBifSwgYXR0cnM6IHt0cmFuc3Bvc2VBLCB0cmFuc3Bvc2VCfSwgYmFja2VuZH0pO1xuICBjdXJyZW50ID0gbWF0TXVsUmVzO1xuXG4gIGlmIChiaWFzKSB7XG4gICAgYWRkUmVzID0gYWRkKHtpbnB1dHM6IHthOiBjdXJyZW50LCBiOiBiaWFzfSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgaW50ZXJtZWRpYXRlcy5wdXNoKGN1cnJlbnQpO1xuICAgIGN1cnJlbnQgPSBhZGRSZXM7XG4gIH1cbiAgaWYgKGFjdGl2YXRpb24pIHtcbiAgICBhY3RpdmF0aW9uUmVzID0gYXBwbHlBY3RpdmF0aW9uKFxuICAgICAgICBiYWNrZW5kLCBjdXJyZW50LCBhY3RpdmF0aW9uLCBwcmVsdUFjdGl2YXRpb25XZWlnaHRzLCBsZWFreXJlbHVBbHBoYSk7XG4gICAgaW50ZXJtZWRpYXRlcy5wdXNoKGN1cnJlbnQpO1xuICAgIGN1cnJlbnQgPSBhY3RpdmF0aW9uUmVzO1xuICB9XG5cbiAgZm9yIChjb25zdCBpIG9mIGludGVybWVkaWF0ZXMpIHtcbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKGkpO1xuICB9XG5cbiAgcmV0dXJuIGN1cnJlbnQ7XG59XG5cbmV4cG9ydCBjb25zdCBfZnVzZWRNYXRNdWxDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogX0Z1c2VkTWF0TXVsLFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IF9mdXNlZE1hdE11bCBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmMsXG59O1xuIl19