/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { SparseFillEmptyRows } from '@tensorflow/tfjs-core';
|
import { sparseFillEmptyRowsImpl } from './SparseFillEmptyRows_impl';
|
export function sparseFillEmptyRows(args) {
|
const { inputs, backend } = args;
|
const { indices, values, denseShape, defaultValue } = inputs;
|
if (denseShape.shape.length !== 1) {
|
throw new Error(`Dense shape must be a vector, saw:
|
${denseShape.shape}`);
|
}
|
if (indices.shape.length !== 2) {
|
throw new Error(`Indices must be a matrix, saw:
|
${indices.shape}`);
|
}
|
if (values.shape.length !== 1) {
|
throw new Error(`Values must be a vector, saw:
|
${values.shape}`);
|
}
|
if (defaultValue.shape.length !== 0) {
|
throw new Error(`Default value must be a scalar, saw:
|
${defaultValue.shape}`);
|
}
|
const $indices = backend.data.get(indices.dataId).values;
|
const $values = backend.data.get(values.dataId).values;
|
const $denseShape = backend.data.get(denseShape.dataId).values;
|
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
|
const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImpl($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
|
return [
|
backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
|
backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
|
backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
|
backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
|
];
|
}
|
export const sparseFillEmptyRowsConfig = {
|
kernelName: SparseFillEmptyRows,
|
backendName: 'cpu',
|
kernelFunc: sparseFillEmptyRows,
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhcnNlRmlsbEVtcHR5Um93cy5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC1jcHUvc3JjL2tlcm5lbHMvU3BhcnNlRmlsbEVtcHR5Um93cy50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQTJCLG1CQUFtQixFQUFvRCxNQUFNLHVCQUF1QixDQUFDO0FBSXZJLE9BQU8sRUFBQyx1QkFBdUIsRUFBQyxNQUFNLDRCQUE0QixDQUFDO0FBRW5FLE1BQU0sVUFBVSxtQkFBbUIsQ0FBQyxJQUduQztJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQy9CLE1BQU0sRUFBQyxPQUFPLEVBQUUsTUFBTSxFQUFFLFVBQVUsRUFBRSxZQUFZLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDM0QsSUFBSSxVQUFVLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7UUFDakMsTUFBTSxJQUFJLEtBQUssQ0FBQztVQUNWLFVBQVUsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0tBQzNCO0lBQ0QsSUFBSSxPQUFPLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7UUFDOUIsTUFBTSxJQUFJLEtBQUssQ0FBQztVQUNWLE9BQU8sQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0tBQ3hCO0lBQ0QsSUFBSSxNQUFNLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7UUFDN0IsTUFBTSxJQUFJLEtBQUssQ0FBQztVQUNWLE1BQU0sQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0tBQ3ZCO0lBQ0QsSUFBSSxZQUFZLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7UUFDbkMsTUFBTSxJQUFJLEtBQUssQ0FBQztVQUNWLFlBQVksQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDO0tBQzdCO0lBRUQsTUFBTSxRQUFRLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsT0FBTyxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQW9CLENBQUM7SUFDdkUsTUFBTSxPQUFPLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQW9CLENBQUM7SUFDckUsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQW9CLENBQUM7SUFDN0UsTUFBTSxhQUFhLEdBQ2YsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsWUFBWSxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQVcsQ0FBQztJQUU5RCxNQUFNLENBQUMsYUFBYSxFQUFFLGtCQUFrQixFQUFFLFlBQVksRUFDL0MsaUJBQWlCLEVBQUUsZUFBZSxDQUFDLEdBQ3RDLHVCQUF1QixDQUNuQixRQUFRLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFBRSxPQUFPLENBQUMsS0FBSyxFQUFFLE9BQU8sRUFBRSxNQUFNLENBQUMsS0FBSyxFQUM3RCxXQUFXLEVBQUUsYUFBYSxDQUFDLENBQUM7SUFDcEMsT0FBTztRQUNMLE9BQU8sQ0FBQyxjQUFjLENBQUMsa0JBQWtCLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFBRSxhQUFhLENBQUM7UUFDeEUsT0FBTyxDQUFDLGNBQWMsQ0FDbEIsQ0FBQyxrQkFBa0IsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxLQUFLLEVBQUUsWUFBWSxDQUFDO1FBQ3hELE9BQU8sQ0FBQyxjQUFjLENBQ2xCLENBQUMsaUJBQWlCLENBQUMsTUFBTSxDQUFDLEVBQUUsTUFBTSxFQUNsQyxJQUFJLFVBQVUsQ0FDVixpQkFBaUIsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFjLEVBQUUsRUFBRSxDQUFDLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDbEUsT0FBTyxDQUFDLGNBQWMsQ0FDbEIsQ0FBQyxlQUFlLENBQUMsTUFBTSxDQUFDLEVBQUUsT0FBTyxDQUFDLEtBQUssRUFDdkMsSUFBSSxVQUFVLENBQUMsZUFBZSxDQUFDLENBQUM7S0FDckMsQ0FBQztBQUNKLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSx5QkFBeUIsR0FBaUI7SUFDckQsVUFBVSxFQUFFLG1CQUFtQjtJQUMvQixXQUFXLEVBQUUsS0FBSztJQUNsQixVQUFVLEVBQUUsbUJBQTRDO0NBQ3pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTcGFyc2VGaWxsRW1wdHlSb3dzLCBTcGFyc2VGaWxsRW1wdHlSb3dzSW5wdXRzLCBUZW5zb3JJbmZvLCBUeXBlZEFycmF5fSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kQ1BVfSBmcm9tICcuLi9iYWNrZW5kX2NwdSc7XG5cbmltcG9ydCB7c3BhcnNlRmlsbEVtcHR5Um93c0ltcGx9IGZyb20gJy4vU3BhcnNlRmlsbEVtcHR5Um93c19pbXBsJztcblxuZXhwb3J0IGZ1bmN0aW9uIHNwYXJzZUZpbGxFbXB0eVJvd3MoYXJnczoge1xuICBpbnB1dHM6IFNwYXJzZUZpbGxFbXB0eVJvd3NJbnB1dHMsXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVXG59KTogW1RlbnNvckluZm8sIFRlbnNvckluZm8sIFRlbnNvckluZm8sIFRlbnNvckluZm9dIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZH0gPSBhcmdzO1xuICBjb25zdCB7aW5kaWNlcywgdmFsdWVzLCBkZW5zZVNoYXBlLCBkZWZhdWx0VmFsdWV9ID0gaW5wdXRzO1xuICBpZiAoZGVuc2VTaGFwZS5zaGFwZS5sZW5ndGggIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYERlbnNlIHNoYXBlIG11c3QgYmUgYSB2ZWN0b3IsIHNhdzpcbiAgICAgICAgJHtkZW5zZVNoYXBlLnNoYXBlfWApO1xuICB9XG4gIGlmIChpbmRpY2VzLnNoYXBlLmxlbmd0aCAhPT0gMikge1xuICAgIHRocm93IG5ldyBFcnJvcihgSW5kaWNlcyBtdXN0IGJlIGEgbWF0cml4LCBzYXc6XG4gICAgICAgICR7aW5kaWNlcy5zaGFwZX1gKTtcbiAgfVxuICBpZiAodmFsdWVzLnNoYXBlLmxlbmd0aCAhPT0gMSkge1xuICAgIHRocm93IG5ldyBFcnJvcihgVmFsdWVzIG11c3QgYmUgYSB2ZWN0b3IsIHNhdzpcbiAgICAgICAgJHt2YWx1ZXMuc2hhcGV9YCk7XG4gIH1cbiAgaWYgKGRlZmF1bHRWYWx1ZS5zaGFwZS5sZW5ndGggIT09IDApIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoYERlZmF1bHQgdmFsdWUgbXVzdCBiZSBhIHNjYWxhciwgc2F3OlxuICAgICAgICAke2RlZmF1bHRWYWx1ZS5zaGFwZX1gKTtcbiAgfVxuXG4gIGNvbnN0ICRpbmRpY2VzID0gYmFja2VuZC5kYXRhLmdldChpbmRpY2VzLmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXk7XG4gIGNvbnN0ICR2YWx1ZXMgPSBiYWNrZW5kLmRhdGEuZ2V0KHZhbHVlcy5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuICBjb25zdCAkZGVuc2VTaGFwZSA9IGJhY2tlbmQuZGF0YS5nZXQoZGVuc2VTaGFwZS5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuICBjb25zdCAkZGVmYXVsdFZhbHVlID1cbiAgICAgIGJhY2tlbmQuZGF0YS5nZXQoZGVmYXVsdFZhbHVlLmRhdGFJZCkudmFsdWVzWzBdIGFzIG51bWJlcjtcblxuICBjb25zdCBbb3V0cHV0SW5kaWNlcywgb3V0cHV0SW5kaWNlc1NoYXBlLCBvdXRwdXRWYWx1ZXMsXG4gICAgICAgICBlbXB0eVJvd0luZGljYXRvciwgcmV2ZXJzZUluZGV4TWFwXSA9XG4gICAgICBzcGFyc2VGaWxsRW1wdHlSb3dzSW1wbChcbiAgICAgICAgICAkaW5kaWNlcywgaW5kaWNlcy5zaGFwZSwgaW5kaWNlcy5kdHlwZSwgJHZhbHVlcywgdmFsdWVzLmR0eXBlLFxuICAgICAgICAgICRkZW5zZVNoYXBlLCAkZGVmYXVsdFZhbHVlKTtcbiAgcmV0dXJuIFtcbiAgICBiYWNrZW5kLm1ha2VUZW5zb3JJbmZvKG91dHB1dEluZGljZXNTaGFwZSwgaW5kaWNlcy5kdHlwZSwgb3V0cHV0SW5kaWNlcyksXG4gICAgYmFja2VuZC5tYWtlVGVuc29ySW5mbyhcbiAgICAgICAgW291dHB1dEluZGljZXNTaGFwZVswXV0sIHZhbHVlcy5kdHlwZSwgb3V0cHV0VmFsdWVzKSxcbiAgICBiYWNrZW5kLm1ha2VUZW5zb3JJbmZvKFxuICAgICAgICBbZW1wdHlSb3dJbmRpY2F0b3IubGVuZ3RoXSwgJ2Jvb2wnLFxuICAgICAgICBuZXcgVWludDhBcnJheShcbiAgICAgICAgICAgIGVtcHR5Um93SW5kaWNhdG9yLm1hcCgodmFsdWU6IGJvb2xlYW4pID0+IE51bWJlcih2YWx1ZSkpKSksXG4gICAgYmFja2VuZC5tYWtlVGVuc29ySW5mbyhcbiAgICAgICAgW3JldmVyc2VJbmRleE1hcC5sZW5ndGhdLCBpbmRpY2VzLmR0eXBlLFxuICAgICAgICBuZXcgSW50MzJBcnJheShyZXZlcnNlSW5kZXhNYXApKSxcbiAgXTtcbn1cblxuZXhwb3J0IGNvbnN0IHNwYXJzZUZpbGxFbXB0eVJvd3NDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU3BhcnNlRmlsbEVtcHR5Um93cyxcbiAgYmFja2VuZE5hbWU6ICdjcHUnLFxuICBrZXJuZWxGdW5jOiBzcGFyc2VGaWxsRW1wdHlSb3dzIGFzIHVua25vd24gYXMgS2VybmVsRnVuYyxcbn07XG4iXX0=
|