/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { Cast, util } from '@tensorflow/tfjs-core'; import { createSimpleBinaryKernelImpl } from '../utils/binary_impl'; import { zeros } from '../utils/zeros_impl'; import { complex } from './Complex'; import { identity } from './Identity'; import { real } from './Real'; export function castImpl(values, shape, inputType, dtype) { if (dtype === 'int32') { const resultValues = Int32Array.from(values); return [shape, 'int32', resultValues]; } if (dtype === 'bool') { // This is essentially the result of notEqual(x, 0). We avoid using // kernel notEqual to avoid circular dependency, i.e. binary_utils -> // cast -> notEqual -> binary_utils. const zero = util.toTypedArray([0], inputType); const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool'); return [resultShape, 'bool', resultData]; } throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`); } export function cast(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { dtype } = attrs; // Casting to complex64. if (dtype === 'complex64') { if (x.dtype === 'complex64') { return identity({ inputs: { x }, backend }); } const zerosTensorInfo = zeros(backend, x.shape, x.dtype); const floatX = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } }); const result = complex({ inputs: { real: floatX, imag: zerosTensorInfo }, backend }); backend.disposeIntermediateTensorInfo(zerosTensorInfo); backend.disposeIntermediateTensorInfo(floatX); return result; } // Casting from complex64 if (x.dtype === 'complex64') { const realPart = real({ inputs: { input: x }, backend }); const result = cast({ inputs: { x: realPart }, backend, attrs: { dtype } }); backend.disposeIntermediateTensorInfo(realPart); return result; } if (!util.hasEncodingLoss(x.dtype, dtype)) { // We don't change the underlying data, since we cast to higher // precision. const result = identity({ inputs: { x }, backend }); return { dataId: result.dataId, shape: result.shape, dtype }; } const values = backend.data.get(x.dataId).values; const [resultShape, resultType, resultData] = castImpl(values, x.shape, x.dtype, dtype); return backend.makeTensorInfo(resultShape, resultType, resultData); } export const castConfig = { kernelName: Cast, backendName: 'cpu', kernelFunc: cast }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Cast.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/Cast.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAC,IAAI,EAAqF,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGpI,OAAO,EAAC,4BAA4B,EAAC,MAAM,sBAAsB,CAAC;AAClE,OAAO,EAAC,KAAK,EAAC,MAAM,qBAAqB,CAAC;AAE1C,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAClC,OAAO,EAAC,QAAQ,EAAC,MAAM,YAAY,CAAC;AACpC,OAAO,EAAC,IAAI,EAAC,MAAM,QAAQ,CAAC;AAE5B,MAAM,UAAU,QAAQ,CACpB,MAAkB,EAAE,KAAe,EAAE,SAAmB,EACxD,KAAe;IACjB,IAAI,KAAK,KAAK,OAAO,EAAE;QACrB,MAAM,YAAY,GAAG,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;QAC7C,OAAO,CAAC,KAAK,EAAE,OAAO,EAAE,YAAY,CAAC,CAAC;KACvC;IAED,IAAI,KAAK,KAAK,MAAM,EAAE;QACpB,mEAAmE;QACnE,qEAAqE;QACrE,oCAAoC;QACpC,MAAM,IAAI,GAAG,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC;QAE/C,MAAM,CAAC,UAAU,EAAE,WAAW,CAAC,GAAG,4BAA4B,CAC1D,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,EAAE,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,CAAC,CAAC;QAElE,OAAO,CAAC,WAAW,EAAE,MAAM,EAAE,UAAU,CAAC,CAAC;KAC1C;IACD,MAAM,IAAI,KAAK,CAAC,iCAAiC,SAAS,OAAO,KAAK,EAAE,CAAC,CAAC;AAC5E,CAAC;AAED,MAAM,UAAU,IAAI,CAChB,IAAqE;IAEvE,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAC,GAAG,MAAM,CAAC;IACnB,MAAM,EAAC,KAAK,EAAC,GAAG,KAAK,CAAC;IAEtB,wBAAwB;IACxB,IAAI,KAAK,KAAK,WAAW,EAAE;QACzB,IAAI,CAAC,CAAC,KAAK,KAAK,WAAW,EAAE;YAC3B,OAAO,QAAQ,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;SACzC;QAED,MAAM,eAAe,GAAG,KAAK,CAAC,OAAO,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;QACzD,MAAM,MAAM,GAAG,IAAI,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,SAAS,EAAC,EAAC,CAAC,CAAC;QAEvE,MAAM,MAAM,GACR,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,eAAe,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;QAEtE,OAAO,CAAC,6BAA6B,CAAC,eAAe,CAAC,CAAC;QACvD,OAAO,CAAC,6BAA6B,CAAC,MAAM,CAAC,CAAC;QAE9C,OAAO,MAAM,CAAC;KACf;IAED,yBAAyB;IACzB,IAAI,CAAC,CAAC,KAAK,KAAK,WAAW,EAAE;QAC3B,MAAM,QAAQ,GAAG,IAAI,CAAC,EAAC,MAAM,EAAE,EAAC,KAAK,EAAE,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;QACrD,MAAM,MAAM,GAAG,IAAI,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,QAAQ,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAC,EAAC,CAAC,CAAC;QAEtE,OAAO,CAAC,6BAA6B,CAAC,QAAQ,CAAC,CAAC;QAEhD,OAAO,MAAM,CAAC;KACf;IAED,IAAI,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,EAAE;QACzC,+DAA+D;QAC/D,aAAa;QACb,MAAM,MAAM,GAAG,QAAQ,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;QAChD,OAAO,EAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,KAAK,EAAE,KAAK,EAAC,CAAC;KAC5D;IAED,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IAC/D,MAAM,CAAC,WAAW,EAAE,UAAU,EAAE,UAAU,CAAC,GACvC,QAAQ,CAAC,MAAM,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC;IAC9C,OAAO,OAAO,CAAC,cAAc,CAAC,WAAW,EAAE,UAAU,EAAE,UAAU,CAAC,CAAC;AACrE,CAAC;AAED,MAAM,CAAC,MAAM,UAAU,GAAiB;IACtC,UAAU,EAAE,IAAI;IAChB,WAAW,EAAE,KAAK;IAClB,UAAU,EAAE,IAA6B;CAC1C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {Cast, CastAttrs, CastInputs, DataType, KernelConfig, KernelFunc, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {createSimpleBinaryKernelImpl} from '../utils/binary_impl';\nimport {zeros} from '../utils/zeros_impl';\n\nimport {complex} from './Complex';\nimport {identity} from './Identity';\nimport {real} from './Real';\n\nexport function castImpl(\n    values: TypedArray, shape: number[], inputType: DataType,\n    dtype: DataType): [number[], DataType, TypedArray] {\n  if (dtype === 'int32') {\n    const resultValues = Int32Array.from(values);\n    return [shape, 'int32', resultValues];\n  }\n\n  if (dtype === 'bool') {\n    // This is essentially the result of notEqual(x, 0). We avoid using\n    // kernel notEqual to avoid circular dependency, i.e. binary_utils ->\n    // cast -> notEqual -> binary_utils.\n    const zero = util.toTypedArray([0], inputType);\n\n    const [resultData, resultShape] = createSimpleBinaryKernelImpl(\n        (a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool');\n\n    return [resultShape, 'bool', resultData];\n  }\n  throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`);\n}\n\nexport function cast(\n    args: {inputs: CastInputs, backend: MathBackendCPU, attrs: CastAttrs}):\n    TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {x} = inputs;\n  const {dtype} = attrs;\n\n  // Casting to complex64.\n  if (dtype === 'complex64') {\n    if (x.dtype === 'complex64') {\n      return identity({inputs: {x}, backend});\n    }\n\n    const zerosTensorInfo = zeros(backend, x.shape, x.dtype);\n    const floatX = cast({inputs: {x}, backend, attrs: {dtype: 'float32'}});\n\n    const result =\n        complex({inputs: {real: floatX, imag: zerosTensorInfo}, backend});\n\n    backend.disposeIntermediateTensorInfo(zerosTensorInfo);\n    backend.disposeIntermediateTensorInfo(floatX);\n\n    return result;\n  }\n\n  // Casting from complex64\n  if (x.dtype === 'complex64') {\n    const realPart = real({inputs: {input: x}, backend});\n    const result = cast({inputs: {x: realPart}, backend, attrs: {dtype}});\n\n    backend.disposeIntermediateTensorInfo(realPart);\n\n    return result;\n  }\n\n  if (!util.hasEncodingLoss(x.dtype, dtype)) {\n    // We don't change the underlying data, since we cast to higher\n    // precision.\n    const result = identity({inputs: {x}, backend});\n    return {dataId: result.dataId, shape: result.shape, dtype};\n  }\n\n  const values = backend.data.get(x.dataId).values as TypedArray;\n  const [resultShape, resultType, resultData] =\n      castImpl(values, x.shape, x.dtype, dtype);\n  return backend.makeTensorInfo(resultShape, resultType, resultData);\n}\n\nexport const castConfig: KernelConfig = {\n  kernelName: Cast,\n  backendName: 'cpu',\n  kernelFunc: cast as unknown as KernelFunc\n};\n"]}