/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { Multinomial, util } from '@tensorflow/tfjs-core';
|
import * as seedrandom from 'seedrandom';
|
import { assertNotComplex } from '../cpu_util';
|
import { softmax } from './Softmax';
|
export function multinomial(args) {
|
const { inputs, backend, attrs } = args;
|
const { logits } = inputs;
|
const { numSamples, seed, normalized } = attrs;
|
assertNotComplex(logits, 'multinomial');
|
const probabilities = normalized ?
|
logits :
|
softmax({ inputs: { logits }, backend, attrs: { dim: -1 } });
|
const batchSize = probabilities.shape[0];
|
const numEvents = probabilities.shape[1];
|
const probVals = backend.data.get(probabilities.dataId).values;
|
const resShape = [batchSize, numSamples];
|
const resVals = util.makeZerosTypedArray(util.sizeFromShape(resShape), 'int32');
|
for (let b = 0; b < batchSize; ++b) {
|
const offset = b * numEvents;
|
// The cdf won't include the last event. It will be implicit if no other
|
// event happened.
|
const cdf = new Float32Array(numEvents - 1);
|
cdf[0] = probVals[offset];
|
for (let event = 1; event < cdf.length; ++event) {
|
cdf[event] = cdf[event - 1] + probVals[offset + event];
|
}
|
const random = seedrandom.alea(seed.toString());
|
const outOffset = b * numSamples;
|
for (let sampleId = 0; sampleId < numSamples; ++sampleId) {
|
const r = random();
|
// Assume last event happened by default.
|
resVals[outOffset + sampleId] = cdf.length;
|
for (let event = 0; event < cdf.length; event++) {
|
if (r < cdf[event]) {
|
resVals[outOffset + sampleId] = event;
|
break;
|
}
|
}
|
}
|
}
|
if (!normalized) {
|
backend.disposeIntermediateTensorInfo(probabilities);
|
}
|
return backend.makeTensorInfo(resShape, 'int32', resVals);
|
}
|
export const multinomialConfig = {
|
kernelName: Multinomial,
|
backendName: 'cpu',
|
kernelFunc: multinomial
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTXVsdGlub21pYWwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL011bHRpbm9taWFsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBMkIsV0FBVyxFQUErRCxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUMvSSxPQUFPLEtBQUssVUFBVSxNQUFNLFlBQVksQ0FBQztBQUd6QyxPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFN0MsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUVsQyxNQUFNLFVBQVUsV0FBVyxDQUFDLElBSTNCO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxNQUFNLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDeEIsTUFBTSxFQUFDLFVBQVUsRUFBRSxJQUFJLEVBQUUsVUFBVSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRTdDLGdCQUFnQixDQUFDLE1BQU0sRUFBRSxhQUFhLENBQUMsQ0FBQztJQUV4QyxNQUFNLGFBQWEsR0FBRyxVQUFVLENBQUMsQ0FBQztRQUM5QixNQUFNLENBQUMsQ0FBQztRQUNSLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLE1BQU0sRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLEVBQUMsRUFBQyxDQUFDLENBQUM7SUFFM0QsTUFBTSxTQUFTLEdBQUcsYUFBYSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN6QyxNQUFNLFNBQVMsR0FBRyxhQUFhLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ3pDLE1BQU0sUUFBUSxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLGFBQWEsQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBQzdFLE1BQU0sUUFBUSxHQUFHLENBQUMsU0FBUyxFQUFFLFVBQVUsQ0FBQyxDQUFDO0lBQ3pDLE1BQU0sT0FBTyxHQUNULElBQUksQ0FBQyxtQkFBbUIsQ0FBQyxJQUFJLENBQUMsYUFBYSxDQUFDLFFBQVEsQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBRXBFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxTQUFTLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDbEMsTUFBTSxNQUFNLEdBQUcsQ0FBQyxHQUFHLFNBQVMsQ0FBQztRQUM3Qix3RUFBd0U7UUFDeEUsa0JBQWtCO1FBQ2xCLE1BQU0sR0FBRyxHQUFHLElBQUksWUFBWSxDQUFDLFNBQVMsR0FBRyxDQUFDLENBQUMsQ0FBQztRQUM1QyxHQUFHLENBQUMsQ0FBQyxDQUFDLEdBQUcsUUFBUSxDQUFDLE1BQU0sQ0FBQyxDQUFDO1FBQzFCLEtBQUssSUFBSSxLQUFLLEdBQUcsQ0FBQyxFQUFFLEtBQUssR0FBRyxHQUFHLENBQUMsTUFBTSxFQUFFLEVBQUUsS0FBSyxFQUFFO1lBQy9DLEdBQUcsQ0FBQyxLQUFLLENBQUMsR0FBRyxHQUFHLENBQUMsS0FBSyxHQUFHLENBQUMsQ0FBQyxHQUFHLFFBQVEsQ0FBQyxNQUFNLEdBQUcsS0FBSyxDQUFDLENBQUM7U0FDeEQ7UUFFRCxNQUFNLE1BQU0sR0FBRyxVQUFVLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxRQUFRLEVBQUUsQ0FBQyxDQUFDO1FBQ2hELE1BQU0sU0FBUyxHQUFHLENBQUMsR0FBRyxVQUFVLENBQUM7UUFDakMsS0FBSyxJQUFJLFFBQVEsR0FBRyxDQUFDLEVBQUUsUUFBUSxHQUFHLFVBQVUsRUFBRSxFQUFFLFFBQVEsRUFBRTtZQUN4RCxNQUFNLENBQUMsR0FBRyxNQUFNLEVBQUUsQ0FBQztZQUVuQix5Q0FBeUM7WUFDekMsT0FBTyxDQUFDLFNBQVMsR0FBRyxRQUFRLENBQUMsR0FBRyxHQUFHLENBQUMsTUFBTSxDQUFDO1lBRTNDLEtBQUssSUFBSSxLQUFLLEdBQUcsQ0FBQyxFQUFFLEtBQUssR0FBRyxHQUFHLENBQUMsTUFBTSxFQUFFLEtBQUssRUFBRSxFQUFFO2dCQUMvQyxJQUFJLENBQUMsR0FBRyxHQUFHLENBQUMsS0FBSyxDQUFDLEVBQUU7b0JBQ2xCLE9BQU8sQ0FBQyxTQUFTLEdBQUcsUUFBUSxDQUFDLEdBQUcsS0FBSyxDQUFDO29CQUN0QyxNQUFNO2lCQUNQO2FBQ0Y7U0FDRjtLQUNGO0lBRUQsSUFBSSxDQUFDLFVBQVUsRUFBRTtRQUNmLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxhQUFhLENBQUMsQ0FBQztLQUN0RDtJQUVELE9BQU8sT0FBTyxDQUFDLGNBQWMsQ0FBQyxRQUFRLEVBQUUsT0FBTyxFQUFFLE9BQU8sQ0FBQyxDQUFDO0FBQzVELENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxpQkFBaUIsR0FBaUI7SUFDN0MsVUFBVSxFQUFFLFdBQVc7SUFDdkIsV0FBVyxFQUFFLEtBQUs7SUFDbEIsVUFBVSxFQUFFLFdBQW9DO0NBQ2pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBNdWx0aW5vbWlhbCwgTXVsdGlub21pYWxBdHRycywgTXVsdGlub21pYWxJbnB1dHMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5pbXBvcnQgKiBhcyBzZWVkcmFuZG9tIGZyb20gJ3NlZWRyYW5kb20nO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kQ1BVfSBmcm9tICcuLi9iYWNrZW5kX2NwdSc7XG5pbXBvcnQge2Fzc2VydE5vdENvbXBsZXh9IGZyb20gJy4uL2NwdV91dGlsJztcblxuaW1wb3J0IHtzb2Z0bWF4fSBmcm9tICcuL1NvZnRtYXgnO1xuXG5leHBvcnQgZnVuY3Rpb24gbXVsdGlub21pYWwoYXJnczoge1xuICBpbnB1dHM6IE11bHRpbm9taWFsSW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSxcbiAgYXR0cnM6IE11bHRpbm9taWFsQXR0cnNcbn0pOiBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge2xvZ2l0c30gPSBpbnB1dHM7XG4gIGNvbnN0IHtudW1TYW1wbGVzLCBzZWVkLCBub3JtYWxpemVkfSA9IGF0dHJzO1xuXG4gIGFzc2VydE5vdENvbXBsZXgobG9naXRzLCAnbXVsdGlub21pYWwnKTtcblxuICBjb25zdCBwcm9iYWJpbGl0aWVzID0gbm9ybWFsaXplZCA/XG4gICAgICBsb2dpdHMgOlxuICAgICAgc29mdG1heCh7aW5wdXRzOiB7bG9naXRzfSwgYmFja2VuZCwgYXR0cnM6IHtkaW06IC0xfX0pO1xuXG4gIGNvbnN0IGJhdGNoU2l6ZSA9IHByb2JhYmlsaXRpZXMuc2hhcGVbMF07XG4gIGNvbnN0IG51bUV2ZW50cyA9IHByb2JhYmlsaXRpZXMuc2hhcGVbMV07XG4gIGNvbnN0IHByb2JWYWxzID0gYmFja2VuZC5kYXRhLmdldChwcm9iYWJpbGl0aWVzLmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXk7XG4gIGNvbnN0IHJlc1NoYXBlID0gW2JhdGNoU2l6ZSwgbnVtU2FtcGxlc107XG4gIGNvbnN0IHJlc1ZhbHMgPVxuICAgICAgdXRpbC5tYWtlWmVyb3NUeXBlZEFycmF5KHV0aWwuc2l6ZUZyb21TaGFwZShyZXNTaGFwZSksICdpbnQzMicpO1xuXG4gIGZvciAobGV0IGIgPSAwOyBiIDwgYmF0Y2hTaXplOyArK2IpIHtcbiAgICBjb25zdCBvZmZzZXQgPSBiICogbnVtRXZlbnRzO1xuICAgIC8vIFRoZSBjZGYgd29uJ3QgaW5jbHVkZSB0aGUgbGFzdCBldmVudC4gSXQgd2lsbCBiZSBpbXBsaWNpdCBpZiBubyBvdGhlclxuICAgIC8vIGV2ZW50IGhhcHBlbmVkLlxuICAgIGNvbnN0IGNkZiA9IG5ldyBGbG9hdDMyQXJyYXkobnVtRXZlbnRzIC0gMSk7XG4gICAgY2RmWzBdID0gcHJvYlZhbHNbb2Zmc2V0XTtcbiAgICBmb3IgKGxldCBldmVudCA9IDE7IGV2ZW50IDwgY2RmLmxlbmd0aDsgKytldmVudCkge1xuICAgICAgY2RmW2V2ZW50XSA9IGNkZltldmVudCAtIDFdICsgcHJvYlZhbHNbb2Zmc2V0ICsgZXZlbnRdO1xuICAgIH1cblxuICAgIGNvbnN0IHJhbmRvbSA9IHNlZWRyYW5kb20uYWxlYShzZWVkLnRvU3RyaW5nKCkpO1xuICAgIGNvbnN0IG91dE9mZnNldCA9IGIgKiBudW1TYW1wbGVzO1xuICAgIGZvciAobGV0IHNhbXBsZUlkID0gMDsgc2FtcGxlSWQgPCBudW1TYW1wbGVzOyArK3NhbXBsZUlkKSB7XG4gICAgICBjb25zdCByID0gcmFuZG9tKCk7XG5cbiAgICAgIC8vIEFzc3VtZSBsYXN0IGV2ZW50IGhhcHBlbmVkIGJ5IGRlZmF1bHQuXG4gICAgICByZXNWYWxzW291dE9mZnNldCArIHNhbXBsZUlkXSA9IGNkZi5sZW5ndGg7XG5cbiAgICAgIGZvciAobGV0IGV2ZW50ID0gMDsgZXZlbnQgPCBjZGYubGVuZ3RoOyBldmVudCsrKSB7XG4gICAgICAgIGlmIChyIDwgY2RmW2V2ZW50XSkge1xuICAgICAgICAgIHJlc1ZhbHNbb3V0T2Zmc2V0ICsgc2FtcGxlSWRdID0gZXZlbnQ7XG4gICAgICAgICAgYnJlYWs7XG4gICAgICAgIH1cbiAgICAgIH1cbiAgICB9XG4gIH1cblxuICBpZiAoIW5vcm1hbGl6ZWQpIHtcbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHByb2JhYmlsaXRpZXMpO1xuICB9XG5cbiAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8ocmVzU2hhcGUsICdpbnQzMicsIHJlc1ZhbHMpO1xufVxuXG5leHBvcnQgY29uc3QgbXVsdGlub21pYWxDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogTXVsdGlub21pYWwsXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogbXVsdGlub21pYWwgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19
|