/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { SparseFillEmptyRows } from '@tensorflow/tfjs-core';
|
import { sparseFillEmptyRowsImplCPU } from '../kernel_utils/shared';
|
export function sparseFillEmptyRows(args) {
|
const { inputs, backend } = args;
|
const { indices, values, denseShape, defaultValue } = inputs;
|
if (denseShape.shape.length !== 1) {
|
throw new Error(`Dense shape must be a vector, saw:
|
${denseShape.shape}`);
|
}
|
if (indices.shape.length !== 2) {
|
throw new Error(`Indices must be a matrix, saw:
|
${indices.shape}`);
|
}
|
if (values.shape.length !== 1) {
|
throw new Error(`Values must be a vector, saw:
|
${values.shape}`);
|
}
|
if (defaultValue.shape.length !== 0) {
|
throw new Error(`Default value must be a scalar, saw:
|
${defaultValue.shape}`);
|
}
|
const $indices = backend.readSync(indices.dataId);
|
const $values = backend.readSync(values.dataId);
|
const $denseShape = backend.readSync(denseShape.dataId);
|
const $defaultValue = backend.readSync(defaultValue.dataId)[0];
|
const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
|
return [
|
backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
|
backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
|
backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
|
backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
|
];
|
}
|
export const sparseFillEmptyRowsConfig = {
|
kernelName: SparseFillEmptyRows,
|
backendName: 'webgl',
|
kernelFunc: sparseFillEmptyRows,
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhcnNlRmlsbEVtcHR5Um93cy5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9TcGFyc2VGaWxsRW1wdHlSb3dzLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBMkIsbUJBQW1CLEVBQW9ELE1BQU0sdUJBQXVCLENBQUM7QUFHdkksT0FBTyxFQUFDLDBCQUEwQixFQUFDLE1BQU0sd0JBQXdCLENBQUM7QUFFbEUsTUFBTSxVQUFVLG1CQUFtQixDQUFDLElBR25DO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDL0IsTUFBTSxFQUFDLE9BQU8sRUFBRSxNQUFNLEVBQUUsVUFBVSxFQUFFLFlBQVksRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUMzRCxJQUFJLFVBQVUsQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNqQyxNQUFNLElBQUksS0FBSyxDQUFDO1dBQ1QsVUFBVSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDNUI7SUFDRCxJQUFJLE9BQU8sQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUM5QixNQUFNLElBQUksS0FBSyxDQUFDO1dBQ1QsT0FBTyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDekI7SUFDRCxJQUFJLE1BQU0sQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUM3QixNQUFNLElBQUksS0FBSyxDQUFDO1dBQ1QsTUFBTSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDeEI7SUFDRCxJQUFJLFlBQVksQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNuQyxNQUFNLElBQUksS0FBSyxDQUFDO1VBQ1YsWUFBWSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDN0I7SUFFRCxNQUFNLFFBQVEsR0FBRyxPQUFPLENBQUMsUUFBUSxDQUFDLE9BQU8sQ0FBQyxNQUFNLENBQWUsQ0FBQztJQUNoRSxNQUFNLE9BQU8sR0FBRyxPQUFPLENBQUMsUUFBUSxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQWUsQ0FBQztJQUM5RCxNQUFNLFdBQVcsR0FBRyxPQUFPLENBQUMsUUFBUSxDQUFDLFVBQVUsQ0FBQyxNQUFNLENBQWUsQ0FBQztJQUN0RSxNQUFNLGFBQWEsR0FDZixPQUFPLENBQUMsUUFBUSxDQUFDLFlBQVksQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQVcsQ0FBQztJQUV2RCxNQUFNLENBQUMsYUFBYSxFQUFFLGtCQUFrQixFQUFFLFlBQVksRUFDL0MsaUJBQWlCLEVBQUUsZUFBZSxDQUFDLEdBQ3RDLDBCQUEwQixDQUN0QixRQUFRLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFBRSxPQUFPLENBQUMsS0FBSyxFQUFFLE9BQU8sRUFBRSxNQUFNLENBQUMsS0FBSyxFQUM3RCxXQUFXLEVBQUUsYUFBYSxDQUFDLENBQUM7SUFDcEMsT0FBTztRQUNMLE9BQU8sQ0FBQyxjQUFjLENBQUMsa0JBQWtCLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFBRSxhQUFhLENBQUM7UUFDeEUsT0FBTyxDQUFDLGNBQWMsQ0FDbEIsQ0FBQyxrQkFBa0IsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxLQUFLLEVBQUUsWUFBWSxDQUFDO1FBQ3hELE9BQU8sQ0FBQyxjQUFjLENBQ2xCLENBQUMsaUJBQWlCLENBQUMsTUFBTSxDQUFDLEVBQUUsTUFBTSxFQUNsQyxJQUFJLFVBQVUsQ0FDVixpQkFBaUIsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFjLEVBQUUsRUFBRSxDQUFDLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDbEUsT0FBTyxDQUFDLGNBQWMsQ0FDbEIsQ0FBQyxlQUFlLENBQUMsTUFBTSxDQUFDLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFDdkMsSUFBSSxVQUFVLENBQUMsZUFBZSxDQUFDLENBQUM7S0FDckMsQ0FBQztBQUNKLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSx5QkFBeUIsR0FBaUI7SUFDckQsVUFBVSxFQUFFLG1CQUFtQjtJQUMvQixXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsbUJBQTRDO0NBQ3pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTcGFyc2VGaWxsRW1wdHlSb3dzLCBTcGFyc2VGaWxsRW1wdHlSb3dzSW5wdXRzLCBUZW5zb3JJbmZvLCBUeXBlZEFycmF5fSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kV2ViR0x9IGZyb20gJy4uL2JhY2tlbmRfd2ViZ2wnO1xuaW1wb3J0IHtzcGFyc2VGaWxsRW1wdHlSb3dzSW1wbENQVX0gZnJvbSAnLi4va2VybmVsX3V0aWxzL3NoYXJlZCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzcGFyc2VGaWxsRW1wdHlSb3dzKGFyZ3M6IHtcbiAgaW5wdXRzOiBTcGFyc2VGaWxsRW1wdHlSb3dzSW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMXG59KTogW1RlbnNvckluZm8sIFRlbnNvckluZm8sIFRlbnNvckluZm8sIFRlbnNvckluZm9dIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZH0gPSBhcmdzO1xuICBjb25zdCB7aW5kaWNlcywgdmFsdWVzLCBkZW5zZVNoYXBlLCBkZWZhdWx0VmFsdWV9ID0gaW5wdXRzO1xuICBpZiAoZGVuc2VTaGFwZS5zaGFwZS5sZW5ndGggIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYERlbnNlIHNoYXBlIG11c3QgYmUgYSB2ZWN0b3IsIHNhdzpcbiAgICAgICAgICR7ZGVuc2VTaGFwZS5zaGFwZX1gKTtcbiAgfVxuICBpZiAoaW5kaWNlcy5zaGFwZS5sZW5ndGggIT09IDIpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYEluZGljZXMgbXVzdCBiZSBhIG1hdHJpeCwgc2F3OlxuICAgICAgICAgJHtpbmRpY2VzLnNoYXBlfWApO1xuICB9XG4gIGlmICh2YWx1ZXMuc2hhcGUubGVuZ3RoICE9PSAxKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKGBWYWx1ZXMgbXVzdCBiZSBhIHZlY3Rvciwgc2F3OlxuICAgICAgICAgJHt2YWx1ZXMuc2hhcGV9YCk7XG4gIH1cbiAgaWYgKGRlZmF1bHRWYWx1ZS5zaGFwZS5sZW5ndGggIT09IDApIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYERlZmF1bHQgdmFsdWUgbXVzdCBiZSBhIHNjYWxhciwgc2F3OlxuICAgICAgICAke2RlZmF1bHRWYWx1ZS5zaGFwZX1gKTtcbiAgfVxuXG4gIGNvbnN0ICRpbmRpY2VzID0gYmFja2VuZC5yZWFkU3luYyhpbmRpY2VzLmRhdGFJZCkgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgJHZhbHVlcyA9IGJhY2tlbmQucmVhZFN5bmModmFsdWVzLmRhdGFJZCkgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgJGRlbnNlU2hhcGUgPSBiYWNrZW5kLnJlYWRTeW5jKGRlbnNlU2hhcGUuZGF0YUlkKSBhcyBUeXBlZEFycmF5O1xuICBjb25zdCAkZGVmYXVsdFZhbHVlID1cbiAgICAgIGJhY2tlbmQucmVhZFN5bmMoZGVmYXVsdFZhbHVlLmRhdGFJZClbMF0gYXMgbnVtYmVyO1xuXG4gIGNvbnN0IFtvdXRwdXRJbmRpY2VzLCBvdXRwdXRJbmRpY2VzU2hhcGUsIG91dHB1dFZhbHVlcyxcbiAgICAgICAgIGVtcHR5Um93SW5kaWNhdG9yLCByZXZlcnNlSW5kZXhNYXBdID1cbiAgICAgIHNwYXJzZUZpbGxFbXB0eVJvd3NJbXBsQ1BVKFxuICAgICAgICAgICRpbmRpY2VzLCBpbmRpY2VzLnNoYXBlLCBpbmRpY2VzLmR0eXBlLCAkdmFsdWVzLCB2YWx1ZXMuZHR5cGUsXG4gICAgICAgICAgJGRlbnNlU2hhcGUsICRkZWZhdWx0VmFsdWUpO1xuICByZXR1cm4gW1xuICAgIGJhY2tlbmQubWFrZVRlbnNvckluZm8ob3V0cHV0SW5kaWNlc1NoYXBlLCBpbmRpY2VzLmR0eXBlLCBvdXRwdXRJbmRpY2VzKSxcbiAgICBiYWNrZW5kLm1ha2VUZW5zb3JJbmZvKFxuICAgICAgICBbb3V0cHV0SW5kaWNlc1NoYXBlWzBdXSwgdmFsdWVzLmR0eXBlLCBvdXRwdXRWYWx1ZXMpLFxuICAgIGJhY2tlbmQubWFrZVRlbnNvckluZm8oXG4gICAgICAgIFtlbXB0eVJvd0luZGljYXRvci5sZW5ndGhdLCAnYm9vbCcsXG4gICAgICAgIG5ldyBVaW50OEFycmF5KFxuICAgICAgICAgICAgZW1wdHlSb3dJbmRpY2F0b3IubWFwKCh2YWx1ZTogYm9vbGVhbikgPT4gTnVtYmVyKHZhbHVlKSkpKSxcbiAgICBiYWNrZW5kLm1ha2VUZW5zb3JJbmZvKFxuICAgICAgICBbcmV2ZXJzZUluZGV4TWFwLmxlbmd0aF0sIGluZGljZXMuZHR5cGUsXG4gICAgICAgIG5ldyBJbnQzMkFycmF5KHJldmVyc2VJbmRleE1hcCkpLFxuICBdO1xufVxuXG5leHBvcnQgY29uc3Qgc3BhcnNlRmlsbEVtcHR5Um93c0NvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBTcGFyc2VGaWxsRW1wdHlSb3dzLFxuICBiYWNrZW5kTmFtZTogJ3dlYmdsJyxcbiAga2VybmVsRnVuYzogc3BhcnNlRmlsbEVtcHR5Um93cyBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmMsXG59O1xuIl19
|