gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/**
 * @license
 * Copyright 2022 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { ENGINE } from '../engine';
import { RaggedGather } from '../kernel_names';
import { convertToTensor } from '../tensor_util_env';
import { op } from './operation';
function raggedGather_(paramsNestedSplits, paramsDenseValues, indices, outputRaggedRank) {
    const $paramsNestedSplits = paramsNestedSplits.map((t, i) => convertToTensor(t, `tensors${i}`, 'raggedGather', 'int32'));
    const $paramsDenseValues = convertToTensor(paramsDenseValues, 'paramsDenseValues', 'raggedGather');
    const $indices = convertToTensor(indices, 'indices', 'raggedGather', 'int32');
    const inputs = {
        paramsNestedSplits: $paramsNestedSplits,
        paramsDenseValues: $paramsDenseValues,
        indices: $indices,
    };
    const attrs = { outputRaggedRank };
    const result = ENGINE.runKernel(RaggedGather, inputs, attrs);
    return {
        outputNestedSplits: result.slice(0, result.length - 1),
        outputDenseValues: result[result.length - 1],
    };
}
export const raggedGather = /* @__PURE__ */ op({ raggedGather_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoicmFnZ2VkX2dhdGhlci5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL3JhZ2dlZF9nYXRoZXIudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsWUFBWSxFQUF3QyxNQUFNLGlCQUFpQixDQUFDO0FBRXBGLE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUVuRCxPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBNEIvQixTQUFTLGFBQWEsQ0FDbEIsa0JBQTRCLEVBQUUsaUJBQW9DLEVBQ2xFLE9BQTBCLEVBQUUsZ0JBQXdCO0lBQ3RELE1BQU0sbUJBQW1CLEdBQUcsa0JBQWtCLENBQUMsR0FBRyxDQUM5QyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRSxDQUFDLGVBQWUsQ0FBQyxDQUFDLEVBQUUsVUFBVSxDQUFDLEVBQUUsRUFBRSxjQUFjLEVBQUUsT0FBTyxDQUFDLENBQUMsQ0FBQztJQUMxRSxNQUFNLGtCQUFrQixHQUNwQixlQUFlLENBQUMsaUJBQWlCLEVBQUUsbUJBQW1CLEVBQUUsY0FBYyxDQUFDLENBQUM7SUFDNUUsTUFBTSxRQUFRLEdBQUcsZUFBZSxDQUFDLE9BQU8sRUFBRSxTQUFTLEVBQUUsY0FBYyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBRTlFLE1BQU0sTUFBTSxHQUF1QjtRQUNqQyxrQkFBa0IsRUFBRSxtQkFBbUI7UUFDdkMsaUJBQWlCLEVBQUUsa0JBQWtCO1FBQ3JDLE9BQU8sRUFBRSxRQUFRO0tBQ2xCLENBQUM7SUFDRixNQUFNLEtBQUssR0FBc0IsRUFBQyxnQkFBZ0IsRUFBQyxDQUFDO0lBRXBELE1BQU0sTUFBTSxHQUNSLE1BQU0sQ0FBQyxTQUFTLENBQUMsWUFBWSxFQUFFLE1BQVksRUFBRSxLQUFXLENBQUMsQ0FBQztJQUM5RCxPQUFPO1FBQ0wsa0JBQWtCLEVBQUUsTUFBTSxDQUFDLEtBQUssQ0FBQyxDQUFDLEVBQUUsTUFBTSxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUM7UUFDdEQsaUJBQWlCLEVBQUUsTUFBTSxDQUFDLE1BQU0sQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDO0tBQzdDLENBQUM7QUFDSixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sWUFBWSxHQUFHLGVBQWUsQ0FBQyxFQUFFLENBQUMsRUFBQyxhQUFhLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjIgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0VOR0lORX0gZnJvbSAnLi4vZW5naW5lJztcbmltcG9ydCB7UmFnZ2VkR2F0aGVyLCBSYWdnZWRHYXRoZXJBdHRycywgUmFnZ2VkR2F0aGVySW5wdXRzfSBmcm9tICcuLi9rZXJuZWxfbmFtZXMnO1xuaW1wb3J0IHtUZW5zb3J9IGZyb20gJy4uL3RlbnNvcic7XG5pbXBvcnQge2NvbnZlcnRUb1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yX3V0aWxfZW52JztcbmltcG9ydCB7VGVuc29yTGlrZX0gZnJvbSAnLi4vdHlwZXMnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuXG4vKipcbiAqIEdhdGhlciByYWdnZWQgc2xpY2VzIGZyb20gcGFyYW1zIGF4aXMgMCBhY2NvcmRpbmcgdG8gaW5kaWNlcy5cbiAqXG4gKiBAcGFyYW0gcGFyYW1zTmVzdGVkU3BsaXRzOiBBIGxpc3Qgb2YgYXQgbGVhc3QgMSBUZW5zb3Igd2l0aCB0eXBlICdpbnQzMicgVGhlXG4gKiAgICAgbmVzdGVkUm93U3BsaXRzIHRlbnNvcnMgdGhhdCBkZWZpbmUgdGhlIHJvdy1wYXJ0aXRpb25pbmcgZm9yIHRoZSBwYXJhbXNcbiAqICAgICBSYWdnZWRUZW5zb3IgaW5wdXQuXG4gKiBAcGFyYW0gcGFyYW1zRGVuc2VWYWx1ZXM6IEEgVGVuc29yLiBUaGUgZmxhdFZhbHVlcyBmb3IgdGhlIHBhcmFtc1xuICogICAgIFJhZ2dlZFRlbnNvci5cbiAqIEBwYXJhbSBpbmRpY2VzOiBBIFRlbnNvci4gTXVzdCBiZSBvbmUgb2YgdHlwZTogaW50MzIuIEluZGljZXMgaW4gdGhlXG4gKiAgICAgb3V0ZXJtb3N0IGRpbWVuc2lvbiBvZiBwYXJhbXMgb2YgdGhlIHZhbHVlcyB0aGF0IHNob3VsZCBiZSBnYXRoZXJlZC5cbiAqIEBwYXJhbSBvdXRwdXRSYWdnZWRSYW5rOiBBbiBpbnQgdGhhdCBpcyA+PSAwLiBUaGUgcmFnZ2VkIHJhbmsgb2YgdGhlIG91dHB1dFxuICogICAgIFJhZ2dlZFRlbnNvci4gb3V0cHV0TmVzdGVkU3BsaXRzIHdpbGwgY29udGFpbiB0aGlzIG51bWJlciBvZiByb3dTcGxpdHNcbiAqICAgICB0ZW5zb3JzLiBUaGlzIHZhbHVlIHNob3VsZCBlcXVhbCBpbmRpY2VzLnNoYXBlLm5kaW1zICsgcGFyYW1zLnJhZ2dlZFJhbmtcbiAqICAgICAtIDEuXG4gKiBAcmV0dXJuIEEgbWFwIHdpdGggdGhlIGZvbGxvd2luZyBwcm9wZXJ0aWVzOlxuICogICAgIC0gb3V0cHV0TmVzdGVkU3BsaXRzOiBBIGxpc3Qgb2Ygb3V0cHV0UmFnZ2VkUmFuayBUZW5zb3Igb2JqZWN0cyB3aXRoIHRoZVxuICogc2FtZSB0eXBlIGFzIHBhcmFtc05lc3RlZFNwbGl0cy5cbiAqICAgICAtIG91dHB1dERlbnNlVmFsdWVzOiBBIFRlbnNvci4gSGFzIHRoZSBzYW1lIHR5cGUgYXMgcGFyYW1zRGVuc2VWYWx1ZXMuXG4gKiBAZG9jIHtoZWFkaW5nOiAnT3BlcmF0aW9ucycsIHN1YmhlYWRpbmc6ICdSYWdnZWQnfVxuICovXG5cbmludGVyZmFjZSBSYWdnZWRHYXRoZXJNYXAge1xuICBvdXRwdXROZXN0ZWRTcGxpdHM6IFRlbnNvcltdO1xuICBvdXRwdXREZW5zZVZhbHVlczogVGVuc29yO1xufVxuXG5mdW5jdGlvbiByYWdnZWRHYXRoZXJfKFxuICAgIHBhcmFtc05lc3RlZFNwbGl0czogVGVuc29yW10sIHBhcmFtc0RlbnNlVmFsdWVzOiBUZW5zb3J8VGVuc29yTGlrZSxcbiAgICBpbmRpY2VzOiBUZW5zb3J8VGVuc29yTGlrZSwgb3V0cHV0UmFnZ2VkUmFuazogbnVtYmVyKTogUmFnZ2VkR2F0aGVyTWFwIHtcbiAgY29uc3QgJHBhcmFtc05lc3RlZFNwbGl0cyA9IHBhcmFtc05lc3RlZFNwbGl0cy5tYXAoXG4gICAgICAodCwgaSkgPT4gY29udmVydFRvVGVuc29yKHQsIGB0ZW5zb3JzJHtpfWAsICdyYWdnZWRHYXRoZXInLCAnaW50MzInKSk7XG4gIGNvbnN0ICRwYXJhbXNEZW5zZVZhbHVlcyA9XG4gICAgICBjb252ZXJ0VG9UZW5zb3IocGFyYW1zRGVuc2VWYWx1ZXMsICdwYXJhbXNEZW5zZVZhbHVlcycsICdyYWdnZWRHYXRoZXInKTtcbiAgY29uc3QgJGluZGljZXMgPSBjb252ZXJ0VG9UZW5zb3IoaW5kaWNlcywgJ2luZGljZXMnLCAncmFnZ2VkR2F0aGVyJywgJ2ludDMyJyk7XG5cbiAgY29uc3QgaW5wdXRzOiBSYWdnZWRHYXRoZXJJbnB1dHMgPSB7XG4gICAgcGFyYW1zTmVzdGVkU3BsaXRzOiAkcGFyYW1zTmVzdGVkU3BsaXRzLFxuICAgIHBhcmFtc0RlbnNlVmFsdWVzOiAkcGFyYW1zRGVuc2VWYWx1ZXMsXG4gICAgaW5kaWNlczogJGluZGljZXMsXG4gIH07XG4gIGNvbnN0IGF0dHJzOiBSYWdnZWRHYXRoZXJBdHRycyA9IHtvdXRwdXRSYWdnZWRSYW5rfTtcblxuICBjb25zdCByZXN1bHQ6IFRlbnNvcltdID1cbiAgICAgIEVOR0lORS5ydW5LZXJuZWwoUmFnZ2VkR2F0aGVyLCBpbnB1dHMgYXMge30sIGF0dHJzIGFzIHt9KTtcbiAgcmV0dXJuIHtcbiAgICBvdXRwdXROZXN0ZWRTcGxpdHM6IHJlc3VsdC5zbGljZSgwLCByZXN1bHQubGVuZ3RoIC0gMSksXG4gICAgb3V0cHV0RGVuc2VWYWx1ZXM6IHJlc3VsdFtyZXN1bHQubGVuZ3RoIC0gMV0sXG4gIH07XG59XG5cbmV4cG9ydCBjb25zdCByYWdnZWRHYXRoZXIgPSAvKiBAX19QVVJFX18gKi8gb3Aoe3JhZ2dlZEdhdGhlcl99KTtcbiJdfQ==