/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { buffer, TensorBuffer } from '@tensorflow/tfjs-core';
|
export function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
|
const flattenShape = [outputSize / sliceSize, sliceSize];
|
const indicesData = indices.values;
|
const updatesData = updates.values;
|
if (outputSize === 0) {
|
return buffer(shape, updates.dtype);
|
}
|
const outBuf = (defaultValue instanceof TensorBuffer) ?
|
defaultValue :
|
buffer(flattenShape, updates.dtype);
|
if (typeof defaultValue === 'string') {
|
outBuf.values.fill(defaultValue);
|
}
|
else if (typeof defaultValue === 'number') {
|
outBuf.values.fill(defaultValue);
|
}
|
else if (typeof defaultValue === 'boolean') {
|
outBuf.values.fill(+defaultValue);
|
}
|
for (let i = 0; i < numUpdates; i++) {
|
const index = [];
|
let flattenIndex = 0;
|
for (let j = 0; j < sliceRank; j++) {
|
const dim = indicesData[i * sliceRank + j];
|
index.push(dim);
|
flattenIndex += dim * strides[j];
|
}
|
if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
|
throw new Error(`Invalid indices: ${index} does not index into ${shape}`);
|
}
|
for (let k = 0; k < sliceSize; k++) {
|
if (sumDupeIndices) {
|
outBuf.values[flattenIndex * sliceSize + k] +=
|
updatesData[i * sliceSize + k];
|
}
|
else {
|
outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ?
|
updatesData[0] :
|
updatesData[i * sliceSize + k];
|
}
|
}
|
}
|
return outBuf;
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU2NhdHRlcl9pbXBsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9TY2F0dGVyX2ltcGwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsT0FBTyxFQUFDLE1BQU0sRUFBa0IsWUFBWSxFQUFhLE1BQU0sdUJBQXVCLENBQUM7QUFTdkYsTUFBTSxVQUNOLFdBQVcsQ0FDUCxPQUFpQyxFQUFFLE9BQTJCLEVBQzlELEtBQWUsRUFBRSxVQUFrQixFQUFFLFNBQWlCLEVBQUUsVUFBa0IsRUFDMUUsU0FBaUIsRUFBRSxPQUFpQixFQUNwQyxZQUF1RCxFQUN2RCxjQUF1QjtJQUN6QixNQUFNLFlBQVksR0FBRyxDQUFDLFVBQVUsR0FBRyxTQUFTLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFFekQsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLE1BQW9CLENBQUM7SUFDakQsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLE1BQU0sQ0FBQztJQUVuQyxJQUFJLFVBQVUsS0FBSyxDQUFDLEVBQUU7UUFDcEIsT0FBTyxNQUFNLENBQUMsS0FBb0IsRUFBRSxPQUFPLENBQUMsS0FBSyxDQUFDLENBQUM7S0FDcEQ7SUFFRCxNQUFNLE1BQU0sR0FBRyxDQUFDLFlBQVksWUFBWSxZQUFZLENBQUMsQ0FBQyxDQUFDO1FBQ25ELFlBQVksQ0FBQyxDQUFDO1FBQ2QsTUFBTSxDQUFDLFlBQVksRUFBRSxPQUFPLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDeEMsSUFBSSxPQUFPLFlBQVksS0FBSyxRQUFRLEVBQUU7UUFDbkMsTUFBTSxDQUFDLE1BQW1CLENBQUMsSUFBSSxDQUFDLFlBQVksQ0FBQyxDQUFDO0tBQ2hEO1NBQU0sSUFBSSxPQUFPLFlBQVksS0FBSyxRQUFRLEVBQUU7UUFDMUMsTUFBTSxDQUFDLE1BQXFCLENBQUMsSUFBSSxDQUFDLFlBQVksQ0FBQyxDQUFDO0tBQ2xEO1NBQU0sSUFBSSxPQUFPLFlBQVksS0FBSyxTQUFTLEVBQUU7UUFDM0MsTUFBTSxDQUFDLE1BQXFCLENBQUMsSUFBSSxDQUFDLENBQUMsWUFBWSxDQUFDLENBQUM7S0FDbkQ7SUFFRCxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsVUFBVSxFQUFFLENBQUMsRUFBRSxFQUFFO1FBQ25DLE1BQU0sS0FBSyxHQUFHLEVBQUUsQ0FBQztRQUNqQixJQUFJLFlBQVksR0FBRyxDQUFDLENBQUM7UUFDckIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFNBQVMsRUFBRSxDQUFDLEVBQUUsRUFBRTtZQUNsQyxNQUFNLEdBQUcsR0FBRyxXQUFXLENBQUMsQ0FBQyxHQUFHLFNBQVMsR0FBRyxDQUFDLENBQUMsQ0FBQztZQUMzQyxLQUFLLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDO1lBQ2hCLFlBQVksSUFBSSxHQUFHLEdBQUcsT0FBTyxDQUFDLENBQUMsQ0FBQyxDQUFDO1NBQ2xDO1FBRUQsSUFBSSxZQUFZLEdBQUcsQ0FBQyxJQUFJLFlBQVksSUFBSSxVQUFVLEdBQUcsU0FBUyxFQUFFO1lBQzlELE1BQU0sSUFBSSxLQUFLLENBQUMsb0JBQW9CLEtBQUssd0JBQXdCLEtBQUssRUFBRSxDQUFDLENBQUM7U0FDM0U7UUFFRCxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsU0FBUyxFQUFFLENBQUMsRUFBRSxFQUFFO1lBQ2xDLElBQUksY0FBYyxFQUFFO2dCQUNqQixNQUFNLENBQUMsTUFBcUIsQ0FBQyxZQUFZLEdBQUcsU0FBUyxHQUFHLENBQUMsQ0FBQztvQkFDdEQsV0FBMEIsQ0FBQyxDQUFDLEdBQUcsU0FBUyxHQUFHLENBQUMsQ0FBQyxDQUFDO2FBQ3BEO2lCQUFNO2dCQUNMLE1BQU0sQ0FBQyxNQUFNLENBQUMsWUFBWSxHQUFHLFNBQVMsR0FBRyxDQUFDLENBQUMsR0FBRyxPQUFPLENBQUMsSUFBSSxLQUFLLENBQUMsQ0FBQyxDQUFDO29CQUM5RCxXQUFXLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztvQkFDaEIsV0FBVyxDQUFDLENBQUMsR0FBRyxTQUFTLEdBQUcsQ0FBQyxDQUFDLENBQUM7YUFDcEM7U0FDRjtLQUNGO0lBRUQsT0FBTyxNQUE0QixDQUFDO0FBQ3RDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5pbXBvcnQge2J1ZmZlciwgUmFuaywgU2hhcGVNYXAsIFRlbnNvckJ1ZmZlciwgVHlwZWRBcnJheX0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW50ZXJmYWNlIERlZmF1bHRWYWx1ZVR5cGVNYXAge1xuICBib29sOiBib29sZWFuO1xuICBpbnQzMjogbnVtYmVyO1xuICBmbG9hdDMyOiBudW1iZXI7XG4gIHN0cmluZzogc3RyaW5nO1xufVxuXG5leHBvcnQgZnVuY3Rpb25cbnNjYXR0ZXJJbXBsPFIgZXh0ZW5kcyBSYW5rLCBEIGV4dGVuZHMgJ2Zsb2F0MzInfCdpbnQzMid8J2Jvb2wnfCdzdHJpbmcnPihcbiAgICBpbmRpY2VzOiBUZW5zb3JCdWZmZXI8UiwgJ2ludDMyJz4sIHVwZGF0ZXM6IFRlbnNvckJ1ZmZlcjxSLCBEPixcbiAgICBzaGFwZTogbnVtYmVyW10sIG91dHB1dFNpemU6IG51bWJlciwgc2xpY2VTaXplOiBudW1iZXIsIG51bVVwZGF0ZXM6IG51bWJlcixcbiAgICBzbGljZVJhbms6IG51bWJlciwgc3RyaWRlczogbnVtYmVyW10sXG4gICAgZGVmYXVsdFZhbHVlOiBUZW5zb3JCdWZmZXI8UiwgRD58RGVmYXVsdFZhbHVlVHlwZU1hcFtEXSxcbiAgICBzdW1EdXBlSW5kaWNlczogYm9vbGVhbik6IFRlbnNvckJ1ZmZlcjxSLCBEPiB7XG4gIGNvbnN0IGZsYXR0ZW5TaGFwZSA9IFtvdXRwdXRTaXplIC8gc2xpY2VTaXplLCBzbGljZVNpemVdO1xuXG4gIGNvbnN0IGluZGljZXNEYXRhID0gaW5kaWNlcy52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgdXBkYXRlc0RhdGEgPSB1cGRhdGVzLnZhbHVlcztcblxuICBpZiAob3V0cHV0U2l6ZSA9PT0gMCkge1xuICAgIHJldHVybiBidWZmZXIoc2hhcGUgYXMgU2hhcGVNYXBbUl0sIHVwZGF0ZXMuZHR5cGUpO1xuICB9XG5cbiAgY29uc3Qgb3V0QnVmID0gKGRlZmF1bHRWYWx1ZSBpbnN0YW5jZW9mIFRlbnNvckJ1ZmZlcikgP1xuICAgICAgZGVmYXVsdFZhbHVlIDpcbiAgICAgIGJ1ZmZlcihmbGF0dGVuU2hhcGUsIHVwZGF0ZXMuZHR5cGUpO1xuICBpZiAodHlwZW9mIGRlZmF1bHRWYWx1ZSA9PT0gJ3N0cmluZycpIHtcbiAgICAob3V0QnVmLnZhbHVlcyBhcyBzdHJpbmdbXSkuZmlsbChkZWZhdWx0VmFsdWUpO1xuICB9IGVsc2UgaWYgKHR5cGVvZiBkZWZhdWx0VmFsdWUgPT09ICdudW1iZXInKSB7XG4gICAgKG91dEJ1Zi52YWx1ZXMgYXMgVHlwZWRBcnJheSkuZmlsbChkZWZhdWx0VmFsdWUpO1xuICB9IGVsc2UgaWYgKHR5cGVvZiBkZWZhdWx0VmFsdWUgPT09ICdib29sZWFuJykge1xuICAgIChvdXRCdWYudmFsdWVzIGFzIFR5cGVkQXJyYXkpLmZpbGwoK2RlZmF1bHRWYWx1ZSk7XG4gIH1cblxuICBmb3IgKGxldCBpID0gMDsgaSA8IG51bVVwZGF0ZXM7IGkrKykge1xuICAgIGNvbnN0IGluZGV4ID0gW107XG4gICAgbGV0IGZsYXR0ZW5JbmRleCA9IDA7XG4gICAgZm9yIChsZXQgaiA9IDA7IGogPCBzbGljZVJhbms7IGorKykge1xuICAgICAgY29uc3QgZGltID0gaW5kaWNlc0RhdGFbaSAqIHNsaWNlUmFuayArIGpdO1xuICAgICAgaW5kZXgucHVzaChkaW0pO1xuICAgICAgZmxhdHRlbkluZGV4ICs9IGRpbSAqIHN0cmlkZXNbal07XG4gICAgfVxuXG4gICAgaWYgKGZsYXR0ZW5JbmRleCA8IDAgfHwgZmxhdHRlbkluZGV4ID49IG91dHB1dFNpemUgLyBzbGljZVNpemUpIHtcbiAgICAgIHRocm93IG5ldyBFcnJvcihgSW52YWxpZCBpbmRpY2VzOiAke2luZGV4fSBkb2VzIG5vdCBpbmRleCBpbnRvICR7c2hhcGV9YCk7XG4gICAgfVxuXG4gICAgZm9yIChsZXQgayA9IDA7IGsgPCBzbGljZVNpemU7IGsrKykge1xuICAgICAgaWYgKHN1bUR1cGVJbmRpY2VzKSB7XG4gICAgICAgIChvdXRCdWYudmFsdWVzIGFzIFR5cGVkQXJyYXkpW2ZsYXR0ZW5JbmRleCAqIHNsaWNlU2l6ZSArIGtdICs9XG4gICAgICAgICAgICAodXBkYXRlc0RhdGEgYXMgVHlwZWRBcnJheSlbaSAqIHNsaWNlU2l6ZSArIGtdO1xuICAgICAgfSBlbHNlIHtcbiAgICAgICAgb3V0QnVmLnZhbHVlc1tmbGF0dGVuSW5kZXggKiBzbGljZVNpemUgKyBrXSA9IHVwZGF0ZXMucmFuayA9PT0gMCA/XG4gICAgICAgICAgICB1cGRhdGVzRGF0YVswXSA6XG4gICAgICAgICAgICB1cGRhdGVzRGF0YVtpICogc2xpY2VTaXplICsga107XG4gICAgICB9XG4gICAgfVxuICB9XG5cbiAgcmV0dXJuIG91dEJ1ZiBhcyBUZW5zb3JCdWZmZXI8UiwgRD47XG59XG4iXX0=
|