gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, env, Multiply } from '@tensorflow/tfjs-core';
import * as binaryop_complex_gpu from '../binaryop_complex_gpu';
import { BinaryOpComplexProgram } from '../binaryop_complex_gpu';
import { BinaryOpProgram } from '../binaryop_gpu';
import { BinaryOpPackedProgram } from '../binaryop_packed_gpu';
import { multiplyImplCPU as cpuMultiply } from '../kernel_utils/shared';
import { complex } from './Complex';
const MUL = 'return a * b;';
export function multiply(args) {
    const { inputs, backend } = args;
    const { a, b } = inputs;
    const dtype = backend_util.upcastType(a.dtype, b.dtype);
    if (a.dtype === 'complex64') {
        const aData = backend.texData.get(a.dataId);
        const bData = backend.texData.get(b.dataId);
        const realProgram = new BinaryOpComplexProgram(binaryop_complex_gpu.COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
        const imagProgram = new BinaryOpComplexProgram(binaryop_complex_gpu.COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
        const inputs = [
            {
                dataId: aData.complexTensorInfos.real.dataId,
                dtype: aData.complexTensorInfos.real.dtype,
                shape: a.shape
            },
            {
                dataId: aData.complexTensorInfos.imag.dataId,
                dtype: aData.complexTensorInfos.imag.dtype,
                shape: a.shape
            },
            {
                dataId: bData.complexTensorInfos.real.dataId,
                dtype: bData.complexTensorInfos.real.dtype,
                shape: b.shape
            },
            {
                dataId: bData.complexTensorInfos.imag.dataId,
                dtype: bData.complexTensorInfos.imag.dtype,
                shape: b.shape
            }
        ];
        const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
        const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
        const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend });
        backend.disposeIntermediateTensorInfo(realPart);
        backend.disposeIntermediateTensorInfo(imagPart);
        // TODO(annxingyuan): CPU forwarding for complex inputs.
        return complexOutput;
    }
    if (backend.shouldExecuteOnCPU([a, b])) {
        const aData = backend.texData.get(a.dataId);
        const bData = backend.texData.get(b.dataId);
        const [outValues, outShape] = cpuMultiply(a.shape, b.shape, aData.values, bData.values, dtype);
        const out = backend.makeTensorInfo(outShape, dtype);
        const outData = backend.texData.get(out.dataId);
        outData.values = outValues;
        return out;
    }
    let program;
    if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
        program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
    }
    else {
        program = new BinaryOpProgram(MUL, a.shape, b.shape);
    }
    return backend.runWebGLProgram(program, [a, b], dtype);
}
export const multiplyConfig = {
    kernelName: Multiply,
    backendName: 'webgl',
    kernelFunc: multiply
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Multiply.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/Multiply.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAgB,GAAG,EAAgB,QAAQ,EAAyB,MAAM,uBAAuB,CAAC;AAGtH,OAAO,KAAK,oBAAoB,MAAM,yBAAyB,CAAC;AAChE,OAAO,EAAC,sBAAsB,EAAC,MAAM,yBAAyB,CAAC;AAC/D,OAAO,EAAC,eAAe,EAAC,MAAM,iBAAiB,CAAC;AAChD,OAAO,EAAC,qBAAqB,EAAC,MAAM,wBAAwB,CAAC;AAC7D,OAAO,EAAC,eAAe,IAAI,WAAW,EAAC,MAAM,wBAAwB,CAAC;AAEtE,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC,MAAM,GAAG,GAAG,eAAe,CAAC;AAE5B,MAAM,UAAU,QAAQ,CACpB,IAAuD;IACzD,MAAM,EAAC,MAAM,EAAE,OAAO,EAAC,GAAG,IAAI,CAAC;IAC/B,MAAM,EAAC,CAAC,EAAE,CAAC,EAAC,GAAG,MAAM,CAAC;IACtB,MAAM,KAAK,GAAG,YAAY,CAAC,UAAU,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;IAExD,IAAI,CAAC,CAAC,KAAK,KAAK,WAAW,EAAE;QAC3B,MAAM,KAAK,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;QAC5C,MAAM,KAAK,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;QAE5C,MAAM,WAAW,GAAG,IAAI,sBAAsB,CAC1C,oBAAoB,CAAC,gBAAgB,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;QAClE,MAAM,WAAW,GAAG,IAAI,sBAAsB,CAC1C,oBAAoB,CAAC,gBAAgB,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;QAElE,MAAM,MAAM,GAAG;YACb;gBACE,MAAM,EAAE,KAAK,CAAC,kBAAkB,CAAC,IAAI,CAAC,MAAM;gBAC5C,KAAK,EAAE,KAAK,CAAC,kBAAkB,CAAC,IAAI,CAAC,KAAK;gBAC1C,KAAK,EAAE,CAAC,CAAC,KAAK;aACf;YACD;gBACE,MAAM,EAAE,KAAK,CAAC,kBAAkB,CAAC,IAAI,CAAC,MAAM;gBAC5C,KAAK,EAAE,KAAK,CAAC,kBAAkB,CAAC,IAAI,CAAC,KAAK;gBAC1C,KAAK,EAAE,CAAC,CAAC,KAAK;aACf;YACD;gBACE,MAAM,EAAE,KAAK,CAAC,kBAAkB,CAAC,IAAI,CAAC,MAAM;gBAC5C,KAAK,EAAE,KAAK,CAAC,kBAAkB,CAAC,IAAI,CAAC,KAAK;gBAC1C,KAAK,EAAE,CAAC,CAAC,KAAK;aACf;YACD;gBACE,MAAM,EAAE,KAAK,CAAC,kBAAkB,CAAC,IAAI,CAAC,MAAM;gBAC5C,KAAK,EAAE,KAAK,CAAC,kBAAkB,CAAC,IAAI,CAAC,KAAK;gBAC1C,KAAK,EAAE,CAAC,CAAC,KAAK;aACf;SACF,CAAC;QAEF,MAAM,QAAQ,GAAG,OAAO,CAAC,eAAe,CAAC,WAAW,EAAE,MAAM,EAAE,SAAS,CAAC,CAAC;QACzE,MAAM,QAAQ,GAAG,OAAO,CAAC,eAAe,CAAC,WAAW,EAAE,MAAM,EAAE,SAAS,CAAC,CAAC;QAEzE,MAAM,aAAa,GACf,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,IAAI,EAAE,QAAQ,EAAE,IAAI,EAAE,QAAQ,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;QAEjE,OAAO,CAAC,6BAA6B,CAAC,QAAQ,CAAC,CAAC;QAChD,OAAO,CAAC,6BAA6B,CAAC,QAAQ,CAAC,CAAC;QAEhD,wDAAwD;QACxD,OAAO,aAAa,CAAC;KACtB;IAED,IAAI,OAAO,CAAC,kBAAkB,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE;QACtC,MAAM,KAAK,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;QAC5C,MAAM,KAAK,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;QAC5C,MAAM,CAAC,SAAS,EAAE,QAAQ,CAAC,GAAG,WAAW,CACrC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,MAAoB,EAC5C,KAAK,CAAC,MAAoB,EAAE,KAAK,CAAC,CAAC;QAEvC,MAAM,GAAG,GAAG,OAAO,CAAC,cAAc,CAAC,QAAQ,EAAE,KAAK,CAAC,CAAC;QACpD,MAAM,OAAO,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC;QAChD,OAAO,CAAC,MAAM,GAAG,SAAS,CAAC;QAC3B,OAAO,GAAG,CAAC;KACZ;IAED,IAAI,OAA8C,CAAC;IACnD,IAAI,GAAG,EAAE,CAAC,OAAO,CAAC,8BAA8B,CAAC,EAAE;QACjD,OAAO,GAAG,IAAI,qBAAqB,CAAC,GAAG,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;KAC5D;SAAM;QACL,OAAO,GAAG,IAAI,eAAe,CAAC,GAAG,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;KACtD;IAED,OAAO,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;AACzD,CAAC;AAED,MAAM,CAAC,MAAM,cAAc,GAAiB;IAC1C,UAAU,EAAE,QAAQ;IACpB,WAAW,EAAE,OAAO;IACpB,UAAU,EAAE,QAAQ;CACrB,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, BinaryInputs, env, KernelConfig, Multiply, TensorInfo, TypedArray} from '@tensorflow/tfjs-core';\n\nimport {MathBackendWebGL} from '../backend_webgl';\nimport * as binaryop_complex_gpu from '../binaryop_complex_gpu';\nimport {BinaryOpComplexProgram} from '../binaryop_complex_gpu';\nimport {BinaryOpProgram} from '../binaryop_gpu';\nimport {BinaryOpPackedProgram} from '../binaryop_packed_gpu';\nimport {multiplyImplCPU as cpuMultiply} from '../kernel_utils/shared';\n\nimport {complex} from './Complex';\n\nconst MUL = 'return a * b;';\n\nexport function multiply(\n    args: {inputs: BinaryInputs, backend: MathBackendWebGL}): TensorInfo {\n  const {inputs, backend} = args;\n  const {a, b} = inputs;\n  const dtype = backend_util.upcastType(a.dtype, b.dtype);\n\n  if (a.dtype === 'complex64') {\n    const aData = backend.texData.get(a.dataId);\n    const bData = backend.texData.get(b.dataId);\n\n    const realProgram = new BinaryOpComplexProgram(\n        binaryop_complex_gpu.COMPLEX_MULTIPLY.REAL, a.shape, b.shape);\n    const imagProgram = new BinaryOpComplexProgram(\n        binaryop_complex_gpu.COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);\n\n    const inputs = [\n      {\n        dataId: aData.complexTensorInfos.real.dataId,\n        dtype: aData.complexTensorInfos.real.dtype,\n        shape: a.shape\n      },\n      {\n        dataId: aData.complexTensorInfos.imag.dataId,\n        dtype: aData.complexTensorInfos.imag.dtype,\n        shape: a.shape\n      },\n      {\n        dataId: bData.complexTensorInfos.real.dataId,\n        dtype: bData.complexTensorInfos.real.dtype,\n        shape: b.shape\n      },\n      {\n        dataId: bData.complexTensorInfos.imag.dataId,\n        dtype: bData.complexTensorInfos.imag.dtype,\n        shape: b.shape\n      }\n    ];\n\n    const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');\n    const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');\n\n    const complexOutput =\n        complex({inputs: {real: realPart, imag: imagPart}, backend});\n\n    backend.disposeIntermediateTensorInfo(realPart);\n    backend.disposeIntermediateTensorInfo(imagPart);\n\n    // TODO(annxingyuan): CPU forwarding for complex inputs.\n    return complexOutput;\n  }\n\n  if (backend.shouldExecuteOnCPU([a, b])) {\n    const aData = backend.texData.get(a.dataId);\n    const bData = backend.texData.get(b.dataId);\n    const [outValues, outShape] = cpuMultiply(\n        a.shape, b.shape, aData.values as TypedArray,\n        bData.values as TypedArray, dtype);\n\n    const out = backend.makeTensorInfo(outShape, dtype);\n    const outData = backend.texData.get(out.dataId);\n    outData.values = outValues;\n    return out;\n  }\n\n  let program: BinaryOpProgram|BinaryOpPackedProgram;\n  if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {\n    program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);\n  } else {\n    program = new BinaryOpProgram(MUL, a.shape, b.shape);\n  }\n\n  return backend.runWebGLProgram(program, [a, b], dtype);\n}\n\nexport const multiplyConfig: KernelConfig = {\n  kernelName: Multiply,\n  backendName: 'webgl',\n  kernelFunc: multiply\n};\n"]}