gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { Cast, util } from '@tensorflow/tfjs-core';
import { createSimpleBinaryKernelImpl } from '../utils/binary_impl';
import { zeros } from '../utils/zeros_impl';
import { complex } from './Complex';
import { identity } from './Identity';
import { real } from './Real';
export function castImpl(values, shape, inputType, dtype) {
    if (dtype === 'int32') {
        const resultValues = Int32Array.from(values);
        return [shape, 'int32', resultValues];
    }
    if (dtype === 'bool') {
        // This is essentially the result of notEqual(x, 0). We avoid using
        // kernel notEqual to avoid circular dependency, i.e. binary_utils ->
        // cast -> notEqual -> binary_utils.
        const zero = util.toTypedArray([0], inputType);
        const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool');
        return [resultShape, 'bool', resultData];
    }
    throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`);
}
export function cast(args) {
    const { inputs, backend, attrs } = args;
    const { x } = inputs;
    const { dtype } = attrs;
    // Casting to complex64.
    if (dtype === 'complex64') {
        if (x.dtype === 'complex64') {
            return identity({ inputs: { x }, backend });
        }
        const zerosTensorInfo = zeros(backend, x.shape, x.dtype);
        const floatX = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
        const result = complex({ inputs: { real: floatX, imag: zerosTensorInfo }, backend });
        backend.disposeIntermediateTensorInfo(zerosTensorInfo);
        backend.disposeIntermediateTensorInfo(floatX);
        return result;
    }
    // Casting from complex64
    if (x.dtype === 'complex64') {
        const realPart = real({ inputs: { input: x }, backend });
        const result = cast({ inputs: { x: realPart }, backend, attrs: { dtype } });
        backend.disposeIntermediateTensorInfo(realPart);
        return result;
    }
    if (!util.hasEncodingLoss(x.dtype, dtype)) {
        // We don't change the underlying data, since we cast to higher
        // precision.
        const result = identity({ inputs: { x }, backend });
        return { dataId: result.dataId, shape: result.shape, dtype };
    }
    const values = backend.data.get(x.dataId).values;
    const [resultShape, resultType, resultData] = castImpl(values, x.shape, x.dtype, dtype);
    return backend.makeTensorInfo(resultShape, resultType, resultData);
}
export const castConfig = {
    kernelName: Cast,
    backendName: 'cpu',
    kernelFunc: cast
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQ2FzdC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC1jcHUvc3JjL2tlcm5lbHMvQ2FzdC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFDSCxPQUFPLEVBQUMsSUFBSSxFQUFxRixJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUdwSSxPQUFPLEVBQUMsNEJBQTRCLEVBQUMsTUFBTSxzQkFBc0IsQ0FBQztBQUNsRSxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0scUJBQXFCLENBQUM7QUFFMUMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBQ3BDLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFFNUIsTUFBTSxVQUFVLFFBQVEsQ0FDcEIsTUFBa0IsRUFBRSxLQUFlLEVBQUUsU0FBbUIsRUFDeEQsS0FBZTtJQUNqQixJQUFJLEtBQUssS0FBSyxPQUFPLEVBQUU7UUFDckIsTUFBTSxZQUFZLEdBQUcsVUFBVSxDQUFDLElBQUksQ0FBQyxNQUFNLENBQUMsQ0FBQztRQUM3QyxPQUFPLENBQUMsS0FBSyxFQUFFLE9BQU8sRUFBRSxZQUFZLENBQUMsQ0FBQztLQUN2QztJQUVELElBQUksS0FBSyxLQUFLLE1BQU0sRUFBRTtRQUNwQixtRUFBbUU7UUFDbkUscUVBQXFFO1FBQ3JFLG9DQUFvQztRQUNwQyxNQUFNLElBQUksR0FBRyxJQUFJLENBQUMsWUFBWSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsU0FBUyxDQUFDLENBQUM7UUFFL0MsTUFBTSxDQUFDLFVBQVUsRUFBRSxXQUFXLENBQUMsR0FBRyw0QkFBNEIsQ0FDMUQsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsRUFBRSxFQUFFLE1BQU0sRUFBRSxJQUFJLEVBQUUsTUFBTSxDQUFDLENBQUM7UUFFbEUsT0FBTyxDQUFDLFdBQVcsRUFBRSxNQUFNLEVBQUUsVUFBVSxDQUFDLENBQUM7S0FDMUM7SUFDRCxNQUFNLElBQUksS0FBSyxDQUFDLGlDQUFpQyxTQUFTLE9BQU8sS0FBSyxFQUFFLENBQUMsQ0FBQztBQUM1RSxDQUFDO0FBRUQsTUFBTSxVQUFVLElBQUksQ0FDaEIsSUFBcUU7SUFFdkUsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxFQUFDLEtBQUssRUFBQyxHQUFHLEtBQUssQ0FBQztJQUV0Qix3QkFBd0I7SUFDeEIsSUFBSSxLQUFLLEtBQUssV0FBVyxFQUFFO1FBQ3pCLElBQUksQ0FBQyxDQUFDLEtBQUssS0FBSyxXQUFXLEVBQUU7WUFDM0IsT0FBTyxRQUFRLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO1NBQ3pDO1FBRUQsTUFBTSxlQUFlLEdBQUcsS0FBSyxDQUFDLE9BQU8sRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztRQUN6RCxNQUFNLE1BQU0sR0FBRyxJQUFJLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFNBQVMsRUFBQyxFQUFDLENBQUMsQ0FBQztRQUV2RSxNQUFNLE1BQU0sR0FDUixPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxJQUFJLEVBQUUsTUFBTSxFQUFFLElBQUksRUFBRSxlQUFlLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO1FBRXRFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxlQUFlLENBQUMsQ0FBQztRQUN2RCxPQUFPLENBQUMsNkJBQTZCLENBQUMsTUFBTSxDQUFDLENBQUM7UUFFOUMsT0FBTyxNQUFNLENBQUM7S0FDZjtJQUVELHlCQUF5QjtJQUN6QixJQUFJLENBQUMsQ0FBQyxLQUFLLEtBQUssV0FBVyxFQUFFO1FBQzNCLE1BQU0sUUFBUSxHQUFHLElBQUksQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLEtBQUssRUFBRSxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO1FBQ3JELE1BQU0sTUFBTSxHQUFHLElBQUksQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxRQUFRLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBRXRFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxRQUFRLENBQUMsQ0FBQztRQUVoRCxPQUFPLE1BQU0sQ0FBQztLQUNmO0lBRUQsSUFBSSxDQUFDLElBQUksQ0FBQyxlQUFlLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxLQUFLLENBQUMsRUFBRTtRQUN6QywrREFBK0Q7UUFDL0QsYUFBYTtRQUNiLE1BQU0sTUFBTSxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7UUFDaEQsT0FBTyxFQUFDLE1BQU0sRUFBRSxNQUFNLENBQUMsTUFBTSxFQUFFLEtBQUssRUFBRSxNQUFNLENBQUMsS0FBSyxFQUFFLEtBQUssRUFBQyxDQUFDO0tBQzVEO0lBRUQsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQW9CLENBQUM7SUFDL0QsTUFBTSxDQUFDLFdBQVcsRUFBRSxVQUFVLEVBQUUsVUFBVSxDQUFDLEdBQ3ZDLFFBQVEsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBQzlDLE9BQU8sT0FBTyxDQUFDLGNBQWMsQ0FBQyxXQUFXLEVBQUUsVUFBVSxFQUFFLFVBQVUsQ0FBQyxDQUFDO0FBQ3JFLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxVQUFVLEdBQWlCO0lBQ3RDLFVBQVUsRUFBRSxJQUFJO0lBQ2hCLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxJQUE2QjtDQUMxQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtDYXN0LCBDYXN0QXR0cnMsIENhc3RJbnB1dHMsIERhdGFUeXBlLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7Y3JlYXRlU2ltcGxlQmluYXJ5S2VybmVsSW1wbH0gZnJvbSAnLi4vdXRpbHMvYmluYXJ5X2ltcGwnO1xuaW1wb3J0IHt6ZXJvc30gZnJvbSAnLi4vdXRpbHMvemVyb3NfaW1wbCc7XG5cbmltcG9ydCB7Y29tcGxleH0gZnJvbSAnLi9Db21wbGV4JztcbmltcG9ydCB7aWRlbnRpdHl9IGZyb20gJy4vSWRlbnRpdHknO1xuaW1wb3J0IHtyZWFsfSBmcm9tICcuL1JlYWwnO1xuXG5leHBvcnQgZnVuY3Rpb24gY2FzdEltcGwoXG4gICAgdmFsdWVzOiBUeXBlZEFycmF5LCBzaGFwZTogbnVtYmVyW10sIGlucHV0VHlwZTogRGF0YVR5cGUsXG4gICAgZHR5cGU6IERhdGFUeXBlKTogW251bWJlcltdLCBEYXRhVHlwZSwgVHlwZWRBcnJheV0ge1xuICBpZiAoZHR5cGUgPT09ICdpbnQzMicpIHtcbiAgICBjb25zdCByZXN1bHRWYWx1ZXMgPSBJbnQzMkFycmF5LmZyb20odmFsdWVzKTtcbiAgICByZXR1cm4gW3NoYXBlLCAnaW50MzInLCByZXN1bHRWYWx1ZXNdO1xuICB9XG5cbiAgaWYgKGR0eXBlID09PSAnYm9vbCcpIHtcbiAgICAvLyBUaGlzIGlzIGVzc2VudGlhbGx5IHRoZSByZXN1bHQgb2Ygbm90RXF1YWwoeCwgMCkuIFdlIGF2b2lkIHVzaW5nXG4gICAgLy8ga2VybmVsIG5vdEVxdWFsIHRvIGF2b2lkIGNpcmN1bGFyIGRlcGVuZGVuY3ksIGkuZS4gYmluYXJ5X3V0aWxzIC0+XG4gICAgLy8gY2FzdCAtPiBub3RFcXVhbCAtPiBiaW5hcnlfdXRpbHMuXG4gICAgY29uc3QgemVybyA9IHV0aWwudG9UeXBlZEFycmF5KFswXSwgaW5wdXRUeXBlKTtcblxuICAgIGNvbnN0IFtyZXN1bHREYXRhLCByZXN1bHRTaGFwZV0gPSBjcmVhdGVTaW1wbGVCaW5hcnlLZXJuZWxJbXBsKFxuICAgICAgICAoYSwgYikgPT4gKGEgIT09IGIpID8gMSA6IDApKHNoYXBlLCBbXSwgdmFsdWVzLCB6ZXJvLCAnYm9vbCcpO1xuXG4gICAgcmV0dXJuIFtyZXN1bHRTaGFwZSwgJ2Jvb2wnLCByZXN1bHREYXRhXTtcbiAgfVxuICB0aHJvdyBuZXcgRXJyb3IoYEVycm9yIGluIENhc3Q6IGZhaWxlZCB0byBjYXN0ICR7aW5wdXRUeXBlfSB0byAke2R0eXBlfWApO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY2FzdChcbiAgICBhcmdzOiB7aW5wdXRzOiBDYXN0SW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSwgYXR0cnM6IENhc3RBdHRyc30pOlxuICAgIFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHtkdHlwZX0gPSBhdHRycztcblxuICAvLyBDYXN0aW5nIHRvIGNvbXBsZXg2NC5cbiAgaWYgKGR0eXBlID09PSAnY29tcGxleDY0Jykge1xuICAgIGlmICh4LmR0eXBlID09PSAnY29tcGxleDY0Jykge1xuICAgICAgcmV0dXJuIGlkZW50aXR5KHtpbnB1dHM6IHt4fSwgYmFja2VuZH0pO1xuICAgIH1cblxuICAgIGNvbnN0IHplcm9zVGVuc29ySW5mbyA9IHplcm9zKGJhY2tlbmQsIHguc2hhcGUsIHguZHR5cGUpO1xuICAgIGNvbnN0IGZsb2F0WCA9IGNhc3Qoe2lucHV0czoge3h9LCBiYWNrZW5kLCBhdHRyczoge2R0eXBlOiAnZmxvYXQzMid9fSk7XG5cbiAgICBjb25zdCByZXN1bHQgPVxuICAgICAgICBjb21wbGV4KHtpbnB1dHM6IHtyZWFsOiBmbG9hdFgsIGltYWc6IHplcm9zVGVuc29ySW5mb30sIGJhY2tlbmR9KTtcblxuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oemVyb3NUZW5zb3JJbmZvKTtcbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKGZsb2F0WCk7XG5cbiAgICByZXR1cm4gcmVzdWx0O1xuICB9XG5cbiAgLy8gQ2FzdGluZyBmcm9tIGNvbXBsZXg2NFxuICBpZiAoeC5kdHlwZSA9PT0gJ2NvbXBsZXg2NCcpIHtcbiAgICBjb25zdCByZWFsUGFydCA9IHJlYWwoe2lucHV0czoge2lucHV0OiB4fSwgYmFja2VuZH0pO1xuICAgIGNvbnN0IHJlc3VsdCA9IGNhc3Qoe2lucHV0czoge3g6IHJlYWxQYXJ0fSwgYmFja2VuZCwgYXR0cnM6IHtkdHlwZX19KTtcblxuICAgIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ocmVhbFBhcnQpO1xuXG4gICAgcmV0dXJuIHJlc3VsdDtcbiAgfVxuXG4gIGlmICghdXRpbC5oYXNFbmNvZGluZ0xvc3MoeC5kdHlwZSwgZHR5cGUpKSB7XG4gICAgLy8gV2UgZG9uJ3QgY2hhbmdlIHRoZSB1bmRlcmx5aW5nIGRhdGEsIHNpbmNlIHdlIGNhc3QgdG8gaGlnaGVyXG4gICAgLy8gcHJlY2lzaW9uLlxuICAgIGNvbnN0IHJlc3VsdCA9IGlkZW50aXR5KHtpbnB1dHM6IHt4fSwgYmFja2VuZH0pO1xuICAgIHJldHVybiB7ZGF0YUlkOiByZXN1bHQuZGF0YUlkLCBzaGFwZTogcmVzdWx0LnNoYXBlLCBkdHlwZX07XG4gIH1cblxuICBjb25zdCB2YWx1ZXMgPSBiYWNrZW5kLmRhdGEuZ2V0KHguZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgW3Jlc3VsdFNoYXBlLCByZXN1bHRUeXBlLCByZXN1bHREYXRhXSA9XG4gICAgICBjYXN0SW1wbCh2YWx1ZXMsIHguc2hhcGUsIHguZHR5cGUsIGR0eXBlKTtcbiAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8ocmVzdWx0U2hhcGUsIHJlc3VsdFR5cGUsIHJlc3VsdERhdGEpO1xufVxuXG5leHBvcnQgY29uc3QgY2FzdENvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBDYXN0LFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IGNhc3QgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19