/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, util } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
import { cast } from '../kernels/Cast';
|
import { complex } from '../kernels/Complex';
|
/**
|
* Template that creates a `KernelFunc` for binary ops.
|
* @param name Kernel name.
|
* @param binaryKernelImpl A `SimpleBinaryKernelImpl` for the kernel.
|
* @param binaryKernelComplexImpl Optional. If exists, represents a
|
* `ComplexBinaryKernelImpl` for the kernel, will be used when input dtype
|
* is `complex64`.
|
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
|
* result has the same dtype as the first input. This is mainly used in
|
* comparison kernels, such as Equal, Less, Greater, etc.
|
*/
|
export function binaryKernelFunc(name, simpleImpl, complexImpl, dtype) {
|
if (complexImpl == null) {
|
return ({ inputs, backend }) => {
|
const { a, b } = inputs;
|
const cpuBackend = backend;
|
assertNotComplex([a, b], name);
|
const aVals = cpuBackend.data.get(a.dataId).values;
|
const bVals = cpuBackend.data.get(b.dataId).values;
|
const decodedAVals = a.dtype === 'string' ?
|
// tslint:disable-next-line: no-any
|
backend_util.fromUint8ToStringArray(aVals) :
|
aVals;
|
const decodedBVals = a.dtype === 'string' ?
|
// tslint:disable-next-line: no-any
|
backend_util.fromUint8ToStringArray(bVals) :
|
bVals;
|
const $dtype = dtype || a.dtype;
|
const [resultData, resultShape] = simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
|
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
|
};
|
}
|
return ({ inputs, backend }) => {
|
const { a, b } = inputs;
|
const cpuBackend = backend;
|
if (a.dtype === 'complex64' || b.dtype === 'complex64') {
|
const $aComplex = cast({ inputs: { x: a }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
|
const $aComplexVals = cpuBackend.data.get($aComplex.dataId);
|
const aReal = $aComplexVals.complexTensorInfos.real;
|
const aImag = $aComplexVals.complexTensorInfos.imag;
|
const aRealVals = cpuBackend.data.get(aReal.dataId).values;
|
const aImagVals = cpuBackend.data.get(aImag.dataId).values;
|
const $bComplex = cast({ inputs: { x: b }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
|
const $bComplexVals = cpuBackend.data.get($bComplex.dataId);
|
const bReal = $bComplexVals.complexTensorInfos.real;
|
const bImag = $bComplexVals.complexTensorInfos.imag;
|
const bRealVals = cpuBackend.data.get(bReal.dataId).values;
|
const bImagVals = cpuBackend.data.get(bImag.dataId).values;
|
const [resultRealData, resultImagData, resultShape] = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals);
|
const resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
|
const resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
|
const result = complex({ inputs: { real: resultReal, imag: resultImag }, backend: cpuBackend });
|
cpuBackend.disposeIntermediateTensorInfo($aComplex);
|
cpuBackend.disposeIntermediateTensorInfo($bComplex);
|
cpuBackend.disposeIntermediateTensorInfo(resultReal);
|
cpuBackend.disposeIntermediateTensorInfo(resultImag);
|
return result;
|
}
|
else {
|
const aVals = cpuBackend.data.get(a.dataId).values;
|
const bVals = cpuBackend.data.get(b.dataId).values;
|
const $dtype = dtype || a.dtype;
|
const [resultData, resultShape] = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);
|
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
|
}
|
};
|
}
|
/**
|
* Template that creates the complex type implementation for binary ops.
|
* Supports broadcast.
|
*/
|
export function createComplexBinaryKernelImpl(op) {
|
return (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) => {
|
const resultShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
const resultSize = util.sizeFromShape(resultShape);
|
const resultRank = resultShape.length;
|
const resultStrides = util.computeStrides(resultShape);
|
const resultRealVals = util.getTypedArrayFromDType('float32', resultSize);
|
const resultImagVals = util.getTypedArrayFromDType('float32', resultSize);
|
const aBroadcastDims = backend_util.getBroadcastDims(aShape, resultShape);
|
const bBroadcastDims = backend_util.getBroadcastDims(bShape, resultShape);
|
const aVals = backend_util.mergeRealAndImagArrays(aRealVals, aImagVals);
|
const bVals = backend_util.mergeRealAndImagArrays(bRealVals, bImagVals);
|
const aRank = aShape.length;
|
const aStrides = util.computeStrides(aShape);
|
const bRank = bShape.length;
|
const bStrides = util.computeStrides(bShape);
|
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
|
for (let i = 0; i < resultRealVals.length; i++) {
|
const aIdx = i % aVals.length;
|
const bIdx = i % bVals.length;
|
const result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
|
resultRealVals[i] = result.real;
|
resultImagVals[i] = result.imag;
|
}
|
}
|
else {
|
for (let i = 0; i < resultRealVals.length; i++) {
|
const loc = util.indexToLoc(i, resultRank, resultStrides);
|
const aLoc = loc.slice(-aRank);
|
aBroadcastDims.forEach(d => aLoc[d] = 0);
|
const aIndex = util.locToIndex(aLoc, aRank, aStrides);
|
const bLoc = loc.slice(-bRank);
|
bBroadcastDims.forEach(d => bLoc[d] = 0);
|
const bIndex = util.locToIndex(bLoc, bRank, bStrides);
|
const opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
|
resultRealVals[i] = opResult.real;
|
resultImagVals[i] = opResult.imag;
|
}
|
}
|
return [resultRealVals, resultImagVals, resultShape];
|
};
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"binary_utils.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/utils/binary_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAkD,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGzG,OAAO,EAAC,gBAAgB,EAAC,MAAM,aAAa,CAAC;AAC7C,OAAO,EAAC,IAAI,EAAC,MAAM,iBAAiB,CAAC;AACrC,OAAO,EAAC,OAAO,EAAC,MAAM,oBAAoB,CAAC;AAI3C;;;;;;;;;;GAUG;AACH,MAAM,UAAU,gBAAgB,CAC5B,IAAY,EAAE,UAAkC,EAChD,WAAqC,EAAE,KAAgB;IACzD,IAAI,WAAW,IAAI,IAAI,EAAE;QACvB,OAAO,CAAC,EAAC,MAAM,EAAE,OAAO,EAAC,EAAE,EAAE;YAC3B,MAAM,EAAC,CAAC,EAAE,CAAC,EAAC,GAAG,MAAsB,CAAC;YACtC,MAAM,UAAU,GAAG,OAAyB,CAAC;YAE7C,gBAAgB,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;YAE/B,MAAM,KAAK,GAAG,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;YACjE,MAAM,KAAK,GAAG,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;YAEjE,MAAM,YAAY,GAAG,CAAC,CAAC,KAAK,KAAK,QAAQ,CAAC,CAAC;gBACvC,mCAAmC;gBACnC,YAAY,CAAC,sBAAsB,CAAC,KAA4B,CAAC,CAAC,CAAC;gBACnE,KAAK,CAAC;YACV,MAAM,YAAY,GAAG,CAAC,CAAC,KAAK,KAAK,QAAQ,CAAC,CAAC;gBACvC,mCAAmC;gBACnC,YAAY,CAAC,sBAAsB,CAAC,KAA4B,CAAC,CAAC,CAAC;gBACnE,KAAK,CAAC;YACV,MAAM,MAAM,GAAG,KAAK,IAAI,CAAC,CAAC,KAAK,CAAC;YAEhC,MAAM,CAAC,UAAU,EAAE,WAAW,CAAC,GAC3B,UAAU,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,YAAY,EAAE,YAAY,EAAE,MAAM,CAAC,CAAC;YAErE,OAAO,UAAU,CAAC,cAAc,CAAC,WAAW,EAAE,MAAM,EAAE,UAAU,CAAC,CAAC;QACpE,CAAC,CAAC;KACH;IAED,OAAO,CAAC,EAAC,MAAM,EAAE,OAAO,EAAC,EAAE,EAAE;QAC3B,MAAM,EAAC,CAAC,EAAE,CAAC,EAAC,GAAG,MAAsB,CAAC;QACtC,MAAM,UAAU,GAAG,OAAyB,CAAC;QAE7C,IAAI,CAAC,CAAC,KAAK,KAAK,WAAW,IAAI,CAAC,CAAC,KAAK,KAAK,WAAW,EAAE;YACtD,MAAM,SAAS,GAAG,IAAI,CAClB,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,CAAC,EAAC,EAAE,OAAO,EAAE,UAAU,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,WAAW,EAAC,EAAC,CAAC,CAAC;YAExE,MAAM,aAAa,GAAG,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC;YAE5D,MAAM,KAAK,GAAG,aAAa,CAAC,kBAAkB,CAAC,IAAI,CAAC;YACpD,MAAM,KAAK,GAAG,aAAa,CAAC,kBAAkB,CAAC,IAAI,CAAC;YAEpD,MAAM,SAAS,GACX,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,MAAsB,CAAC;YAC7D,MAAM,SAAS,GACX,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,MAAsB,CAAC;YAE7D,MAAM,SAAS,GAAG,IAAI,CAClB,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,CAAC,EAAC,EAAE,OAAO,EAAE,UAAU,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,WAAW,EAAC,EAAC,CAAC,CAAC;YAExE,MAAM,aAAa,GAAG,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC;YAE5D,MAAM,KAAK,GAAG,aAAa,CAAC,kBAAkB,CAAC,IAAI,CAAC;YACpD,MAAM,KAAK,GAAG,aAAa,CAAC,kBAAkB,CAAC,IAAI,CAAC;YAEpD,MAAM,SAAS,GACX,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,MAAsB,CAAC;YAC7D,MAAM,SAAS,GACX,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,MAAsB,CAAC;YAE7D,MAAM,CAAC,cAAc,EAAE,cAAc,EAAE,WAAW,CAAC,GAAG,WAAW,CAC7D,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAElE,MAAM,UAAU,GACZ,UAAU,CAAC,cAAc,CAAC,WAAW,EAAE,SAAS,EAAE,cAAc,CAAC,CAAC;YAEtE,MAAM,UAAU,GACZ,UAAU,CAAC,cAAc,CAAC,WAAW,EAAE,SAAS,EAAE,cAAc,CAAC,CAAC;YAEtE,MAAM,MAAM,GAAG,OAAO,CAClB,EAAC,MAAM,EAAE,EAAC,IAAI,EAAE,UAAU,EAAE,IAAI,EAAE,UAAU,EAAC,EAAE,OAAO,EAAE,UAAU,EAAC,CAAC,CAAC;YAEzE,UAAU,CAAC,6BAA6B,CAAC,SAAS,CAAC,CAAC;YACpD,UAAU,CAAC,6BAA6B,CAAC,SAAS,CAAC,CAAC;YACpD,UAAU,CAAC,6BAA6B,CAAC,UAAU,CAAC,CAAC;YACrD,UAAU,CAAC,6BAA6B,CAAC,UAAU,CAAC,CAAC;YAErD,OAAO,MAAM,CAAC;SACf;aAAM;YACL,MAAM,KAAK,GAAG,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;YACjE,MAAM,KAAK,GAAG,UAAU,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;YAEjE,MAAM,MAAM,GAAG,KAAK,IAAI,CAAC,CAAC,KAAK,CAAC;YAEhC,MAAM,CAAC,UAAU,EAAE,WAAW,CAAC,GAC3B,UAAU,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,CAAC,CAAC;YAEvD,OAAO,UAAU,CAAC,cAAc,CAAC,WAAW,EAAE,MAAM,EAAE,UAAU,CAAC,CAAC;SACnE;IACH,CAAC,CAAC;AACJ,CAAC;AAED;;;GAGG;AACH,MAAM,UAAU,6BAA6B,CAAC,EAA0B;IAEtE,OAAO,CAAC,MAAgB,EAAE,MAAgB,EAAE,SAAuB,EAC3D,SAAuB,EAAE,SAAuB,EAChD,SAAuB,EAAsC,EAAE;QACrE,MAAM,WAAW,GAAG,YAAY,CAAC,0BAA0B,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;QAC5E,MAAM,UAAU,GAAG,IAAI,CAAC,aAAa,CAAC,WAAW,CAAC,CAAC;QACnD,MAAM,UAAU,GAAG,WAAW,CAAC,MAAM,CAAC;QACtC,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,WAAW,CAAC,CAAC;QAEvD,MAAM,cAAc,GAAG,IAAI,CAAC,sBAAsB,CAAC,SAAS,EAAE,UAAU,CAAC,CAAC;QAC1E,MAAM,cAAc,GAAG,IAAI,CAAC,sBAAsB,CAAC,SAAS,EAAE,UAAU,CAAC,CAAC;QAE1E,MAAM,cAAc,GAAG,YAAY,CAAC,gBAAgB,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAC1E,MAAM,cAAc,GAAG,YAAY,CAAC,gBAAgB,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAE1E,MAAM,KAAK,GAAG,YAAY,CAAC,sBAAsB,CAAC,SAAS,EAAE,SAAS,CAAC,CAAC;QACxE,MAAM,KAAK,GAAG,YAAY,CAAC,sBAAsB,CAAC,SAAS,EAAE,SAAS,CAAC,CAAC;QAExE,MAAM,KAAK,GAAG,MAAM,CAAC,MAAM,CAAC;QAC5B,MAAM,QAAQ,GAAG,IAAI,CAAC,cAAc,CAAC,MAAM,CAAC,CAAC;QAE7C,MAAM,KAAK,GAAG,MAAM,CAAC,MAAM,CAAC;QAC5B,MAAM,QAAQ,GAAG,IAAI,CAAC,cAAc,CAAC,MAAM,CAAC,CAAC;QAE7C,IAAI,cAAc,CAAC,MAAM,GAAG,cAAc,CAAC,MAAM,KAAK,CAAC,EAAE;YACvD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,cAAc,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBAC9C,MAAM,IAAI,GAAG,CAAC,GAAG,KAAK,CAAC,MAAM,CAAC;gBAC9B,MAAM,IAAI,GAAG,CAAC,GAAG,KAAK,CAAC,MAAM,CAAC;gBAE9B,MAAM,MAAM,GACR,EAAE,CAAC,KAAK,CAAC,IAAI,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,IAAI,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,IAAI,GAAG,CAAC,CAAC,EACrD,KAAK,CAAC,IAAI,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;gBAE5B,cAAc,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,IAAI,CAAC;gBAChC,cAAc,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,IAAI,CAAC;aACjC;SACF;aAAM;YACL,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,cAAc,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBAC9C,MAAM,GAAG,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC,EAAE,UAAU,EAAE,aAAa,CAAC,CAAC;gBAE1D,MAAM,IAAI,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC;gBAC/B,cAAc,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBACzC,MAAM,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,KAAK,EAAE,QAAQ,CAAC,CAAC;gBAEtD,MAAM,IAAI,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC;gBAC/B,cAAc,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBACzC,MAAM,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,KAAK,EAAE,QAAQ,CAAC,CAAC;gBAEtD,MAAM,QAAQ,GACV,EAAE,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,MAAM,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,EAC3D,KAAK,CAAC,MAAM,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;gBAE9B,cAAc,CAAC,CAAC,CAAC,GAAG,QAAQ,CAAC,IAAI,CAAC;gBAClC,cAAc,CAAC,CAAC,CAAC,GAAG,QAAQ,CAAC,IAAI,CAAC;aACnC;SACF;QACD,OAAO,CAAC,cAAc,EAAE,cAAc,EAAE,WAAW,CAAC,CAAC;IACvD,CAAC,CAAC;AACJ,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, BinaryInputs, DataType, KernelFunc, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\nimport {cast} from '../kernels/Cast';\nimport {complex} from '../kernels/Complex';\n\nimport {ComplexBinaryKernelImpl, ComplexBinaryOperation, SimpleBinaryKernelImpl} from './binary_types';\n\n/**\n * Template that creates a `KernelFunc` for binary ops.\n * @param name Kernel name.\n * @param binaryKernelImpl A `SimpleBinaryKernelImpl` for the kernel.\n * @param binaryKernelComplexImpl Optional. If exists, represents a\n *     `ComplexBinaryKernelImpl` for the kernel, will be used when input dtype\n *     is `complex64`.\n * @param dtype Optional. If set, the result has this dtype. Otherwise, the\n *     result has the same dtype as the first input. This is mainly used in\n *     comparison kernels, such as Equal, Less, Greater, etc.\n */\nexport function binaryKernelFunc(\n    name: string, simpleImpl: SimpleBinaryKernelImpl,\n    complexImpl?: ComplexBinaryKernelImpl, dtype?: DataType): KernelFunc {\n  if (complexImpl == null) {\n    return ({inputs, backend}) => {\n      const {a, b} = inputs as BinaryInputs;\n      const cpuBackend = backend as MathBackendCPU;\n\n      assertNotComplex([a, b], name);\n\n      const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;\n      const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;\n\n      const decodedAVals = a.dtype === 'string' ?\n          // tslint:disable-next-line: no-any\n          backend_util.fromUint8ToStringArray(aVals as any as Uint8Array[]) :\n          aVals;\n      const decodedBVals = a.dtype === 'string' ?\n          // tslint:disable-next-line: no-any\n          backend_util.fromUint8ToStringArray(bVals as any as Uint8Array[]) :\n          bVals;\n      const $dtype = dtype || a.dtype;\n\n      const [resultData, resultShape] =\n          simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);\n\n      return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);\n    };\n  }\n\n  return ({inputs, backend}) => {\n    const {a, b} = inputs as BinaryInputs;\n    const cpuBackend = backend as MathBackendCPU;\n\n    if (a.dtype === 'complex64' || b.dtype === 'complex64') {\n      const $aComplex = cast(\n          {inputs: {x: a}, backend: cpuBackend, attrs: {dtype: 'complex64'}});\n\n      const $aComplexVals = cpuBackend.data.get($aComplex.dataId);\n\n      const aReal = $aComplexVals.complexTensorInfos.real;\n      const aImag = $aComplexVals.complexTensorInfos.imag;\n\n      const aRealVals =\n          cpuBackend.data.get(aReal.dataId).values as Float32Array;\n      const aImagVals =\n          cpuBackend.data.get(aImag.dataId).values as Float32Array;\n\n      const $bComplex = cast(\n          {inputs: {x: b}, backend: cpuBackend, attrs: {dtype: 'complex64'}});\n\n      const $bComplexVals = cpuBackend.data.get($bComplex.dataId);\n\n      const bReal = $bComplexVals.complexTensorInfos.real;\n      const bImag = $bComplexVals.complexTensorInfos.imag;\n\n      const bRealVals =\n          cpuBackend.data.get(bReal.dataId).values as Float32Array;\n      const bImagVals =\n          cpuBackend.data.get(bImag.dataId).values as Float32Array;\n\n      const [resultRealData, resultImagData, resultShape] = complexImpl(\n          a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals);\n\n      const resultReal =\n          cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);\n\n      const resultImag =\n          cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);\n\n      const result = complex(\n          {inputs: {real: resultReal, imag: resultImag}, backend: cpuBackend});\n\n      cpuBackend.disposeIntermediateTensorInfo($aComplex);\n      cpuBackend.disposeIntermediateTensorInfo($bComplex);\n      cpuBackend.disposeIntermediateTensorInfo(resultReal);\n      cpuBackend.disposeIntermediateTensorInfo(resultImag);\n\n      return result;\n    } else {\n      const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;\n      const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;\n\n      const $dtype = dtype || a.dtype;\n\n      const [resultData, resultShape] =\n          simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);\n\n      return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);\n    }\n  };\n}\n\n/**\n * Template that creates the complex type implementation for binary ops.\n * Supports broadcast.\n */\nexport function createComplexBinaryKernelImpl(op: ComplexBinaryOperation):\n    ComplexBinaryKernelImpl {\n  return (aShape: number[], bShape: number[], aRealVals: Float32Array,\n          aImagVals: Float32Array, bRealVals: Float32Array,\n          bImagVals: Float32Array): [TypedArray, TypedArray, number[]] => {\n    const resultShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);\n    const resultSize = util.sizeFromShape(resultShape);\n    const resultRank = resultShape.length;\n    const resultStrides = util.computeStrides(resultShape);\n\n    const resultRealVals = util.getTypedArrayFromDType('float32', resultSize);\n    const resultImagVals = util.getTypedArrayFromDType('float32', resultSize);\n\n    const aBroadcastDims = backend_util.getBroadcastDims(aShape, resultShape);\n    const bBroadcastDims = backend_util.getBroadcastDims(bShape, resultShape);\n\n    const aVals = backend_util.mergeRealAndImagArrays(aRealVals, aImagVals);\n    const bVals = backend_util.mergeRealAndImagArrays(bRealVals, bImagVals);\n\n    const aRank = aShape.length;\n    const aStrides = util.computeStrides(aShape);\n\n    const bRank = bShape.length;\n    const bStrides = util.computeStrides(bShape);\n\n    if (aBroadcastDims.length + bBroadcastDims.length === 0) {\n      for (let i = 0; i < resultRealVals.length; i++) {\n        const aIdx = i % aVals.length;\n        const bIdx = i % bVals.length;\n\n        const result =\n            op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2],\n               bVals[bIdx * 2 + 1]);\n\n        resultRealVals[i] = result.real;\n        resultImagVals[i] = result.imag;\n      }\n    } else {\n      for (let i = 0; i < resultRealVals.length; i++) {\n        const loc = util.indexToLoc(i, resultRank, resultStrides);\n\n        const aLoc = loc.slice(-aRank);\n        aBroadcastDims.forEach(d => aLoc[d] = 0);\n        const aIndex = util.locToIndex(aLoc, aRank, aStrides);\n\n        const bLoc = loc.slice(-bRank);\n        bBroadcastDims.forEach(d => bLoc[d] = 0);\n        const bIndex = util.locToIndex(bLoc, bRank, bStrides);\n\n        const opResult =\n            op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2],\n               bVals[bIndex * 2 + 1]);\n\n        resultRealVals[i] = opResult.real;\n        resultImagVals[i] = opResult.imag;\n      }\n    }\n    return [resultRealVals, resultImagVals, resultShape];\n  };\n}\n"]}
|