/** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { SparseFillEmptyRows } from '@tensorflow/tfjs-core'; import { sparseFillEmptyRowsImplCPU } from '../kernel_utils/shared'; export function sparseFillEmptyRows(args) { const { inputs, backend } = args; const { indices, values, denseShape, defaultValue } = inputs; if (denseShape.shape.length !== 1) { throw new Error(`Dense shape must be a vector, saw: ${denseShape.shape}`); } if (indices.shape.length !== 2) { throw new Error(`Indices must be a matrix, saw: ${indices.shape}`); } if (values.shape.length !== 1) { throw new Error(`Values must be a vector, saw: ${values.shape}`); } if (defaultValue.shape.length !== 0) { throw new Error(`Default value must be a scalar, saw: ${defaultValue.shape}`); } const $indices = backend.readSync(indices.dataId); const $values = backend.readSync(values.dataId); const $denseShape = backend.readSync(denseShape.dataId); const $defaultValue = backend.readSync(defaultValue.dataId)[0]; const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue); return [ backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices), backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues), backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))), backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)), ]; } export const sparseFillEmptyRowsConfig = { kernelName: SparseFillEmptyRows, backendName: 'webgl', kernelFunc: sparseFillEmptyRows, }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhcnNlRmlsbEVtcHR5Um93cy5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9TcGFyc2VGaWxsRW1wdHlSb3dzLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBMkIsbUJBQW1CLEVBQW9ELE1BQU0sdUJBQXVCLENBQUM7QUFHdkksT0FBTyxFQUFDLDBCQUEwQixFQUFDLE1BQU0sd0JBQXdCLENBQUM7QUFFbEUsTUFBTSxVQUFVLG1CQUFtQixDQUFDLElBR25DO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDL0IsTUFBTSxFQUFDLE9BQU8sRUFBRSxNQUFNLEVBQUUsVUFBVSxFQUFFLFlBQVksRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUMzRCxJQUFJLFVBQVUsQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNqQyxNQUFNLElBQUksS0FBSyxDQUFDO1dBQ1QsVUFBVSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDNUI7SUFDRCxJQUFJLE9BQU8sQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUM5QixNQUFNLElBQUksS0FBSyxDQUFDO1dBQ1QsT0FBTyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDekI7SUFDRCxJQUFJLE1BQU0sQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUM3QixNQUFNLElBQUksS0FBSyxDQUFDO1dBQ1QsTUFBTSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDeEI7SUFDRCxJQUFJLFlBQVksQ0FBQyxLQUFLLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNuQyxNQUFNLElBQUksS0FBSyxDQUFDO1VBQ1YsWUFBWSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDN0I7SUFFRCxNQUFNLFFBQVEsR0FBRyxPQUFPLENBQUMsUUFBUSxDQUFDLE9BQU8sQ0FBQyxNQUFNLENBQWUsQ0FBQztJQUNoRSxNQUFNLE9BQU8sR0FBRyxPQUFPLENBQUMsUUFBUSxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQWUsQ0FBQztJQUM5RCxNQUFNLFdBQVcsR0FBRyxPQUFPLENBQUMsUUFBUSxDQUFDLFVBQVUsQ0FBQyxNQUFNLENBQWUsQ0FBQztJQUN0RSxNQUFNLGFBQWEsR0FDZixPQUFPLENBQUMsUUFBUSxDQUFDLFlBQVksQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQVcsQ0FBQztJQUV2RCxNQUFNLENBQUMsYUFBYSxFQUFFLGtCQUFrQixFQUFFLFlBQVksRUFDL0MsaUJBQWlCLEVBQUUsZUFBZSxDQUFDLEdBQ3RDLDBCQUEwQixDQUN0QixRQUFRLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFBRSxPQUFPLENBQUMsS0FBSyxFQUFFLE9BQU8sRUFBRSxNQUFNLENBQUMsS0FBSyxFQUM3RCxXQUFXLEVBQUUsYUFBYSxDQUFDLENBQUM7SUFDcEMsT0FBTztRQUNMLE9BQU8sQ0FBQyxjQUFjLENBQUMsa0JBQWtCLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFBRSxhQUFhLENBQUM7UUFDeEUsT0FBTyxDQUFDLGNBQWMsQ0FDbEIsQ0FBQyxrQkFBa0IsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxLQUFLLEVBQUUsWUFBWSxDQUFDO1FBQ3hELE9BQU8sQ0FBQyxjQUFjLENBQ2xCLENBQUMsaUJBQWlCLENBQUMsTUFBTSxDQUFDLEVBQUUsTUFBTSxFQUNsQyxJQUFJLFVBQVUsQ0FDVixpQkFBaUIsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFjLEVBQUUsRUFBRSxDQUFDLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDbEUsT0FBTyxDQUFDLGNBQWMsQ0FDbEIsQ0FBQyxlQUFlLENBQUMsTUFBTSxDQUFDLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFDdkMsSUFBSSxVQUFVLENBQUMsZUFBZSxDQUFDLENBQUM7S0FDckMsQ0FBQztBQUNKLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSx5QkFBeUIsR0FBaUI7SUFDckQsVUFBVSxFQUFFLG1CQUFtQjtJQUMvQixXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsbUJBQTRDO0NBQ3pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTcGFyc2VGaWxsRW1wdHlSb3dzLCBTcGFyc2VGaWxsRW1wdHlSb3dzSW5wdXRzLCBUZW5zb3JJbmZvLCBUeXBlZEFycmF5fSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kV2ViR0x9IGZyb20gJy4uL2JhY2tlbmRfd2ViZ2wnO1xuaW1wb3J0IHtzcGFyc2VGaWxsRW1wdHlSb3dzSW1wbENQVX0gZnJvbSAnLi4va2VybmVsX3V0aWxzL3NoYXJlZCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzcGFyc2VGaWxsRW1wdHlSb3dzKGFyZ3M6IHtcbiAgaW5wdXRzOiBTcGFyc2VGaWxsRW1wdHlSb3dzSW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMXG59KTogW1RlbnNvckluZm8sIFRlbnNvckluZm8sIFRlbnNvckluZm8sIFRlbnNvckluZm9dIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZH0gPSBhcmdzO1xuICBjb25zdCB7aW5kaWNlcywgdmFsdWVzLCBkZW5zZVNoYXBlLCBkZWZhdWx0VmFsdWV9ID0gaW5wdXRzO1xuICBpZiAoZGVuc2VTaGFwZS5zaGFwZS5sZW5ndGggIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYERlbnNlIHNoYXBlIG11c3QgYmUgYSB2ZWN0b3IsIHNhdzpcbiAgICAgICAgICR7ZGVuc2VTaGFwZS5zaGFwZX1gKTtcbiAgfVxuICBpZiAoaW5kaWNlcy5zaGFwZS5sZW5ndGggIT09IDIpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYEluZGljZXMgbXVzdCBiZSBhIG1hdHJpeCwgc2F3OlxuICAgICAgICAgJHtpbmRpY2VzLnNoYXBlfWApO1xuICB9XG4gIGlmICh2YWx1ZXMuc2hhcGUubGVuZ3RoICE9PSAxKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKGBWYWx1ZXMgbXVzdCBiZSBhIHZlY3Rvciwgc2F3OlxuICAgICAgICAgJHt2YWx1ZXMuc2hhcGV9YCk7XG4gIH1cbiAgaWYgKGRlZmF1bHRWYWx1ZS5zaGFwZS5sZW5ndGggIT09IDApIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYERlZmF1bHQgdmFsdWUgbXVzdCBiZSBhIHNjYWxhciwgc2F3OlxuICAgICAgICAke2RlZmF1bHRWYWx1ZS5zaGFwZX1gKTtcbiAgfVxuXG4gIGNvbnN0ICRpbmRpY2VzID0gYmFja2VuZC5yZWFkU3luYyhpbmRpY2VzLmRhdGFJZCkgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgJHZhbHVlcyA9IGJhY2tlbmQucmVhZFN5bmModmFsdWVzLmRhdGFJZCkgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgJGRlbnNlU2hhcGUgPSBiYWNrZW5kLnJlYWRTeW5jKGRlbnNlU2hhcGUuZGF0YUlkKSBhcyBUeXBlZEFycmF5O1xuICBjb25zdCAkZGVmYXVsdFZhbHVlID1cbiAgICAgIGJhY2tlbmQucmVhZFN5bmMoZGVmYXVsdFZhbHVlLmRhdGFJZClbMF0gYXMgbnVtYmVyO1xuXG4gIGNvbnN0IFtvdXRwdXRJbmRpY2VzLCBvdXRwdXRJbmRpY2VzU2hhcGUsIG91dHB1dFZhbHVlcyxcbiAgICAgICAgIGVtcHR5Um93SW5kaWNhdG9yLCByZXZlcnNlSW5kZXhNYXBdID1cbiAgICAgIHNwYXJzZUZpbGxFbXB0eVJvd3NJbXBsQ1BVKFxuICAgICAgICAgICRpbmRpY2VzLCBpbmRpY2VzLnNoYXBlLCBpbmRpY2VzLmR0eXBlLCAkdmFsdWVzLCB2YWx1ZXMuZHR5cGUsXG4gICAgICAgICAgJGRlbnNlU2hhcGUsICRkZWZhdWx0VmFsdWUpO1xuICByZXR1cm4gW1xuICAgIGJhY2tlbmQubWFrZVRlbnNvckluZm8ob3V0cHV0SW5kaWNlc1NoYXBlLCBpbmRpY2VzLmR0eXBlLCBvdXRwdXRJbmRpY2VzKSxcbiAgICBiYWNrZW5kLm1ha2VUZW5zb3JJbmZvKFxuICAgICAgICBbb3V0cHV0SW5kaWNlc1NoYXBlWzBdXSwgdmFsdWVzLmR0eXBlLCBvdXRwdXRWYWx1ZXMpLFxuICAgIGJhY2tlbmQubWFrZVRlbnNvckluZm8oXG4gICAgICAgIFtlbXB0eVJvd0luZGljYXRvci5sZW5ndGhdLCAnYm9vbCcsXG4gICAgICAgIG5ldyBVaW50OEFycmF5KFxuICAgICAgICAgICAgZW1wdHlSb3dJbmRpY2F0b3IubWFwKCh2YWx1ZTogYm9vbGVhbikgPT4gTnVtYmVyKHZhbHVlKSkpKSxcbiAgICBiYWNrZW5kLm1ha2VUZW5zb3JJbmZvKFxuICAgICAgICBbcmV2ZXJzZUluZGV4TWFwLmxlbmd0aF0sIGluZGljZXMuZHR5cGUsXG4gICAgICAgIG5ldyBJbnQzMkFycmF5KHJldmVyc2VJbmRleE1hcCkpLFxuICBdO1xufVxuXG5leHBvcnQgY29uc3Qgc3BhcnNlRmlsbEVtcHR5Um93c0NvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBTcGFyc2VGaWxsRW1wdHlSb3dzLFxuICBiYWNrZW5kTmFtZTogJ3dlYmdsJyxcbiAga2VybmVsRnVuYzogc3BhcnNlRmlsbEVtcHR5Um93cyBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmMsXG59O1xuIl19