/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, SparseToDense, util } from '@tensorflow/tfjs-core';
|
import { scatterImplCPU } from '../kernel_utils/shared';
|
import { ScatterProgram } from '../scatter_gpu';
|
import { reshape } from './Reshape';
|
export function sparseToDense(args) {
|
const { inputs, backend, attrs } = args;
|
const { sparseIndices, sparseValues, defaultValue } = inputs;
|
const { outputShape } = attrs;
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = backend_util.calculateShapes(sparseValues, sparseIndices, outputShape);
|
const sumDupeIndices = false;
|
if (sparseValues.dtype === 'string') {
|
const indicesBuf = backend.bufferSync(sparseIndices);
|
const updatesBuf = backend.bufferSync(sparseValues);
|
const $defaultValue = util.decodeString(backend.readSync(defaultValue.dataId)[0]);
|
const outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
|
}
|
const program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
|
const res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype);
|
const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: outputShape } });
|
backend.disposeIntermediateTensorInfo(res);
|
return reshaped;
|
}
|
export const sparseToDenseConfig = {
|
kernelName: SparseToDense,
|
backendName: 'webgl',
|
kernelFunc: sparseToDense
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhcnNlVG9EZW5zZS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9TcGFyc2VUb0RlbnNlLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQWtDLGFBQWEsRUFBdUQsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHN0osT0FBTyxFQUFDLGNBQWMsRUFBQyxNQUFNLHdCQUF3QixDQUFDO0FBQ3RELE9BQU8sRUFBQyxjQUFjLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUU5QyxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBRWxDLE1BQU0sVUFBVSxhQUFhLENBQUMsSUFJN0I7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLGFBQWEsRUFBRSxZQUFZLEVBQUUsWUFBWSxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQzNELE1BQU0sRUFBQyxXQUFXLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFNUIsTUFBTSxFQUFDLFNBQVMsRUFBRSxVQUFVLEVBQUUsU0FBUyxFQUFFLE9BQU8sRUFBRSxVQUFVLEVBQUMsR0FDekQsWUFBWSxDQUFDLGVBQWUsQ0FBQyxZQUFZLEVBQUUsYUFBYSxFQUFFLFdBQVcsQ0FBQyxDQUFDO0lBQzNFLE1BQU0sY0FBYyxHQUFHLEtBQUssQ0FBQztJQUU3QixJQUFJLFlBQVksQ0FBQyxLQUFLLEtBQUssUUFBUSxFQUFFO1FBQ25DLE1BQU0sVUFBVSxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQWdCLGFBQWEsQ0FBQyxDQUFDO1FBQ3BFLE1BQU0sVUFBVSxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQWlCLFlBQVksQ0FBQyxDQUFDO1FBQ3BFLE1BQU0sYUFBYSxHQUFHLElBQUksQ0FBQyxZQUFZLENBQ25DLE9BQU8sQ0FBQyxRQUFRLENBQUMsWUFBWSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBZSxDQUFDLENBQUM7UUFDNUQsTUFBTSxNQUFNLEdBQUcsY0FBYyxDQUN6QixVQUFVLEVBQUUsVUFBVSxFQUFFLFdBQVcsRUFBRSxVQUFVLEVBQUUsU0FBUyxFQUFFLFVBQVUsRUFDdEUsU0FBUyxFQUFFLE9BQU8sRUFBRSxhQUFhLEVBQUUsY0FBYyxDQUFDLENBQUM7UUFDdkQsT0FBTyxPQUFPLENBQUMsY0FBYyxDQUFDLFdBQVcsRUFBRSxNQUFNLENBQUMsS0FBSyxFQUFFLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQztLQUN6RTtJQUNELE1BQU0sT0FBTyxHQUFHLElBQUksY0FBYyxDQUM5QixVQUFVLEVBQUUsU0FBUyxFQUFFLGFBQWEsQ0FBQyxLQUFLLENBQUMsTUFBTSxFQUNqRCxZQUFZLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsQ0FBQyxVQUFVLEVBQUUsQ0FBQyxDQUFDLEVBQUUsY0FBYyxDQUFDLENBQUM7SUFFekUsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLGVBQWUsQ0FDL0IsT0FBTyxFQUFFLENBQUMsWUFBWSxFQUFFLGFBQWEsRUFBRSxZQUFZLENBQUMsRUFBRSxZQUFZLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFOUUsTUFBTSxRQUFRLEdBQ1YsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLEdBQUcsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsV0FBVyxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBRXRFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxHQUFHLENBQUMsQ0FBQztJQUMzQyxPQUFPLFFBQVEsQ0FBQztBQUNsQixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sbUJBQW1CLEdBQWlCO0lBQy9DLFVBQVUsRUFBRSxhQUFhO0lBQ3pCLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxhQUFzQztDQUNuRCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBSYW5rLCBTcGFyc2VUb0RlbnNlLCBTcGFyc2VUb0RlbnNlQXR0cnMsIFNwYXJzZVRvRGVuc2VJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5pbXBvcnQge3NjYXR0ZXJJbXBsQ1BVfSBmcm9tICcuLi9rZXJuZWxfdXRpbHMvc2hhcmVkJztcbmltcG9ydCB7U2NhdHRlclByb2dyYW19IGZyb20gJy4uL3NjYXR0ZXJfZ3B1JztcblxuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuXG5leHBvcnQgZnVuY3Rpb24gc3BhcnNlVG9EZW5zZShhcmdzOiB7XG4gIGlucHV0czogU3BhcnNlVG9EZW5zZUlucHV0cyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTCxcbiAgYXR0cnM6IFNwYXJzZVRvRGVuc2VBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7c3BhcnNlSW5kaWNlcywgc3BhcnNlVmFsdWVzLCBkZWZhdWx0VmFsdWV9ID0gaW5wdXRzO1xuICBjb25zdCB7b3V0cHV0U2hhcGV9ID0gYXR0cnM7XG5cbiAgY29uc3Qge3NsaWNlUmFuaywgbnVtVXBkYXRlcywgc2xpY2VTaXplLCBzdHJpZGVzLCBvdXRwdXRTaXplfSA9XG4gICAgICBiYWNrZW5kX3V0aWwuY2FsY3VsYXRlU2hhcGVzKHNwYXJzZVZhbHVlcywgc3BhcnNlSW5kaWNlcywgb3V0cHV0U2hhcGUpO1xuICBjb25zdCBzdW1EdXBlSW5kaWNlcyA9IGZhbHNlO1xuXG4gIGlmIChzcGFyc2VWYWx1ZXMuZHR5cGUgPT09ICdzdHJpbmcnKSB7XG4gICAgY29uc3QgaW5kaWNlc0J1ZiA9IGJhY2tlbmQuYnVmZmVyU3luYzxSYW5rLCAnaW50MzInPihzcGFyc2VJbmRpY2VzKTtcbiAgICBjb25zdCB1cGRhdGVzQnVmID0gYmFja2VuZC5idWZmZXJTeW5jPFJhbmssICdzdHJpbmcnPihzcGFyc2VWYWx1ZXMpO1xuICAgIGNvbnN0ICRkZWZhdWx0VmFsdWUgPSB1dGlsLmRlY29kZVN0cmluZyhcbiAgICAgICAgYmFja2VuZC5yZWFkU3luYyhkZWZhdWx0VmFsdWUuZGF0YUlkKVswXSBhcyBVaW50OEFycmF5KTtcbiAgICBjb25zdCBvdXRCdWYgPSBzY2F0dGVySW1wbENQVShcbiAgICAgICAgaW5kaWNlc0J1ZiwgdXBkYXRlc0J1Ziwgb3V0cHV0U2hhcGUsIG91dHB1dFNpemUsIHNsaWNlU2l6ZSwgbnVtVXBkYXRlcyxcbiAgICAgICAgc2xpY2VSYW5rLCBzdHJpZGVzLCAkZGVmYXVsdFZhbHVlLCBzdW1EdXBlSW5kaWNlcyk7XG4gICAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8ob3V0cHV0U2hhcGUsIG91dEJ1Zi5kdHlwZSwgb3V0QnVmLnZhbHVlcyk7XG4gIH1cbiAgY29uc3QgcHJvZ3JhbSA9IG5ldyBTY2F0dGVyUHJvZ3JhbShcbiAgICAgIG51bVVwZGF0ZXMsIHNsaWNlUmFuaywgc3BhcnNlSW5kaWNlcy5zaGFwZS5sZW5ndGgsXG4gICAgICBzcGFyc2VWYWx1ZXMuc2hhcGUubGVuZ3RoLCBzdHJpZGVzLCBbb3V0cHV0U2l6ZSwgMV0sIHN1bUR1cGVJbmRpY2VzKTtcblxuICBjb25zdCByZXMgPSBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShcbiAgICAgIHByb2dyYW0sIFtzcGFyc2VWYWx1ZXMsIHNwYXJzZUluZGljZXMsIGRlZmF1bHRWYWx1ZV0sIHNwYXJzZVZhbHVlcy5kdHlwZSk7XG5cbiAgY29uc3QgcmVzaGFwZWQgPVxuICAgICAgcmVzaGFwZSh7aW5wdXRzOiB7eDogcmVzfSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogb3V0cHV0U2hhcGV9fSk7XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhyZXMpO1xuICByZXR1cm4gcmVzaGFwZWQ7XG59XG5cbmV4cG9ydCBjb25zdCBzcGFyc2VUb0RlbnNlQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IFNwYXJzZVRvRGVuc2UsXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiBzcGFyc2VUb0RlbnNlIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==
|