/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import * as tf from '@tensorflow/tfjs-core';
|
import { Cast, util } from '@tensorflow/tfjs-core';
|
import { castImplCPU } from '../kernel_utils/shared';
|
import { complex } from './Complex';
|
import { identity } from './Identity';
|
import { notEqual } from './NotEqual';
|
import { real } from './Real';
|
import { int } from '../kernel_utils/int';
|
export function cast(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { dtype } = attrs;
|
// Casting to complex64.
|
if (dtype === 'complex64') {
|
if (x.dtype === 'complex64') {
|
return identity({ inputs: { x }, backend });
|
}
|
// TODO(annxingyuan): Import kernel function once zeros is modularized.
|
const zerosTensor = tf.zeros(x.shape);
|
const floatX = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
|
const result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend });
|
zerosTensor.dispose();
|
backend.disposeIntermediateTensorInfo(floatX);
|
return result;
|
}
|
// Casting from complex64
|
if (x.dtype === 'complex64') {
|
const realPart = real({ inputs: { input: x }, backend });
|
const result = cast({ inputs: { x: realPart }, backend, attrs: { dtype } });
|
backend.disposeIntermediateTensorInfo(realPart);
|
return result;
|
}
|
if (!util.hasEncodingLoss(x.dtype, dtype)) {
|
// We don't change the underlying data, since we cast to higher
|
// precision.
|
const result = identity({ inputs: { x }, backend });
|
return { dataId: result.dataId, shape: result.shape, dtype };
|
}
|
if (backend.shouldExecuteOnCPU([x])) {
|
const values = backend.texData.get(x.dataId).values;
|
const [resultShape, resultType, resultData] = castImplCPU(values, x.shape, x.dtype, dtype);
|
return backend.makeTensorInfo(resultShape, resultType, resultData);
|
}
|
if (dtype === 'int32') {
|
return int(x, backend);
|
}
|
if (dtype === 'bool') {
|
const zerosTensorInfo = backend.makeTensorInfo([], 'bool', util.getTypedArrayFromDType('bool', 1));
|
const binaryInputs = { a: x, b: zerosTensorInfo };
|
const result = notEqual({ inputs: binaryInputs, backend });
|
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
|
return result;
|
}
|
throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
|
}
|
export const castConfig = {
|
kernelName: Cast,
|
backendName: 'webgl',
|
kernelFunc: cast
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQ2FzdC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9DYXN0LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUNILE9BQU8sS0FBSyxFQUFFLE1BQU0sdUJBQXVCLENBQUM7QUFDNUMsT0FBTyxFQUFlLElBQUksRUFBMkUsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHeEksT0FBTyxFQUFDLFdBQVcsRUFBQyxNQUFNLHdCQUF3QixDQUFDO0FBQ25ELE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUNwQyxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBQ3BDLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFFNUIsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLHFCQUFxQixDQUFDO0FBRXhDLE1BQU0sVUFBVSxJQUFJLENBQ2hCLElBQXVFO0lBRXpFLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ25CLE1BQU0sRUFBQyxLQUFLLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFdEIsd0JBQXdCO0lBQ3hCLElBQUksS0FBSyxLQUFLLFdBQVcsRUFBRTtRQUN6QixJQUFJLENBQUMsQ0FBQyxLQUFLLEtBQUssV0FBVyxFQUFFO1lBQzNCLE9BQU8sUUFBUSxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQUMsQ0FBQztTQUN6QztRQUVELHVFQUF1RTtRQUN2RSxNQUFNLFdBQVcsR0FBRyxFQUFFLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztRQUN0QyxNQUFNLE1BQU0sR0FBRyxJQUFJLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFNBQVMsRUFBQyxFQUFDLENBQUMsQ0FBQztRQUV2RSxNQUFNLE1BQU0sR0FDUixPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxJQUFJLEVBQUUsTUFBTSxFQUFFLElBQUksRUFBRSxXQUFXLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO1FBRWxFLFdBQVcsQ0FBQyxPQUFPLEVBQUUsQ0FBQztRQUN0QixPQUFPLENBQUMsNkJBQTZCLENBQUMsTUFBTSxDQUFDLENBQUM7UUFFOUMsT0FBTyxNQUFNLENBQUM7S0FDZjtJQUVELHlCQUF5QjtJQUN6QixJQUFJLENBQUMsQ0FBQyxLQUFLLEtBQUssV0FBVyxFQUFFO1FBQzNCLE1BQU0sUUFBUSxHQUFHLElBQUksQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLEtBQUssRUFBRSxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO1FBQ3JELE1BQU0sTUFBTSxHQUFHLElBQUksQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxRQUFRLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBQ3RFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxRQUFRLENBQUMsQ0FBQztRQUNoRCxPQUFPLE1BQU0sQ0FBQztLQUNmO0lBRUQsSUFBSSxDQUFDLElBQUksQ0FBQyxlQUFlLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxLQUFLLENBQUMsRUFBRTtRQUN6QywrREFBK0Q7UUFDL0QsYUFBYTtRQUNiLE1BQU0sTUFBTSxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7UUFDaEQsT0FBTyxFQUFDLE1BQU0sRUFBRSxNQUFNLENBQUMsTUFBTSxFQUFFLEtBQUssRUFBRSxNQUFNLENBQUMsS0FBSyxFQUFFLEtBQUssRUFBQyxDQUFDO0tBQzVEO0lBRUQsSUFBSSxPQUFPLENBQUMsa0JBQWtCLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFO1FBQ25DLE1BQU0sTUFBTSxHQUFHLE9BQU8sQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO1FBQ2xFLE1BQU0sQ0FBQyxXQUFXLEVBQUUsVUFBVSxFQUFFLFVBQVUsQ0FBQyxHQUN2QyxXQUFXLENBQUMsTUFBTSxFQUFFLENBQUMsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxLQUFLLENBQUMsQ0FBQztRQUNqRCxPQUFPLE9BQU8sQ0FBQyxjQUFjLENBQUMsV0FBVyxFQUFFLFVBQVUsRUFBRSxVQUFVLENBQUMsQ0FBQztLQUNwRTtJQUVELElBQUksS0FBSyxLQUFLLE9BQU8sRUFBRTtRQUNyQixPQUFPLEdBQUcsQ0FBQyxDQUFDLEVBQUUsT0FBTyxDQUFDLENBQUM7S0FDeEI7SUFFRCxJQUFJLEtBQUssS0FBSyxNQUFNLEVBQUU7UUFDcEIsTUFBTSxlQUFlLEdBQUcsT0FBTyxDQUFDLGNBQWMsQ0FDMUMsRUFBRSxFQUFFLE1BQU0sRUFBRSxJQUFJLENBQUMsc0JBQXNCLENBQUMsTUFBTSxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFFeEQsTUFBTSxZQUFZLEdBQWlCLEVBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsZUFBZSxFQUFDLENBQUM7UUFFOUQsTUFBTSxNQUFNLEdBQUcsUUFBUSxDQUFDLEVBQUMsTUFBTSxFQUFFLFlBQVksRUFBRSxPQUFPLEVBQUMsQ0FBZSxDQUFDO1FBQ3ZFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxlQUFlLENBQUMsQ0FBQztRQUN2RCxPQUFPLE1BQU0sQ0FBQztLQUNmO0lBRUQsTUFBTSxJQUFJLEtBQUssQ0FBQyxpQ0FBaUMsQ0FBQyxDQUFDLEtBQUssT0FBTyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0FBQzFFLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxVQUFVLEdBQWlCO0lBQ3RDLFVBQVUsRUFBRSxJQUFJO0lBQ2hCLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxJQUE2QjtDQUMxQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0ICogYXMgdGYgZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcbmltcG9ydCB7QmluYXJ5SW5wdXRzLCBDYXN0LCBDYXN0QXR0cnMsIENhc3RJbnB1dHMsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgVGVuc29ySW5mbywgVHlwZWRBcnJheSwgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZFdlYkdMfSBmcm9tICcuLi9iYWNrZW5kX3dlYmdsJztcbmltcG9ydCB7Y2FzdEltcGxDUFV9IGZyb20gJy4uL2tlcm5lbF91dGlscy9zaGFyZWQnO1xuaW1wb3J0IHtjb21wbGV4fSBmcm9tICcuL0NvbXBsZXgnO1xuaW1wb3J0IHtpZGVudGl0eX0gZnJvbSAnLi9JZGVudGl0eSc7XG5pbXBvcnQge25vdEVxdWFsfSBmcm9tICcuL05vdEVxdWFsJztcbmltcG9ydCB7cmVhbH0gZnJvbSAnLi9SZWFsJztcblxuaW1wb3J0IHtpbnR9IGZyb20gJy4uL2tlcm5lbF91dGlscy9pbnQnO1xuXG5leHBvcnQgZnVuY3Rpb24gY2FzdChcbiAgICBhcmdzOiB7aW5wdXRzOiBDYXN0SW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMLCBhdHRyczogQ2FzdEF0dHJzfSk6XG4gICAgVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4fSA9IGlucHV0cztcbiAgY29uc3Qge2R0eXBlfSA9IGF0dHJzO1xuXG4gIC8vIENhc3RpbmcgdG8gY29tcGxleDY0LlxuICBpZiAoZHR5cGUgPT09ICdjb21wbGV4NjQnKSB7XG4gICAgaWYgKHguZHR5cGUgPT09ICdjb21wbGV4NjQnKSB7XG4gICAgICByZXR1cm4gaWRlbnRpdHkoe2lucHV0czoge3h9LCBiYWNrZW5kfSk7XG4gICAgfVxuXG4gICAgLy8gVE9ETyhhbm54aW5neXVhbik6IEltcG9ydCBrZXJuZWwgZnVuY3Rpb24gb25jZSB6ZXJvcyBpcyBtb2R1bGFyaXplZC5cbiAgICBjb25zdCB6ZXJvc1RlbnNvciA9IHRmLnplcm9zKHguc2hhcGUpO1xuICAgIGNvbnN0IGZsb2F0WCA9IGNhc3Qoe2lucHV0czoge3h9LCBiYWNrZW5kLCBhdHRyczoge2R0eXBlOiAnZmxvYXQzMid9fSk7XG5cbiAgICBjb25zdCByZXN1bHQgPVxuICAgICAgICBjb21wbGV4KHtpbnB1dHM6IHtyZWFsOiBmbG9hdFgsIGltYWc6IHplcm9zVGVuc29yfSwgYmFja2VuZH0pO1xuXG4gICAgemVyb3NUZW5zb3IuZGlzcG9zZSgpO1xuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oZmxvYXRYKTtcblxuICAgIHJldHVybiByZXN1bHQ7XG4gIH1cblxuICAvLyBDYXN0aW5nIGZyb20gY29tcGxleDY0XG4gIGlmICh4LmR0eXBlID09PSAnY29tcGxleDY0Jykge1xuICAgIGNvbnN0IHJlYWxQYXJ0ID0gcmVhbCh7aW5wdXRzOiB7aW5wdXQ6IHh9LCBiYWNrZW5kfSk7XG4gICAgY29uc3QgcmVzdWx0ID0gY2FzdCh7aW5wdXRzOiB7eDogcmVhbFBhcnR9LCBiYWNrZW5kLCBhdHRyczoge2R0eXBlfX0pO1xuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ocmVhbFBhcnQpO1xuICAgIHJldHVybiByZXN1bHQ7XG4gIH1cblxuICBpZiAoIXV0aWwuaGFzRW5jb2RpbmdMb3NzKHguZHR5cGUsIGR0eXBlKSkge1xuICAgIC8vIFdlIGRvbid0IGNoYW5nZSB0aGUgdW5kZXJseWluZyBkYXRhLCBzaW5jZSB3ZSBjYXN0IHRvIGhpZ2hlclxuICAgIC8vIHByZWNpc2lvbi5cbiAgICBjb25zdCByZXN1bHQgPSBpZGVudGl0eSh7aW5wdXRzOiB7eH0sIGJhY2tlbmR9KTtcbiAgICByZXR1cm4ge2RhdGFJZDogcmVzdWx0LmRhdGFJZCwgc2hhcGU6IHJlc3VsdC5zaGFwZSwgZHR5cGV9O1xuICB9XG5cbiAgaWYgKGJhY2tlbmQuc2hvdWxkRXhlY3V0ZU9uQ1BVKFt4XSkpIHtcbiAgICBjb25zdCB2YWx1ZXMgPSBiYWNrZW5kLnRleERhdGEuZ2V0KHguZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgICBjb25zdCBbcmVzdWx0U2hhcGUsIHJlc3VsdFR5cGUsIHJlc3VsdERhdGFdID1cbiAgICAgICAgY2FzdEltcGxDUFUodmFsdWVzLCB4LnNoYXBlLCB4LmR0eXBlLCBkdHlwZSk7XG4gICAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8ocmVzdWx0U2hhcGUsIHJlc3VsdFR5cGUsIHJlc3VsdERhdGEpO1xuICB9XG5cbiAgaWYgKGR0eXBlID09PSAnaW50MzInKSB7XG4gICAgcmV0dXJuIGludCh4LCBiYWNrZW5kKTtcbiAgfVxuXG4gIGlmIChkdHlwZSA9PT0gJ2Jvb2wnKSB7XG4gICAgY29uc3QgemVyb3NUZW5zb3JJbmZvID0gYmFja2VuZC5tYWtlVGVuc29ySW5mbyhcbiAgICAgICAgW10sICdib29sJywgdXRpbC5nZXRUeXBlZEFycmF5RnJvbURUeXBlKCdib29sJywgMSkpO1xuXG4gICAgY29uc3QgYmluYXJ5SW5wdXRzOiBCaW5hcnlJbnB1dHMgPSB7YTogeCwgYjogemVyb3NUZW5zb3JJbmZvfTtcblxuICAgIGNvbnN0IHJlc3VsdCA9IG5vdEVxdWFsKHtpbnB1dHM6IGJpbmFyeUlucHV0cywgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyh6ZXJvc1RlbnNvckluZm8pO1xuICAgIHJldHVybiByZXN1bHQ7XG4gIH1cblxuICB0aHJvdyBuZXcgRXJyb3IoYEVycm9yIGluIENhc3Q6IGZhaWxlZCB0byBjYXN0ICR7eC5kdHlwZX0gdG8gJHtkdHlwZX1gKTtcbn1cblxuZXhwb3J0IGNvbnN0IGNhc3RDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogQ2FzdCxcbiAgYmFja2VuZE5hbWU6ICd3ZWJnbCcsXG4gIGtlcm5lbEZ1bmM6IGNhc3QgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19
|