gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util } from '@tensorflow/tfjs-core';
import { assertNotComplex } from '../cpu_util';
import { createSimpleUnaryImpl } from './unary_impl';
/**
 * Template that creates a `KernelFunc` for unary ops.
 * @param name Kernel name.
 * @param op A `SimpleUnaryOperation` for the kernel.
 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
 *     result has the same dtype as the input. This is mainly used in certain
 *     kernels that return bool type, such as isFinite, isInf, etc.
 */
export function unaryKernelFunc(name, op, dtype) {
    const impl = createSimpleUnaryImpl(op);
    return unaryKernelFuncFromImpl(name, impl, dtype);
}
/**
 * Template that creates a `KernelFunc` for unary ops from the given
 * `SimpleUnaryImpl`..
 * @param name Kernel name.
 * @param unaryImpl A `SimpleUnaryImpl` that implements the op.
 * @param dtype Optional. If set, the result has this dtype. Otherwise, the
 *     result has the same dtype as the input. This is mainly used in certain
 *     kernels that return bool type, such as isFinite, isInf, etc.
 */
export function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
    return ({ inputs, attrs, backend }) => {
        const { x } = inputs;
        assertNotComplex(x, name);
        const cpuBackend = backend;
        const values = cpuBackend.data.get(x.dataId).values;
        let decoded;
        if (x.dtype === 'string') {
            if (!Array.isArray(values)) {
                throw new Error('String tensor\'s value was not an instance of Array');
            }
            decoded = backend_util.fromUint8ToStringArray(values);
        }
        else {
            decoded = values;
        }
        const $dtype = dtype || x.dtype;
        const newValues = unaryImpl(decoded, $dtype, attrs);
        return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
    };
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidW5hcnlfdXRpbHMuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy91dGlscy91bmFyeV91dGlscy50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUF1QyxNQUFNLHVCQUF1QixDQUFDO0FBR3pGLE9BQU8sRUFBQyxnQkFBZ0IsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUM3QyxPQUFPLEVBQUMscUJBQXFCLEVBQUMsTUFBTSxjQUFjLENBQUM7QUFJbkQ7Ozs7Ozs7R0FPRztBQUNILE1BQU0sVUFBVSxlQUFlLENBRTdCLElBQVksRUFBRSxFQUE4QixFQUM1QyxLQUFzQjtJQUV0QixNQUFNLElBQUksR0FBRyxxQkFBcUIsQ0FBTyxFQUFFLENBQUMsQ0FBQztJQUU3QyxPQUFPLHVCQUF1QixDQUFPLElBQUksRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7QUFDMUQsQ0FBQztBQUVEOzs7Ozs7OztHQVFHO0FBQ0gsTUFBTSxVQUFVLHVCQUF1QixDQUVyQyxJQUFZLEVBQUUsU0FBZ0MsRUFDOUMsS0FBc0I7SUFFdEIsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEtBQUssRUFBRSxPQUFPLEVBQUMsRUFBRSxFQUFFO1FBQ2xDLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFxQixDQUFDO1FBQ2xDLGdCQUFnQixDQUFDLENBQUMsRUFBRSxJQUFJLENBQUMsQ0FBQztRQUUxQixNQUFNLFVBQVUsR0FBRyxPQUF5QixDQUFDO1FBQzdDLE1BQU0sTUFBTSxHQUFHLFVBQVUsQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFNLENBQUM7UUFDcEQsSUFBSSxPQUFxQixDQUFDO1FBQzFCLElBQUksQ0FBQyxDQUFDLEtBQUssS0FBSyxRQUFRLEVBQUU7WUFDeEIsSUFBSSxDQUFDLEtBQUssQ0FBQyxPQUFPLENBQUMsTUFBTSxDQUFDLEVBQUU7Z0JBQzFCLE1BQU0sSUFBSSxLQUFLLENBQUMscURBQXFELENBQUMsQ0FBQzthQUN4RTtZQUNELE9BQU8sR0FBRyxZQUFZLENBQUMsc0JBQXNCLENBQUMsTUFBTSxDQUN0QyxDQUFDO1NBQ2hCO2FBQU07WUFDTCxPQUFPLEdBQUcsTUFBaUMsQ0FBQztTQUM3QztRQUVELE1BQU0sTUFBTSxHQUFHLEtBQUssSUFBSSxDQUFDLENBQUMsS0FBdUIsQ0FBQztRQUNsRCxNQUFNLFNBQVMsR0FBRyxTQUFTLENBQUMsT0FBTyxFQUFFLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztRQUNwRCxPQUFPLFVBQVUsQ0FBQyxjQUFjLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxNQUFNLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFDL0QsQ0FBQyxDQUFDO0FBQ0osQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIERhdGFUeXBlRm9yLCBLZXJuZWxGdW5jLCBVbmFyeUlucHV0c30gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZENQVX0gZnJvbSAnLi4vYmFja2VuZF9jcHUnO1xuaW1wb3J0IHthc3NlcnROb3RDb21wbGV4fSBmcm9tICcuLi9jcHVfdXRpbCc7XG5pbXBvcnQge2NyZWF0ZVNpbXBsZVVuYXJ5SW1wbH0gZnJvbSAnLi91bmFyeV9pbXBsJztcblxuaW1wb3J0IHtTaW1wbGVVbmFyeUltcGwsIFNpbXBsZVVuYXJ5T3BlcmF0aW9ufSBmcm9tICcuL3VuYXJ5X3R5cGVzJztcblxuLyoqXG4gKiBUZW1wbGF0ZSB0aGF0IGNyZWF0ZXMgYSBgS2VybmVsRnVuY2AgZm9yIHVuYXJ5IG9wcy5cbiAqIEBwYXJhbSBuYW1lIEtlcm5lbCBuYW1lLlxuICogQHBhcmFtIG9wIEEgYFNpbXBsZVVuYXJ5T3BlcmF0aW9uYCBmb3IgdGhlIGtlcm5lbC5cbiAqIEBwYXJhbSBkdHlwZSBPcHRpb25hbC4gSWYgc2V0LCB0aGUgcmVzdWx0IGhhcyB0aGlzIGR0eXBlLiBPdGhlcndpc2UsIHRoZVxuICogICAgIHJlc3VsdCBoYXMgdGhlIHNhbWUgZHR5cGUgYXMgdGhlIGlucHV0LiBUaGlzIGlzIG1haW5seSB1c2VkIGluIGNlcnRhaW5cbiAqICAgICBrZXJuZWxzIHRoYXQgcmV0dXJuIGJvb2wgdHlwZSwgc3VjaCBhcyBpc0Zpbml0ZSwgaXNJbmYsIGV0Yy5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIHVuYXJ5S2VybmVsRnVuYzxJIGV4dGVuZHMgbnVtYmVyIHwgc3RyaW5nID0gbnVtYmVyLFxuICBPIGV4dGVuZHMgbnVtYmVyIHwgc3RyaW5nID0gbnVtYmVyPihcbiAgbmFtZTogc3RyaW5nLCBvcDogU2ltcGxlVW5hcnlPcGVyYXRpb248SSwgTz4sXG4gIGR0eXBlPzogRGF0YVR5cGVGb3I8Tz4pOiBLZXJuZWxGdW5jIHtcblxuICBjb25zdCBpbXBsID0gY3JlYXRlU2ltcGxlVW5hcnlJbXBsPEksIE8+KG9wKTtcblxuICByZXR1cm4gdW5hcnlLZXJuZWxGdW5jRnJvbUltcGw8SSwgTz4obmFtZSwgaW1wbCwgZHR5cGUpO1xufVxuXG4vKipcbiAqIFRlbXBsYXRlIHRoYXQgY3JlYXRlcyBhIGBLZXJuZWxGdW5jYCBmb3IgdW5hcnkgb3BzIGZyb20gdGhlIGdpdmVuXG4gKiBgU2ltcGxlVW5hcnlJbXBsYC4uXG4gKiBAcGFyYW0gbmFtZSBLZXJuZWwgbmFtZS5cbiAqIEBwYXJhbSB1bmFyeUltcGwgQSBgU2ltcGxlVW5hcnlJbXBsYCB0aGF0IGltcGxlbWVudHMgdGhlIG9wLlxuICogQHBhcmFtIGR0eXBlIE9wdGlvbmFsLiBJZiBzZXQsIHRoZSByZXN1bHQgaGFzIHRoaXMgZHR5cGUuIE90aGVyd2lzZSwgdGhlXG4gKiAgICAgcmVzdWx0IGhhcyB0aGUgc2FtZSBkdHlwZSBhcyB0aGUgaW5wdXQuIFRoaXMgaXMgbWFpbmx5IHVzZWQgaW4gY2VydGFpblxuICogICAgIGtlcm5lbHMgdGhhdCByZXR1cm4gYm9vbCB0eXBlLCBzdWNoIGFzIGlzRmluaXRlLCBpc0luZiwgZXRjLlxuICovXG5leHBvcnQgZnVuY3Rpb24gdW5hcnlLZXJuZWxGdW5jRnJvbUltcGw8SSBleHRlbmRzIG51bWJlciB8IHN0cmluZyA9IG51bWJlcixcbiAgTyBleHRlbmRzIG51bWJlciB8IHN0cmluZyA9IG51bWJlcj4oXG4gIG5hbWU6IHN0cmluZywgdW5hcnlJbXBsOiBTaW1wbGVVbmFyeUltcGw8SSwgTz4sXG4gIGR0eXBlPzogRGF0YVR5cGVGb3I8Tz4pOiBLZXJuZWxGdW5jIHtcblxuICByZXR1cm4gKHtpbnB1dHMsIGF0dHJzLCBiYWNrZW5kfSkgPT4ge1xuICAgIGNvbnN0IHt4fSA9IGlucHV0cyBhcyBVbmFyeUlucHV0cztcbiAgICBhc3NlcnROb3RDb21wbGV4KHgsIG5hbWUpO1xuXG4gICAgY29uc3QgY3B1QmFja2VuZCA9IGJhY2tlbmQgYXMgTWF0aEJhY2tlbmRDUFU7XG4gICAgY29uc3QgdmFsdWVzID0gY3B1QmFja2VuZC5kYXRhLmdldCh4LmRhdGFJZCkudmFsdWVzO1xuICAgIGxldCBkZWNvZGVkOiBBcnJheUxpa2U8ST47XG4gICAgaWYgKHguZHR5cGUgPT09ICdzdHJpbmcnKSB7XG4gICAgICBpZiAoIUFycmF5LmlzQXJyYXkodmFsdWVzKSkge1xuICAgICAgICB0aHJvdyBuZXcgRXJyb3IoJ1N0cmluZyB0ZW5zb3JcXCdzIHZhbHVlIHdhcyBub3QgYW4gaW5zdGFuY2Ugb2YgQXJyYXknKTtcbiAgICAgIH1cbiAgICAgIGRlY29kZWQgPSBiYWNrZW5kX3V0aWwuZnJvbVVpbnQ4VG9TdHJpbmdBcnJheSh2YWx1ZXMpIGFzIHVua25vd24gYXNcbiAgICAgICAgQXJyYXlMaWtlPEk+O1xuICAgIH0gZWxzZSB7XG4gICAgICBkZWNvZGVkID0gdmFsdWVzIGFzIHVua25vd24gYXMgQXJyYXlMaWtlPEk+O1xuICAgIH1cblxuICAgIGNvbnN0ICRkdHlwZSA9IGR0eXBlIHx8IHguZHR5cGUgYXMgRGF0YVR5cGVGb3I8Tz47XG4gICAgY29uc3QgbmV3VmFsdWVzID0gdW5hcnlJbXBsKGRlY29kZWQsICRkdHlwZSwgYXR0cnMpO1xuICAgIHJldHVybiBjcHVCYWNrZW5kLm1ha2VUZW5zb3JJbmZvKHguc2hhcGUsICRkdHlwZSwgbmV3VmFsdWVzKTtcbiAgfTtcbn1cbiJdfQ==