/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { Cast, util } from '@tensorflow/tfjs-core'; import { createSimpleBinaryKernelImpl } from '../utils/binary_impl'; import { zeros } from '../utils/zeros_impl'; import { complex } from './Complex'; import { identity } from './Identity'; import { real } from './Real'; export function castImpl(values, shape, inputType, dtype) { if (dtype === 'int32') { const resultValues = Int32Array.from(values); return [shape, 'int32', resultValues]; } if (dtype === 'bool') { // This is essentially the result of notEqual(x, 0). We avoid using // kernel notEqual to avoid circular dependency, i.e. binary_utils -> // cast -> notEqual -> binary_utils. const zero = util.toTypedArray([0], inputType); const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool'); return [resultShape, 'bool', resultData]; } throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`); } export function cast(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { dtype } = attrs; // Casting to complex64. if (dtype === 'complex64') { if (x.dtype === 'complex64') { return identity({ inputs: { x }, backend }); } const zerosTensorInfo = zeros(backend, x.shape, x.dtype); const floatX = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } }); const result = complex({ inputs: { real: floatX, imag: zerosTensorInfo }, backend }); backend.disposeIntermediateTensorInfo(zerosTensorInfo); backend.disposeIntermediateTensorInfo(floatX); return result; } // Casting from complex64 if (x.dtype === 'complex64') { const realPart = real({ inputs: { input: x }, backend }); const result = cast({ inputs: { x: realPart }, backend, attrs: { dtype } }); backend.disposeIntermediateTensorInfo(realPart); return result; } if (!util.hasEncodingLoss(x.dtype, dtype)) { // We don't change the underlying data, since we cast to higher // precision. const result = identity({ inputs: { x }, backend }); return { dataId: result.dataId, shape: result.shape, dtype }; } const values = backend.data.get(x.dataId).values; const [resultShape, resultType, resultData] = castImpl(values, x.shape, x.dtype, dtype); return backend.makeTensorInfo(resultShape, resultType, resultData); } export const castConfig = { kernelName: Cast, backendName: 'cpu', kernelFunc: cast }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQ2FzdC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC1jcHUvc3JjL2tlcm5lbHMvQ2FzdC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFDSCxPQUFPLEVBQUMsSUFBSSxFQUFxRixJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUdwSSxPQUFPLEVBQUMsNEJBQTRCLEVBQUMsTUFBTSxzQkFBc0IsQ0FBQztBQUNsRSxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0scUJBQXFCLENBQUM7QUFFMUMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBQ3BDLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFFNUIsTUFBTSxVQUFVLFFBQVEsQ0FDcEIsTUFBa0IsRUFBRSxLQUFlLEVBQUUsU0FBbUIsRUFDeEQsS0FBZTtJQUNqQixJQUFJLEtBQUssS0FBSyxPQUFPLEVBQUU7UUFDckIsTUFBTSxZQUFZLEdBQUcsVUFBVSxDQUFDLElBQUksQ0FBQyxNQUFNLENBQUMsQ0FBQztRQUM3QyxPQUFPLENBQUMsS0FBSyxFQUFFLE9BQU8sRUFBRSxZQUFZLENBQUMsQ0FBQztLQUN2QztJQUVELElBQUksS0FBSyxLQUFLLE1BQU0sRUFBRTtRQUNwQixtRUFBbUU7UUFDbkUscUVBQXFFO1FBQ3JFLG9DQUFvQztRQUNwQyxNQUFNLElBQUksR0FBRyxJQUFJLENBQUMsWUFBWSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsU0FBUyxDQUFDLENBQUM7UUFFL0MsTUFBTSxDQUFDLFVBQVUsRUFBRSxXQUFXLENBQUMsR0FBRyw0QkFBNEIsQ0FDMUQsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsRUFBRSxFQUFFLE1BQU0sRUFBRSxJQUFJLEVBQUUsTUFBTSxDQUFDLENBQUM7UUFFbEUsT0FBTyxDQUFDLFdBQVcsRUFBRSxNQUFNLEVBQUUsVUFBVSxDQUFDLENBQUM7S0FDMUM7SUFDRCxNQUFNLElBQUksS0FBSyxDQUFDLGlDQUFpQyxTQUFTLE9BQU8sS0FBSyxFQUFFLENBQUMsQ0FBQztBQUM1RSxDQUFDO0FBRUQsTUFBTSxVQUFVLElBQUksQ0FDaEIsSUFBcUU7SUFFdkUsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxFQUFDLEtBQUssRUFBQyxHQUFHLEtBQUssQ0FBQztJQUV0Qix3QkFBd0I7SUFDeEIsSUFBSSxLQUFLLEtBQUssV0FBVyxFQUFFO1FBQ3pCLElBQUksQ0FBQyxDQUFDLEtBQUssS0FBSyxXQUFXLEVBQUU7WUFDM0IsT0FBTyxRQUFRLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO1NBQ3pDO1FBRUQsTUFBTSxlQUFlLEdBQUcsS0FBSyxDQUFDLE9BQU8sRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztRQUN6RCxNQUFNLE1BQU0sR0FBRyxJQUFJLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFNBQVMsRUFBQyxFQUFDLENBQUMsQ0FBQztRQUV2RSxNQUFNLE1BQU0sR0FDUixPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxJQUFJLEVBQUUsTUFBTSxFQUFFLElBQUksRUFBRSxlQUFlLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO1FBRXRFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxlQUFlLENBQUMsQ0FBQztRQUN2RCxPQUFPLENBQUMsNkJBQTZCLENBQUMsTUFBTSxDQUFDLENBQUM7UUFFOUMsT0FBTyxNQUFNLENBQUM7S0FDZjtJQUVELHlCQUF5QjtJQUN6QixJQUFJLENBQUMsQ0FBQyxLQUFLLEtBQUssV0FBVyxFQUFFO1FBQzNCLE1BQU0sUUFBUSxHQUFHLElBQUksQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLEtBQUssRUFBRSxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO1FBQ3JELE1BQU0sTUFBTSxHQUFHLElBQUksQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxRQUFRLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBRXRFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxRQUFRLENBQUMsQ0FBQztRQUVoRCxPQUFPLE1BQU0sQ0FBQztLQUNmO0lBRUQsSUFBSSxDQUFDLElBQUksQ0FBQyxlQUFlLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxLQUFLLENBQUMsRUFBRTtRQUN6QywrREFBK0Q7UUFDL0QsYUFBYTtRQUNiLE1BQU0sTUFBTSxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7UUFDaEQsT0FBTyxFQUFDLE1BQU0sRUFBRSxNQUFNLENBQUMsTUFBTSxFQUFFLEtBQUssRUFBRSxNQUFNLENBQUMsS0FBSyxFQUFFLEtBQUssRUFBQyxDQUFDO0tBQzVEO0lBRUQsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQW9CLENBQUM7SUFDL0QsTUFBTSxDQUFDLFdBQVcsRUFBRSxVQUFVLEVBQUUsVUFBVSxDQUFDLEdBQ3ZDLFFBQVEsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBQzlDLE9BQU8sT0FBTyxDQUFDLGNBQWMsQ0FBQyxXQUFXLEVBQUUsVUFBVSxFQUFFLFVBQVUsQ0FBQyxDQUFDO0FBQ3JFLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxVQUFVLEdBQWlCO0lBQ3RDLFVBQVUsRUFBRSxJQUFJO0lBQ2hCLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxJQUE2QjtDQUMxQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtDYXN0LCBDYXN0QXR0cnMsIENhc3RJbnB1dHMsIERhdGFUeXBlLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7Y3JlYXRlU2ltcGxlQmluYXJ5S2VybmVsSW1wbH0gZnJvbSAnLi4vdXRpbHMvYmluYXJ5X2ltcGwnO1xuaW1wb3J0IHt6ZXJvc30gZnJvbSAnLi4vdXRpbHMvemVyb3NfaW1wbCc7XG5cbmltcG9ydCB7Y29tcGxleH0gZnJvbSAnLi9Db21wbGV4JztcbmltcG9ydCB7aWRlbnRpdHl9IGZyb20gJy4vSWRlbnRpdHknO1xuaW1wb3J0IHtyZWFsfSBmcm9tICcuL1JlYWwnO1xuXG5leHBvcnQgZnVuY3Rpb24gY2FzdEltcGwoXG4gICAgdmFsdWVzOiBUeXBlZEFycmF5LCBzaGFwZTogbnVtYmVyW10sIGlucHV0VHlwZTogRGF0YVR5cGUsXG4gICAgZHR5cGU6IERhdGFUeXBlKTogW251bWJlcltdLCBEYXRhVHlwZSwgVHlwZWRBcnJheV0ge1xuICBpZiAoZHR5cGUgPT09ICdpbnQzMicpIHtcbiAgICBjb25zdCByZXN1bHRWYWx1ZXMgPSBJbnQzMkFycmF5LmZyb20odmFsdWVzKTtcbiAgICByZXR1cm4gW3NoYXBlLCAnaW50MzInLCByZXN1bHRWYWx1ZXNdO1xuICB9XG5cbiAgaWYgKGR0eXBlID09PSAnYm9vbCcpIHtcbiAgICAvLyBUaGlzIGlzIGVzc2VudGlhbGx5IHRoZSByZXN1bHQgb2Ygbm90RXF1YWwoeCwgMCkuIFdlIGF2b2lkIHVzaW5nXG4gICAgLy8ga2VybmVsIG5vdEVxdWFsIHRvIGF2b2lkIGNpcmN1bGFyIGRlcGVuZGVuY3ksIGkuZS4gYmluYXJ5X3V0aWxzIC0+XG4gICAgLy8gY2FzdCAtPiBub3RFcXVhbCAtPiBiaW5hcnlfdXRpbHMuXG4gICAgY29uc3QgemVybyA9IHV0aWwudG9UeXBlZEFycmF5KFswXSwgaW5wdXRUeXBlKTtcblxuICAgIGNvbnN0IFtyZXN1bHREYXRhLCByZXN1bHRTaGFwZV0gPSBjcmVhdGVTaW1wbGVCaW5hcnlLZXJuZWxJbXBsKFxuICAgICAgICAoYSwgYikgPT4gKGEgIT09IGIpID8gMSA6IDApKHNoYXBlLCBbXSwgdmFsdWVzLCB6ZXJvLCAnYm9vbCcpO1xuXG4gICAgcmV0dXJuIFtyZXN1bHRTaGFwZSwgJ2Jvb2wnLCByZXN1bHREYXRhXTtcbiAgfVxuICB0aHJvdyBuZXcgRXJyb3IoYEVycm9yIGluIENhc3Q6IGZhaWxlZCB0byBjYXN0ICR7aW5wdXRUeXBlfSB0byAke2R0eXBlfWApO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY2FzdChcbiAgICBhcmdzOiB7aW5wdXRzOiBDYXN0SW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSwgYXR0cnM6IENhc3RBdHRyc30pOlxuICAgIFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHtkdHlwZX0gPSBhdHRycztcblxuICAvLyBDYXN0aW5nIHRvIGNvbXBsZXg2NC5cbiAgaWYgKGR0eXBlID09PSAnY29tcGxleDY0Jykge1xuICAgIGlmICh4LmR0eXBlID09PSAnY29tcGxleDY0Jykge1xuICAgICAgcmV0dXJuIGlkZW50aXR5KHtpbnB1dHM6IHt4fSwgYmFja2VuZH0pO1xuICAgIH1cblxuICAgIGNvbnN0IHplcm9zVGVuc29ySW5mbyA9IHplcm9zKGJhY2tlbmQsIHguc2hhcGUsIHguZHR5cGUpO1xuICAgIGNvbnN0IGZsb2F0WCA9IGNhc3Qoe2lucHV0czoge3h9LCBiYWNrZW5kLCBhdHRyczoge2R0eXBlOiAnZmxvYXQzMid9fSk7XG5cbiAgICBjb25zdCByZXN1bHQgPVxuICAgICAgICBjb21wbGV4KHtpbnB1dHM6IHtyZWFsOiBmbG9hdFgsIGltYWc6IHplcm9zVGVuc29ySW5mb30sIGJhY2tlbmR9KTtcblxuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oemVyb3NUZW5zb3JJbmZvKTtcbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKGZsb2F0WCk7XG5cbiAgICByZXR1cm4gcmVzdWx0O1xuICB9XG5cbiAgLy8gQ2FzdGluZyBmcm9tIGNvbXBsZXg2NFxuICBpZiAoeC5kdHlwZSA9PT0gJ2NvbXBsZXg2NCcpIHtcbiAgICBjb25zdCByZWFsUGFydCA9IHJlYWwoe2lucHV0czoge2lucHV0OiB4fSwgYmFja2VuZH0pO1xuICAgIGNvbnN0IHJlc3VsdCA9IGNhc3Qoe2lucHV0czoge3g6IHJlYWxQYXJ0fSwgYmFja2VuZCwgYXR0cnM6IHtkdHlwZX19KTtcblxuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ocmVhbFBhcnQpO1xuXG4gICAgcmV0dXJuIHJlc3VsdDtcbiAgfVxuXG4gIGlmICghdXRpbC5oYXNFbmNvZGluZ0xvc3MoeC5kdHlwZSwgZHR5cGUpKSB7XG4gICAgLy8gV2UgZG9uJ3QgY2hhbmdlIHRoZSB1bmRlcmx5aW5nIGRhdGEsIHNpbmNlIHdlIGNhc3QgdG8gaGlnaGVyXG4gICAgLy8gcHJlY2lzaW9uLlxuICAgIGNvbnN0IHJlc3VsdCA9IGlkZW50aXR5KHtpbnB1dHM6IHt4fSwgYmFja2VuZH0pO1xuICAgIHJldHVybiB7ZGF0YUlkOiByZXN1bHQuZGF0YUlkLCBzaGFwZTogcmVzdWx0LnNoYXBlLCBkdHlwZX07XG4gIH1cblxuICBjb25zdCB2YWx1ZXMgPSBiYWNrZW5kLmRhdGEuZ2V0KHguZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgW3Jlc3VsdFNoYXBlLCByZXN1bHRUeXBlLCByZXN1bHREYXRhXSA9XG4gICAgICBjYXN0SW1wbCh2YWx1ZXMsIHguc2hhcGUsIHguZHR5cGUsIGR0eXBlKTtcbiAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8ocmVzdWx0U2hhcGUsIHJlc3VsdFR5cGUsIHJlc3VsdERhdGEpO1xufVxuXG5leHBvcnQgY29uc3QgY2FzdENvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBDYXN0LFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IGNhc3QgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19