/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import * as tf from '@tensorflow/tfjs-core';
|
import { Cast, util } from '@tensorflow/tfjs-core';
|
import { castImplCPU } from '../kernel_utils/shared';
|
import { complex } from './Complex';
|
import { identity } from './Identity';
|
import { notEqual } from './NotEqual';
|
import { real } from './Real';
|
import { int } from '../kernel_utils/int';
|
export function cast(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { dtype } = attrs;
|
// Casting to complex64.
|
if (dtype === 'complex64') {
|
if (x.dtype === 'complex64') {
|
return identity({ inputs: { x }, backend });
|
}
|
// TODO(annxingyuan): Import kernel function once zeros is modularized.
|
const zerosTensor = tf.zeros(x.shape);
|
const floatX = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
|
const result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend });
|
zerosTensor.dispose();
|
backend.disposeIntermediateTensorInfo(floatX);
|
return result;
|
}
|
// Casting from complex64
|
if (x.dtype === 'complex64') {
|
const realPart = real({ inputs: { input: x }, backend });
|
const result = cast({ inputs: { x: realPart }, backend, attrs: { dtype } });
|
backend.disposeIntermediateTensorInfo(realPart);
|
return result;
|
}
|
if (!util.hasEncodingLoss(x.dtype, dtype)) {
|
// We don't change the underlying data, since we cast to higher
|
// precision.
|
const result = identity({ inputs: { x }, backend });
|
return { dataId: result.dataId, shape: result.shape, dtype };
|
}
|
if (backend.shouldExecuteOnCPU([x])) {
|
const values = backend.texData.get(x.dataId).values;
|
const [resultShape, resultType, resultData] = castImplCPU(values, x.shape, x.dtype, dtype);
|
return backend.makeTensorInfo(resultShape, resultType, resultData);
|
}
|
if (dtype === 'int32') {
|
return int(x, backend);
|
}
|
if (dtype === 'bool') {
|
const zerosTensorInfo = backend.makeTensorInfo([], 'bool', util.getTypedArrayFromDType('bool', 1));
|
const binaryInputs = { a: x, b: zerosTensorInfo };
|
const result = notEqual({ inputs: binaryInputs, backend });
|
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
|
return result;
|
}
|
throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
|
}
|
export const castConfig = {
|
kernelName: Cast,
|
backendName: 'webgl',
|
kernelFunc: cast
|
};
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Cast.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/Cast.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,KAAK,EAAE,MAAM,uBAAuB,CAAC;AAC5C,OAAO,EAAe,IAAI,EAA2E,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGxI,OAAO,EAAC,WAAW,EAAC,MAAM,wBAAwB,CAAC;AACnD,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAClC,OAAO,EAAC,QAAQ,EAAC,MAAM,YAAY,CAAC;AACpC,OAAO,EAAC,QAAQ,EAAC,MAAM,YAAY,CAAC;AACpC,OAAO,EAAC,IAAI,EAAC,MAAM,QAAQ,CAAC;AAE5B,OAAO,EAAC,GAAG,EAAC,MAAM,qBAAqB,CAAC;AAExC,MAAM,UAAU,IAAI,CAChB,IAAuE;IAEzE,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAC,GAAG,MAAM,CAAC;IACnB,MAAM,EAAC,KAAK,EAAC,GAAG,KAAK,CAAC;IAEtB,wBAAwB;IACxB,IAAI,KAAK,KAAK,WAAW,EAAE;QACzB,IAAI,CAAC,CAAC,KAAK,KAAK,WAAW,EAAE;YAC3B,OAAO,QAAQ,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;SACzC;QAED,uEAAuE;QACvE,MAAM,WAAW,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QACtC,MAAM,MAAM,GAAG,IAAI,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,SAAS,EAAC,EAAC,CAAC,CAAC;QAEvE,MAAM,MAAM,GACR,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,WAAW,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;QAElE,WAAW,CAAC,OAAO,EAAE,CAAC;QACtB,OAAO,CAAC,6BAA6B,CAAC,MAAM,CAAC,CAAC;QAE9C,OAAO,MAAM,CAAC;KACf;IAED,yBAAyB;IACzB,IAAI,CAAC,CAAC,KAAK,KAAK,WAAW,EAAE;QAC3B,MAAM,QAAQ,GAAG,IAAI,CAAC,EAAC,MAAM,EAAE,EAAC,KAAK,EAAE,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;QACrD,MAAM,MAAM,GAAG,IAAI,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,QAAQ,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAC,EAAC,CAAC,CAAC;QACtE,OAAO,CAAC,6BAA6B,CAAC,QAAQ,CAAC,CAAC;QAChD,OAAO,MAAM,CAAC;KACf;IAED,IAAI,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,EAAE;QACzC,+DAA+D;QAC/D,aAAa;QACb,MAAM,MAAM,GAAG,QAAQ,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;QAChD,OAAO,EAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,KAAK,EAAE,KAAK,EAAC,CAAC;KAC5D;IAED,IAAI,OAAO,CAAC,kBAAkB,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE;QACnC,MAAM,MAAM,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;QAClE,MAAM,CAAC,WAAW,EAAE,UAAU,EAAE,UAAU,CAAC,GACvC,WAAW,CAAC,MAAM,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC;QACjD,OAAO,OAAO,CAAC,cAAc,CAAC,WAAW,EAAE,UAAU,EAAE,UAAU,CAAC,CAAC;KACpE;IAED,IAAI,KAAK,KAAK,OAAO,EAAE;QACrB,OAAO,GAAG,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;KACxB;IAED,IAAI,KAAK,KAAK,MAAM,EAAE;QACpB,MAAM,eAAe,GAAG,OAAO,CAAC,cAAc,CAC1C,EAAE,EAAE,MAAM,EAAE,IAAI,CAAC,sBAAsB,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC;QAExD,MAAM,YAAY,GAAiB,EAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,eAAe,EAAC,CAAC;QAE9D,MAAM,MAAM,GAAG,QAAQ,CAAC,EAAC,MAAM,EAAE,YAAY,EAAE,OAAO,EAAC,CAAe,CAAC;QACvE,OAAO,CAAC,6BAA6B,CAAC,eAAe,CAAC,CAAC;QACvD,OAAO,MAAM,CAAC;KACf;IAED,MAAM,IAAI,KAAK,CAAC,iCAAiC,CAAC,CAAC,KAAK,OAAO,KAAK,EAAE,CAAC,CAAC;AAC1E,CAAC;AAED,MAAM,CAAC,MAAM,UAAU,GAAiB;IACtC,UAAU,EAAE,IAAI;IAChB,WAAW,EAAE,OAAO;IACpB,UAAU,EAAE,IAA6B;CAC1C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport * as tf from '@tensorflow/tfjs-core';\nimport {BinaryInputs, Cast, CastAttrs, CastInputs, KernelConfig, KernelFunc, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {castImplCPU} from '../kernel_utils/shared';\nimport {complex} from './Complex';\nimport {identity} from './Identity';\nimport {notEqual} from './NotEqual';\nimport {real} from './Real';\n\nimport {int} from '../kernel_utils/int';\n\nexport function cast(\n    args: {inputs: CastInputs, backend: MathBackendWebGL, attrs: CastAttrs}):\n    TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {x} = inputs;\n  const {dtype} = attrs;\n\n  // Casting to complex64.\n  if (dtype === 'complex64') {\n    if (x.dtype === 'complex64') {\n      return identity({inputs: {x}, backend});\n    }\n\n    // TODO(annxingyuan): Import kernel function once zeros is modularized.\n    const zerosTensor = tf.zeros(x.shape);\n    const floatX = cast({inputs: {x}, backend, attrs: {dtype: 'float32'}});\n\n    const result =\n        complex({inputs: {real: floatX, imag: zerosTensor}, backend});\n\n    zerosTensor.dispose();\n    backend.disposeIntermediateTensorInfo(floatX);\n\n    return result;\n  }\n\n  // Casting from complex64\n  if (x.dtype === 'complex64') {\n    const realPart = real({inputs: {input: x}, backend});\n    const result = cast({inputs: {x: realPart}, backend, attrs: {dtype}});\n    backend.disposeIntermediateTensorInfo(realPart);\n    return result;\n  }\n\n  if (!util.hasEncodingLoss(x.dtype, dtype)) {\n    // We don't change the underlying data, since we cast to higher\n    // precision.\n    const result = identity({inputs: {x}, backend});\n    return {dataId: result.dataId, shape: result.shape, dtype};\n  }\n\n  if (backend.shouldExecuteOnCPU([x])) {\n    const values = backend.texData.get(x.dataId).values as TypedArray;\n    const [resultShape, resultType, resultData] =\n        castImplCPU(values, x.shape, x.dtype, dtype);\n    return backend.makeTensorInfo(resultShape, resultType, resultData);\n  }\n\n  if (dtype === 'int32') {\n    return int(x, backend);\n  }\n\n  if (dtype === 'bool') {\n    const zerosTensorInfo = backend.makeTensorInfo(\n        [], 'bool', util.getTypedArrayFromDType('bool', 1));\n\n    const binaryInputs: BinaryInputs = {a: x, b: zerosTensorInfo};\n\n    const result = notEqual({inputs: binaryInputs, backend}) as TensorInfo;\n    backend.disposeIntermediateTensorInfo(zerosTensorInfo);\n    return result;\n  }\n\n  throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);\n}\n\nexport const castConfig: KernelConfig = {\n  kernelName: Cast,\n  backendName: 'webgl',\n  kernelFunc: cast as unknown as KernelFunc\n};\n"]}
|