/**
|
* @license
|
* Copyright 2023 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import * as tf from '@tensorflow/tfjs-core';
|
import { util, env, tensor, tensor1d, tensor2d, browser, tidy, expandDims, cast, image, reshape } from '@tensorflow/tfjs-core';
|
|
var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
|
|
function getAugmentedNamespace(n) {
|
if (n.__esModule) return n;
|
var f = n.default;
|
if (typeof f == "function") {
|
var a = function a () {
|
if (this instanceof a) {
|
var args = [null];
|
args.push.apply(args, arguments);
|
var Ctor = Function.bind.apply(f, args);
|
return new Ctor();
|
}
|
return f.apply(this, arguments);
|
};
|
a.prototype = f.prototype;
|
} else a = {};
|
Object.defineProperty(a, '__esModule', {value: true});
|
Object.keys(n).forEach(function (k) {
|
var d = Object.getOwnPropertyDescriptor(n, k);
|
Object.defineProperty(a, k, d.get ? d : {
|
enumerable: true,
|
get: function () {
|
return n[k];
|
}
|
});
|
});
|
return a;
|
}
|
|
var alea$1 = {exports: {}};
|
|
(function (module) {
|
// A port of an algorithm by Johannes Baagøe <baagoe@baagoe.com>, 2010
|
// http://baagoe.com/en/RandomMusings/javascript/
|
// https://github.com/nquinlan/better-random-numbers-for-javascript-mirror
|
// Original work is under MIT license -
|
|
// Copyright (C) 2010 by Johannes Baagøe <baagoe@baagoe.org>
|
//
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
// of this software and associated documentation files (the "Software"), to deal
|
// in the Software without restriction, including without limitation the rights
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
// copies of the Software, and to permit persons to whom the Software is
|
// furnished to do so, subject to the following conditions:
|
//
|
// The above copyright notice and this permission notice shall be included in
|
// all copies or substantial portions of the Software.
|
//
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
// THE SOFTWARE.
|
|
|
|
(function(global, module, define) {
|
|
function Alea(seed) {
|
var me = this, mash = Mash();
|
|
me.next = function() {
|
var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32
|
me.s0 = me.s1;
|
me.s1 = me.s2;
|
return me.s2 = t - (me.c = t | 0);
|
};
|
|
// Apply the seeding algorithm from Baagoe.
|
me.c = 1;
|
me.s0 = mash(' ');
|
me.s1 = mash(' ');
|
me.s2 = mash(' ');
|
me.s0 -= mash(seed);
|
if (me.s0 < 0) { me.s0 += 1; }
|
me.s1 -= mash(seed);
|
if (me.s1 < 0) { me.s1 += 1; }
|
me.s2 -= mash(seed);
|
if (me.s2 < 0) { me.s2 += 1; }
|
mash = null;
|
}
|
|
function copy(f, t) {
|
t.c = f.c;
|
t.s0 = f.s0;
|
t.s1 = f.s1;
|
t.s2 = f.s2;
|
return t;
|
}
|
|
function impl(seed, opts) {
|
var xg = new Alea(seed),
|
state = opts && opts.state,
|
prng = xg.next;
|
prng.int32 = function() { return (xg.next() * 0x100000000) | 0; };
|
prng.double = function() {
|
return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53
|
};
|
prng.quick = prng;
|
if (state) {
|
if (typeof(state) == 'object') copy(state, xg);
|
prng.state = function() { return copy(xg, {}); };
|
}
|
return prng;
|
}
|
|
function Mash() {
|
var n = 0xefc8249d;
|
|
var mash = function(data) {
|
data = String(data);
|
for (var i = 0; i < data.length; i++) {
|
n += data.charCodeAt(i);
|
var h = 0.02519603282416938 * n;
|
n = h >>> 0;
|
h -= n;
|
h *= n;
|
n = h >>> 0;
|
h -= n;
|
n += h * 0x100000000; // 2^32
|
}
|
return (n >>> 0) * 2.3283064365386963e-10; // 2^-32
|
};
|
|
return mash;
|
}
|
|
|
if (module && module.exports) {
|
module.exports = impl;
|
} else if (define && define.amd) {
|
define(function() { return impl; });
|
} else {
|
this.alea = impl;
|
}
|
|
})(
|
commonjsGlobal,
|
module, // present in node.js
|
(typeof undefined) == 'function' // present with an AMD loader
|
);
|
} (alea$1));
|
|
var aleaExports = alea$1.exports;
|
|
var xor128$1 = {exports: {}};
|
|
(function (module) {
|
// A Javascript implementaion of the "xor128" prng algorithm by
|
// George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
|
|
(function(global, module, define) {
|
|
function XorGen(seed) {
|
var me = this, strseed = '';
|
|
me.x = 0;
|
me.y = 0;
|
me.z = 0;
|
me.w = 0;
|
|
// Set up generator function.
|
me.next = function() {
|
var t = me.x ^ (me.x << 11);
|
me.x = me.y;
|
me.y = me.z;
|
me.z = me.w;
|
return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8);
|
};
|
|
if (seed === (seed | 0)) {
|
// Integer seed.
|
me.x = seed;
|
} else {
|
// String seed.
|
strseed += seed;
|
}
|
|
// Mix in string seed, then discard an initial batch of 64 values.
|
for (var k = 0; k < strseed.length + 64; k++) {
|
me.x ^= strseed.charCodeAt(k) | 0;
|
me.next();
|
}
|
}
|
|
function copy(f, t) {
|
t.x = f.x;
|
t.y = f.y;
|
t.z = f.z;
|
t.w = f.w;
|
return t;
|
}
|
|
function impl(seed, opts) {
|
var xg = new XorGen(seed),
|
state = opts && opts.state,
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
prng.double = function() {
|
do {
|
var top = xg.next() >>> 11,
|
bot = (xg.next() >>> 0) / 0x100000000,
|
result = (top + bot) / (1 << 21);
|
} while (result === 0);
|
return result;
|
};
|
prng.int32 = xg.next;
|
prng.quick = prng;
|
if (state) {
|
if (typeof(state) == 'object') copy(state, xg);
|
prng.state = function() { return copy(xg, {}); };
|
}
|
return prng;
|
}
|
|
if (module && module.exports) {
|
module.exports = impl;
|
} else if (define && define.amd) {
|
define(function() { return impl; });
|
} else {
|
this.xor128 = impl;
|
}
|
|
})(
|
commonjsGlobal,
|
module, // present in node.js
|
(typeof undefined) == 'function' // present with an AMD loader
|
);
|
} (xor128$1));
|
|
var xor128Exports = xor128$1.exports;
|
|
var xorwow$1 = {exports: {}};
|
|
(function (module) {
|
// A Javascript implementaion of the "xorwow" prng algorithm by
|
// George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
|
|
(function(global, module, define) {
|
|
function XorGen(seed) {
|
var me = this, strseed = '';
|
|
// Set up generator function.
|
me.next = function() {
|
var t = (me.x ^ (me.x >>> 2));
|
me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v;
|
return (me.d = (me.d + 362437 | 0)) +
|
(me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0;
|
};
|
|
me.x = 0;
|
me.y = 0;
|
me.z = 0;
|
me.w = 0;
|
me.v = 0;
|
|
if (seed === (seed | 0)) {
|
// Integer seed.
|
me.x = seed;
|
} else {
|
// String seed.
|
strseed += seed;
|
}
|
|
// Mix in string seed, then discard an initial batch of 64 values.
|
for (var k = 0; k < strseed.length + 64; k++) {
|
me.x ^= strseed.charCodeAt(k) | 0;
|
if (k == strseed.length) {
|
me.d = me.x << 10 ^ me.x >>> 4;
|
}
|
me.next();
|
}
|
}
|
|
function copy(f, t) {
|
t.x = f.x;
|
t.y = f.y;
|
t.z = f.z;
|
t.w = f.w;
|
t.v = f.v;
|
t.d = f.d;
|
return t;
|
}
|
|
function impl(seed, opts) {
|
var xg = new XorGen(seed),
|
state = opts && opts.state,
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
prng.double = function() {
|
do {
|
var top = xg.next() >>> 11,
|
bot = (xg.next() >>> 0) / 0x100000000,
|
result = (top + bot) / (1 << 21);
|
} while (result === 0);
|
return result;
|
};
|
prng.int32 = xg.next;
|
prng.quick = prng;
|
if (state) {
|
if (typeof(state) == 'object') copy(state, xg);
|
prng.state = function() { return copy(xg, {}); };
|
}
|
return prng;
|
}
|
|
if (module && module.exports) {
|
module.exports = impl;
|
} else if (define && define.amd) {
|
define(function() { return impl; });
|
} else {
|
this.xorwow = impl;
|
}
|
|
})(
|
commonjsGlobal,
|
module, // present in node.js
|
(typeof undefined) == 'function' // present with an AMD loader
|
);
|
} (xorwow$1));
|
|
var xorwowExports = xorwow$1.exports;
|
|
var xorshift7$1 = {exports: {}};
|
|
(function (module) {
|
// A Javascript implementaion of the "xorshift7" algorithm by
|
// François Panneton and Pierre L'ecuyer:
|
// "On the Xorgshift Random Number Generators"
|
// http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf
|
|
(function(global, module, define) {
|
|
function XorGen(seed) {
|
var me = this;
|
|
// Set up generator function.
|
me.next = function() {
|
// Update xor generator.
|
var X = me.x, i = me.i, t, v;
|
t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24);
|
t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10);
|
t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3);
|
t = X[(i + 4) & 7]; v ^= t ^ (t << 7);
|
t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9);
|
X[i] = v;
|
me.i = (i + 1) & 7;
|
return v;
|
};
|
|
function init(me, seed) {
|
var j, X = [];
|
|
if (seed === (seed | 0)) {
|
// Seed state array using a 32-bit integer.
|
X[0] = seed;
|
} else {
|
// Seed state using a string.
|
seed = '' + seed;
|
for (j = 0; j < seed.length; ++j) {
|
X[j & 7] = (X[j & 7] << 15) ^
|
(seed.charCodeAt(j) + X[(j + 1) & 7] << 13);
|
}
|
}
|
// Enforce an array length of 8, not all zeroes.
|
while (X.length < 8) X.push(0);
|
for (j = 0; j < 8 && X[j] === 0; ++j);
|
if (j == 8) X[7] = -1; else X[j];
|
|
me.x = X;
|
me.i = 0;
|
|
// Discard an initial 256 values.
|
for (j = 256; j > 0; --j) {
|
me.next();
|
}
|
}
|
|
init(me, seed);
|
}
|
|
function copy(f, t) {
|
t.x = f.x.slice();
|
t.i = f.i;
|
return t;
|
}
|
|
function impl(seed, opts) {
|
if (seed == null) seed = +(new Date);
|
var xg = new XorGen(seed),
|
state = opts && opts.state,
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
prng.double = function() {
|
do {
|
var top = xg.next() >>> 11,
|
bot = (xg.next() >>> 0) / 0x100000000,
|
result = (top + bot) / (1 << 21);
|
} while (result === 0);
|
return result;
|
};
|
prng.int32 = xg.next;
|
prng.quick = prng;
|
if (state) {
|
if (state.x) copy(state, xg);
|
prng.state = function() { return copy(xg, {}); };
|
}
|
return prng;
|
}
|
|
if (module && module.exports) {
|
module.exports = impl;
|
} else if (define && define.amd) {
|
define(function() { return impl; });
|
} else {
|
this.xorshift7 = impl;
|
}
|
|
})(
|
commonjsGlobal,
|
module, // present in node.js
|
(typeof undefined) == 'function' // present with an AMD loader
|
);
|
} (xorshift7$1));
|
|
var xorshift7Exports = xorshift7$1.exports;
|
|
var xor4096$1 = {exports: {}};
|
|
(function (module) {
|
// A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm.
|
//
|
// This fast non-cryptographic random number generator is designed for
|
// use in Monte-Carlo algorithms. It combines a long-period xorshift
|
// generator with a Weyl generator, and it passes all common batteries
|
// of stasticial tests for randomness while consuming only a few nanoseconds
|
// for each prng generated. For background on the generator, see Brent's
|
// paper: "Some long-period random number generators using shifts and xors."
|
// http://arxiv.org/pdf/1004.3115v1.pdf
|
//
|
// Usage:
|
//
|
// var xor4096 = require('xor4096');
|
// random = xor4096(1); // Seed with int32 or string.
|
// assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits.
|
// assert.equal(random.int32(), 1806534897); // signed int32, 32 bits.
|
//
|
// For nonzero numeric keys, this impelementation provides a sequence
|
// identical to that by Brent's xorgens 3 implementaion in C. This
|
// implementation also provides for initalizing the generator with
|
// string seeds, or for saving and restoring the state of the generator.
|
//
|
// On Chrome, this prng benchmarks about 2.1 times slower than
|
// Javascript's built-in Math.random().
|
|
(function(global, module, define) {
|
|
function XorGen(seed) {
|
var me = this;
|
|
// Set up generator function.
|
me.next = function() {
|
var w = me.w,
|
X = me.X, i = me.i, t, v;
|
// Update Weyl generator.
|
me.w = w = (w + 0x61c88647) | 0;
|
// Update xor generator.
|
v = X[(i + 34) & 127];
|
t = X[i = ((i + 1) & 127)];
|
v ^= v << 13;
|
t ^= t << 17;
|
v ^= v >>> 15;
|
t ^= t >>> 12;
|
// Update Xor generator array state.
|
v = X[i] = v ^ t;
|
me.i = i;
|
// Result is the combination.
|
return (v + (w ^ (w >>> 16))) | 0;
|
};
|
|
function init(me, seed) {
|
var t, v, i, j, w, X = [], limit = 128;
|
if (seed === (seed | 0)) {
|
// Numeric seeds initialize v, which is used to generates X.
|
v = seed;
|
seed = null;
|
} else {
|
// String seeds are mixed into v and X one character at a time.
|
seed = seed + '\0';
|
v = 0;
|
limit = Math.max(limit, seed.length);
|
}
|
// Initialize circular array and weyl value.
|
for (i = 0, j = -32; j < limit; ++j) {
|
// Put the unicode characters into the array, and shuffle them.
|
if (seed) v ^= seed.charCodeAt((j + 32) % seed.length);
|
// After 32 shuffles, take v as the starting w value.
|
if (j === 0) w = v;
|
v ^= v << 10;
|
v ^= v >>> 15;
|
v ^= v << 4;
|
v ^= v >>> 13;
|
if (j >= 0) {
|
w = (w + 0x61c88647) | 0; // Weyl.
|
t = (X[j & 127] ^= (v + w)); // Combine xor and weyl to init array.
|
i = (0 == t) ? i + 1 : 0; // Count zeroes.
|
}
|
}
|
// We have detected all zeroes; make the key nonzero.
|
if (i >= 128) {
|
X[(seed && seed.length || 0) & 127] = -1;
|
}
|
// Run the generator 512 times to further mix the state before using it.
|
// Factoring this as a function slows the main generator, so it is just
|
// unrolled here. The weyl generator is not advanced while warming up.
|
i = 127;
|
for (j = 4 * 128; j > 0; --j) {
|
v = X[(i + 34) & 127];
|
t = X[i = ((i + 1) & 127)];
|
v ^= v << 13;
|
t ^= t << 17;
|
v ^= v >>> 15;
|
t ^= t >>> 12;
|
X[i] = v ^ t;
|
}
|
// Storing state as object members is faster than using closure variables.
|
me.w = w;
|
me.X = X;
|
me.i = i;
|
}
|
|
init(me, seed);
|
}
|
|
function copy(f, t) {
|
t.i = f.i;
|
t.w = f.w;
|
t.X = f.X.slice();
|
return t;
|
}
|
function impl(seed, opts) {
|
if (seed == null) seed = +(new Date);
|
var xg = new XorGen(seed),
|
state = opts && opts.state,
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
prng.double = function() {
|
do {
|
var top = xg.next() >>> 11,
|
bot = (xg.next() >>> 0) / 0x100000000,
|
result = (top + bot) / (1 << 21);
|
} while (result === 0);
|
return result;
|
};
|
prng.int32 = xg.next;
|
prng.quick = prng;
|
if (state) {
|
if (state.X) copy(state, xg);
|
prng.state = function() { return copy(xg, {}); };
|
}
|
return prng;
|
}
|
|
if (module && module.exports) {
|
module.exports = impl;
|
} else if (define && define.amd) {
|
define(function() { return impl; });
|
} else {
|
this.xor4096 = impl;
|
}
|
|
})(
|
commonjsGlobal, // window object or global
|
module, // present in node.js
|
(typeof undefined) == 'function' // present with an AMD loader
|
);
|
} (xor4096$1));
|
|
var xor4096Exports = xor4096$1.exports;
|
|
var tychei$1 = {exports: {}};
|
|
(function (module) {
|
// A Javascript implementaion of the "Tyche-i" prng algorithm by
|
// Samuel Neves and Filipe Araujo.
|
// See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
|
|
(function(global, module, define) {
|
|
function XorGen(seed) {
|
var me = this, strseed = '';
|
|
// Set up generator function.
|
me.next = function() {
|
var b = me.b, c = me.c, d = me.d, a = me.a;
|
b = (b << 25) ^ (b >>> 7) ^ c;
|
c = (c - d) | 0;
|
d = (d << 24) ^ (d >>> 8) ^ a;
|
a = (a - b) | 0;
|
me.b = b = (b << 20) ^ (b >>> 12) ^ c;
|
me.c = c = (c - d) | 0;
|
me.d = (d << 16) ^ (c >>> 16) ^ a;
|
return me.a = (a - b) | 0;
|
};
|
|
/* The following is non-inverted tyche, which has better internal
|
* bit diffusion, but which is about 25% slower than tyche-i in JS.
|
me.next = function() {
|
var a = me.a, b = me.b, c = me.c, d = me.d;
|
a = (me.a + me.b | 0) >>> 0;
|
d = me.d ^ a; d = d << 16 ^ d >>> 16;
|
c = me.c + d | 0;
|
b = me.b ^ c; b = b << 12 ^ d >>> 20;
|
me.a = a = a + b | 0;
|
d = d ^ a; me.d = d = d << 8 ^ d >>> 24;
|
me.c = c = c + d | 0;
|
b = b ^ c;
|
return me.b = (b << 7 ^ b >>> 25);
|
}
|
*/
|
|
me.a = 0;
|
me.b = 0;
|
me.c = 2654435769 | 0;
|
me.d = 1367130551;
|
|
if (seed === Math.floor(seed)) {
|
// Integer seed.
|
me.a = (seed / 0x100000000) | 0;
|
me.b = seed | 0;
|
} else {
|
// String seed.
|
strseed += seed;
|
}
|
|
// Mix in string seed, then discard an initial batch of 64 values.
|
for (var k = 0; k < strseed.length + 20; k++) {
|
me.b ^= strseed.charCodeAt(k) | 0;
|
me.next();
|
}
|
}
|
|
function copy(f, t) {
|
t.a = f.a;
|
t.b = f.b;
|
t.c = f.c;
|
t.d = f.d;
|
return t;
|
}
|
function impl(seed, opts) {
|
var xg = new XorGen(seed),
|
state = opts && opts.state,
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
prng.double = function() {
|
do {
|
var top = xg.next() >>> 11,
|
bot = (xg.next() >>> 0) / 0x100000000,
|
result = (top + bot) / (1 << 21);
|
} while (result === 0);
|
return result;
|
};
|
prng.int32 = xg.next;
|
prng.quick = prng;
|
if (state) {
|
if (typeof(state) == 'object') copy(state, xg);
|
prng.state = function() { return copy(xg, {}); };
|
}
|
return prng;
|
}
|
|
if (module && module.exports) {
|
module.exports = impl;
|
} else if (define && define.amd) {
|
define(function() { return impl; });
|
} else {
|
this.tychei = impl;
|
}
|
|
})(
|
commonjsGlobal,
|
module, // present in node.js
|
(typeof undefined) == 'function' // present with an AMD loader
|
);
|
} (tychei$1));
|
|
var tycheiExports = tychei$1.exports;
|
|
var seedrandom$1 = {exports: {}};
|
|
var _nodeResolve_empty = {};
|
|
var _nodeResolve_empty$1 = {
|
__proto__: null,
|
default: _nodeResolve_empty
|
};
|
|
var require$$0 = /*@__PURE__*/getAugmentedNamespace(_nodeResolve_empty$1);
|
|
/*
|
Copyright 2019 David Bau.
|
|
Permission is hereby granted, free of charge, to any person obtaining
|
a copy of this software and associated documentation files (the
|
"Software"), to deal in the Software without restriction, including
|
without limitation the rights to use, copy, modify, merge, publish,
|
distribute, sublicense, and/or sell copies of the Software, and to
|
permit persons to whom the Software is furnished to do so, subject to
|
the following conditions:
|
|
The above copyright notice and this permission notice shall be
|
included in all copies or substantial portions of the Software.
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*/
|
|
(function (module) {
|
(function (global, pool, math) {
|
//
|
// The following constants are related to IEEE 754 limits.
|
//
|
|
var width = 256, // each RC4 output is 0 <= x < 256
|
chunks = 6, // at least six RC4 outputs for each double
|
digits = 52, // there are 52 significant digits in a double
|
rngname = 'random', // rngname: name for Math.random and Math.seedrandom
|
startdenom = math.pow(width, chunks),
|
significance = math.pow(2, digits),
|
overflow = significance * 2,
|
mask = width - 1,
|
nodecrypto; // node.js crypto module, initialized at the bottom.
|
|
//
|
// seedrandom()
|
// This is the seedrandom function described above.
|
//
|
function seedrandom(seed, options, callback) {
|
var key = [];
|
options = (options == true) ? { entropy: true } : (options || {});
|
|
// Flatten the seed string or build one from local entropy if needed.
|
var shortseed = mixkey(flatten(
|
options.entropy ? [seed, tostring(pool)] :
|
(seed == null) ? autoseed() : seed, 3), key);
|
|
// Use the seed to initialize an ARC4 generator.
|
var arc4 = new ARC4(key);
|
|
// This function returns a random double in [0, 1) that contains
|
// randomness in every bit of the mantissa of the IEEE 754 value.
|
var prng = function() {
|
var n = arc4.g(chunks), // Start with a numerator n < 2 ^ 48
|
d = startdenom, // and denominator d = 2 ^ 48.
|
x = 0; // and no 'extra last byte'.
|
while (n < significance) { // Fill up all significant digits by
|
n = (n + x) * width; // shifting numerator and
|
d *= width; // denominator and generating a
|
x = arc4.g(1); // new least-significant-byte.
|
}
|
while (n >= overflow) { // To avoid rounding up, before adding
|
n /= 2; // last byte, shift everything
|
d /= 2; // right using integer math until
|
x >>>= 1; // we have exactly the desired bits.
|
}
|
return (n + x) / d; // Form the number within [0, 1).
|
};
|
|
prng.int32 = function() { return arc4.g(4) | 0; };
|
prng.quick = function() { return arc4.g(4) / 0x100000000; };
|
prng.double = prng;
|
|
// Mix the randomness into accumulated entropy.
|
mixkey(tostring(arc4.S), pool);
|
|
// Calling convention: what to return as a function of prng, seed, is_math.
|
return (options.pass || callback ||
|
function(prng, seed, is_math_call, state) {
|
if (state) {
|
// Load the arc4 state from the given state if it has an S array.
|
if (state.S) { copy(state, arc4); }
|
// Only provide the .state method if requested via options.state.
|
prng.state = function() { return copy(arc4, {}); };
|
}
|
|
// If called as a method of Math (Math.seedrandom()), mutate
|
// Math.random because that is how seedrandom.js has worked since v1.0.
|
if (is_math_call) { math[rngname] = prng; return seed; }
|
|
// Otherwise, it is a newer calling convention, so return the
|
// prng directly.
|
else return prng;
|
})(
|
prng,
|
shortseed,
|
'global' in options ? options.global : (this == math),
|
options.state);
|
}
|
|
//
|
// ARC4
|
//
|
// An ARC4 implementation. The constructor takes a key in the form of
|
// an array of at most (width) integers that should be 0 <= x < (width).
|
//
|
// The g(count) method returns a pseudorandom integer that concatenates
|
// the next (count) outputs from ARC4. Its return value is a number x
|
// that is in the range 0 <= x < (width ^ count).
|
//
|
function ARC4(key) {
|
var t, keylen = key.length,
|
me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
|
|
// The empty key [] is treated as [0].
|
if (!keylen) { key = [keylen++]; }
|
|
// Set up S using the standard key scheduling algorithm.
|
while (i < width) {
|
s[i] = i++;
|
}
|
for (i = 0; i < width; i++) {
|
s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))];
|
s[j] = t;
|
}
|
|
// The "g" method returns the next (count) outputs as one number.
|
(me.g = function(count) {
|
// Using instance members instead of closure state nearly doubles speed.
|
var t, r = 0,
|
i = me.i, j = me.j, s = me.S;
|
while (count--) {
|
t = s[i = mask & (i + 1)];
|
r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))];
|
}
|
me.i = i; me.j = j;
|
return r;
|
// For robust unpredictability, the function call below automatically
|
// discards an initial batch of values. This is called RC4-drop[256].
|
// See http://google.com/search?q=rsa+fluhrer+response&btnI
|
})(width);
|
}
|
|
//
|
// copy()
|
// Copies internal state of ARC4 to or from a plain object.
|
//
|
function copy(f, t) {
|
t.i = f.i;
|
t.j = f.j;
|
t.S = f.S.slice();
|
return t;
|
}
|
//
|
// flatten()
|
// Converts an object tree to nested arrays of strings.
|
//
|
function flatten(obj, depth) {
|
var result = [], typ = (typeof obj), prop;
|
if (depth && typ == 'object') {
|
for (prop in obj) {
|
try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {}
|
}
|
}
|
return (result.length ? result : typ == 'string' ? obj : obj + '\0');
|
}
|
|
//
|
// mixkey()
|
// Mixes a string seed into a key that is an array of integers, and
|
// returns a shortened string seed that is equivalent to the result key.
|
//
|
function mixkey(seed, key) {
|
var stringseed = seed + '', smear, j = 0;
|
while (j < stringseed.length) {
|
key[mask & j] =
|
mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++));
|
}
|
return tostring(key);
|
}
|
|
//
|
// autoseed()
|
// Returns an object for autoseeding, using window.crypto and Node crypto
|
// module if available.
|
//
|
function autoseed() {
|
try {
|
var out;
|
if (nodecrypto && (out = nodecrypto.randomBytes)) {
|
// The use of 'out' to remember randomBytes makes tight minified code.
|
out = out(width);
|
} else {
|
out = new Uint8Array(width);
|
(global.crypto || global.msCrypto).getRandomValues(out);
|
}
|
return tostring(out);
|
} catch (e) {
|
var browser = global.navigator,
|
plugins = browser && browser.plugins;
|
return [+new Date, global, plugins, global.screen, tostring(pool)];
|
}
|
}
|
|
//
|
// tostring()
|
// Converts an array of charcodes to a string
|
//
|
function tostring(a) {
|
return String.fromCharCode.apply(0, a);
|
}
|
|
//
|
// When seedrandom.js is loaded, we immediately mix a few bits
|
// from the built-in RNG into the entropy pool. Because we do
|
// not want to interfere with deterministic PRNG state later,
|
// seedrandom will not call math.random on its own again after
|
// initialization.
|
//
|
mixkey(math.random(), pool);
|
|
//
|
// Nodejs and AMD support: export the implementation as a module using
|
// either convention.
|
//
|
if (module.exports) {
|
module.exports = seedrandom;
|
// When in node.js, try using crypto package for autoseeding.
|
try {
|
nodecrypto = require$$0;
|
} catch (ex) {}
|
} else {
|
// When included as a plain script, set up Math.seedrandom global.
|
math['seed' + rngname] = seedrandom;
|
}
|
|
|
// End anonymous scope, and pass initial values.
|
})(
|
// global: `self` in browsers (including strict mode and web workers),
|
// otherwise `this` in Node and other environments
|
(typeof self !== 'undefined') ? self : commonjsGlobal,
|
[], // pool: entropy pool starts empty
|
Math // math: package containing random, pow, and seedrandom
|
);
|
} (seedrandom$1));
|
|
var seedrandomExports = seedrandom$1.exports;
|
|
// A library of seedable RNGs implemented in Javascript.
|
//
|
// Usage:
|
//
|
// var seedrandom = require('seedrandom');
|
// var random = seedrandom(1); // or any seed.
|
// var x = random(); // 0 <= x < 1. Every bit is random.
|
// var x = random.quick(); // 0 <= x < 1. 32 bits of randomness.
|
|
// alea, a 53-bit multiply-with-carry generator by Johannes Baagøe.
|
// Period: ~2^116
|
// Reported to pass all BigCrush tests.
|
var alea = aleaExports;
|
|
// xor128, a pure xor-shift generator by George Marsaglia.
|
// Period: 2^128-1.
|
// Reported to fail: MatrixRank and LinearComp.
|
var xor128 = xor128Exports;
|
|
// xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl.
|
// Period: 2^192-2^32
|
// Reported to fail: CollisionOver, SimpPoker, and LinearComp.
|
var xorwow = xorwowExports;
|
|
// xorshift7, by François Panneton and Pierre L'ecuyer, takes
|
// a different approach: it adds robustness by allowing more shifts
|
// than Marsaglia's original three. It is a 7-shift generator
|
// with 256 bits, that passes BigCrush with no systmatic failures.
|
// Period 2^256-1.
|
// No systematic BigCrush failures reported.
|
var xorshift7 = xorshift7Exports;
|
|
// xor4096, by Richard Brent, is a 4096-bit xor-shift with a
|
// very long period that also adds a Weyl generator. It also passes
|
// BigCrush with no systematic failures. Its long period may
|
// be useful if you have many generators and need to avoid
|
// collisions.
|
// Period: 2^4128-2^32.
|
// No systematic BigCrush failures reported.
|
var xor4096 = xor4096Exports;
|
|
// Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random
|
// number generator derived from ChaCha, a modern stream cipher.
|
// https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
|
// Period: ~2^127
|
// No systematic BigCrush failures reported.
|
var tychei = tycheiExports;
|
|
// The original ARC4-based prng included in this library.
|
// Period: ~2^1600
|
var sr = seedrandomExports;
|
|
sr.alea = alea;
|
sr.xor128 = xor128;
|
sr.xorwow = xorwow;
|
sr.xorshift7 = xorshift7;
|
sr.xor4096 = xor4096;
|
sr.tychei = tychei;
|
|
var seedrandom = sr;
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Apply a mapping function to a nested structure in a recursive manner.
|
*
|
* The result of the mapping is an object with the same nested structure (i.e.,
|
* of arrays and dicts) as the input, except that some subtrees are replaced,
|
* according to the results of the mapping function.
|
*
|
* Mappings are memoized. Thus, if the nested structure contains the same
|
* object in multiple positions, the output will contain the same mapped object
|
* in those positions. Cycles are not supported, however.
|
*
|
* @param input: The object to which to apply the mapping function.
|
* @param mapFn: A function that expects a single node of the object tree, and
|
* returns a `DeepMapResult`. The `DeepMapResult` either provides a
|
* replacement value for that node (i.e., replacing the subtree), or indicates
|
* that the node should be processed recursively.
|
*/
|
function deepMap(input, mapFn) {
|
return deepMapInternal(input, mapFn);
|
}
|
/**
|
* @param seen: A Map of known object mappings (i.e., memoized results of
|
* `mapFn()`)
|
* @param containedIn: An set containing objects on the reference path currently
|
* being processed (used to detect cycles).
|
*/
|
function deepMapInternal(input, mapFn, seen = new Map(), containedIn = new Set()) {
|
if (input == null) {
|
return null;
|
}
|
if (typeof Blob === 'function' && input instanceof Blob) {
|
return input.slice();
|
}
|
if (containedIn.has(input)) {
|
throw new Error('Circular references are not supported.');
|
}
|
if (seen.has(input)) {
|
return seen.get(input);
|
}
|
const result = mapFn(input);
|
if (result.recurse && result.value !== null) {
|
throw new Error('A deep map function may not return both a value and recurse=true.');
|
}
|
if (!result.recurse) {
|
seen.set(input, result.value);
|
return result.value;
|
}
|
else if (isIterable(input)) {
|
// tslint:disable-next-line:no-any
|
const mappedIterable = Array.isArray(input) ? [] : {};
|
containedIn.add(input);
|
for (const k in input) {
|
const child = input[k];
|
const childResult = deepMapInternal(child, mapFn, seen, containedIn);
|
mappedIterable[k] = childResult;
|
}
|
containedIn.delete(input);
|
if (input.__proto__) {
|
mappedIterable.__proto__ = input.__proto__;
|
}
|
return mappedIterable;
|
}
|
else {
|
throw new Error(`Can't recurse into non-iterable type: ${input}`);
|
}
|
}
|
// TODO(soergel, kangyizhang) Reconsider naming of deepZip() to avoid confusion
|
// with zip()
|
/**
|
* Zip nested structures together in a recursive manner.
|
*
|
* This has the effect of transposing or pivoting data, e.g. converting it from
|
* a row-major representation to a column-major representation.
|
*
|
* For example, `deepZip([{a: 1, b: 2}, {a: 3, b: 4}])` returns
|
* `{a: [1, 3], b: [2, 4]}`.
|
*
|
* The inputs should all have the same nested structure (i.e., of arrays and
|
* dicts). The result is a single object with the same nested structure, where
|
* the leaves are arrays collecting the values of the inputs at that location
|
* (or, optionally, the result of a custom function applied to those arrays).
|
*
|
* @param inputs: An array of the objects to zip together.
|
* @param zipFn: (optional) A function that expects an array of elements at a
|
* single node of the object tree, and returns a `DeepMapResult`. The
|
* `DeepMapResult` either provides a result value for that node (i.e.,
|
* representing the subtree), or indicates that the node should be processed
|
* recursively. The default zipFn recurses as far as possible and places
|
* arrays at the leaves.
|
*/
|
function deepZip(inputs, zipFn = zipToList) {
|
return deepZipInternal(inputs, zipFn);
|
}
|
/**
|
* @param containedIn: An set containing objects on the reference path currently
|
* being processed (used to detect cycles).
|
*/
|
function deepZipInternal(inputs, zipFn, containedIn = new Set()) {
|
// The recursion follows the structure of input 0; it's assumed that all the
|
// other inputs have the same structure.
|
const input = inputs[0];
|
if (containedIn.has(input)) {
|
throw new Error('Circular references are not supported.');
|
}
|
const result = zipFn(inputs);
|
if (result.recurse && result.value !== null) {
|
throw new Error('A deep zip function may not return both a value and recurse=true.');
|
}
|
if (!result.recurse) {
|
return result.value;
|
}
|
else if (isIterable(input)) {
|
// tslint:disable-next-line:no-any
|
const mappedIterable = Array.isArray(input) ? [] : {};
|
containedIn.add(input);
|
for (const k in input) {
|
const children = inputs.map(x => x[k]);
|
const childResult = deepZipInternal(children, zipFn, containedIn);
|
mappedIterable[k] = childResult;
|
}
|
containedIn.delete(input);
|
return mappedIterable;
|
}
|
else {
|
throw new Error(`Can't recurse into non-iterable type: ${input}`);
|
}
|
}
|
// tslint:disable-next-line:no-any
|
function zipToList(x) {
|
if (x === null) {
|
return null;
|
}
|
// TODO(soergel): validate array type?
|
if (isIterable(x[0])) {
|
return { value: null, recurse: true };
|
}
|
else {
|
return { value: x, recurse: false };
|
}
|
}
|
/**
|
* Apply an async mapping function to a nested structure in a recursive manner.
|
*
|
* This first creates a nested structure of Promises, and then awaits all of
|
* those, resulting in a single Promise for a resolved nested structure.
|
*
|
* The result of the mapping is an object with the same nested structure (i.e.,
|
* of arrays and dicts) as the input, except that some subtrees are replaced,
|
* according to the results of the mapping function.
|
*
|
* Mappings are memoized. Thus, if the nested structure contains the same
|
* object in multiple positions, the output will contain the same mapped object
|
* in those positions. Cycles are not supported, however.
|
*
|
* @param input: The object to which to apply the mapping function.
|
* @param mapFn: A function that expects a single node of the object tree, and
|
* returns a `DeepMapAsyncResult`. The `DeepMapAsyncResult` either provides
|
* a `Promise` for a replacement value for that node (i.e., replacing the
|
* subtree), or indicates that the node should be processed recursively. Note
|
* that the decision whether or not to recurse must be made immediately; only
|
* the mapped value may be promised.
|
*/
|
async function deepMapAndAwaitAll(input, mapFn) {
|
const seen = new Map();
|
// First do a normal deepMap, collecting Promises in 'seen' as a side effect.
|
deepMapInternal(input, mapFn, seen);
|
// Replace the Promises in 'seen' in place.
|
// Note TypeScript provides no async map iteration, and regular map iteration
|
// is broken too, so sadly we have to do Array.from() to make it work.
|
// (There's no advantage to Promise.all(), and that would be tricky anyway.)
|
for (const key of Array.from(seen.keys())) {
|
const value = seen.get(key);
|
if (tf.util.isPromise(value)) {
|
const mappedValue = await value;
|
seen.set(key, mappedValue);
|
}
|
}
|
// Normal deepMap again, this time filling in the resolved values.
|
// It's unfortunate that we have to do two passes.
|
// TODO(soergel): test performance and think harder about a fast solution.
|
const result = deepMapInternal(input, mapFn, seen);
|
return result;
|
}
|
/**
|
* Determine whether the argument is iterable.
|
*
|
* @returns true if the argument is an array or any non-Tensor object.
|
*/
|
// tslint:disable-next-line:no-any
|
function isIterable(obj) {
|
let isTextDecoder = false;
|
if (tf.env().get('IS_BROWSER')) {
|
isTextDecoder = obj instanceof TextDecoder;
|
}
|
else {
|
// tslint:disable-next-line:no-require-imports
|
const { StringDecoder } = require('string_decoder');
|
isTextDecoder = obj instanceof StringDecoder;
|
}
|
return obj != null && (!ArrayBuffer.isView(obj)) &&
|
(Array.isArray(obj) ||
|
(typeof obj === 'object' && !(obj instanceof tf.Tensor) &&
|
!(obj instanceof Promise) && !isTextDecoder));
|
}
|
/**
|
* Determine whether the argument can be converted to Tensor.
|
*
|
* Tensors, primitives, arrays, and TypedArrays all qualify; anything else does
|
* not.
|
*
|
* @returns true if the argument can be converted to Tensor.
|
*/
|
// tslint:disable-next-line:no-any
|
function canTensorify(obj) {
|
return obj == null || isPrimitive(obj) || Array.isArray(obj) ||
|
(typeof obj === 'object' && (obj instanceof tf.Tensor)) ||
|
tf.util.isTypedArray(obj);
|
}
|
/**
|
* Returns true if the given `value` is a primitive type. Otherwise returns
|
* false. This is equivalant to node util.isPrimitive
|
*/
|
function isPrimitive(value) {
|
return (value === null ||
|
(typeof value !== 'object' && typeof value !== 'function'));
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
function deepClone(container) {
|
return deepMap(container, cloneIfTensor);
|
}
|
// tslint:disable-next-line: no-any
|
function cloneIfTensor(item) {
|
if (item instanceof tf.Tensor) {
|
return ({ value: item.clone(), recurse: false });
|
}
|
else if (isIterable(item)) {
|
return { value: null, recurse: true };
|
}
|
else {
|
return { value: item, recurse: false };
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* A ring buffer, providing O(1) FIFO, LIFO, and related operations.
|
*/
|
class RingBuffer {
|
/**
|
* Constructs a `RingBuffer`.
|
* @param capacity The number of items that the buffer can accomodate.
|
*/
|
constructor(capacity) {
|
this.capacity = capacity;
|
// Note we store the indices in the range 0 <= index < 2*capacity.
|
// This allows us to distinguish the full from the empty case.
|
// See https://www.snellman.net/blog/archive/2016-12-13-ring-buffers/
|
this.begin = 0; // inclusive
|
this.end = 0; // exclusive
|
if (capacity == null) {
|
throw new RangeError('Can\'t create a ring buffer of unknown capacity.');
|
}
|
if (capacity < 1) {
|
throw new RangeError('Can\'t create ring buffer of capacity < 1.');
|
}
|
this.data = new Array(capacity);
|
this.doubledCapacity = 2 * capacity;
|
}
|
/**
|
* Map any index into the range 0 <= index < 2*capacity.
|
*/
|
wrap(index) {
|
// don't trust % on negative numbers
|
while (index < 0) {
|
index += this.doubledCapacity;
|
}
|
return index % this.doubledCapacity;
|
}
|
get(index) {
|
if (index < 0) {
|
throw new RangeError('Can\'t get item at a negative index.');
|
}
|
return this.data[index % this.capacity];
|
}
|
set(index, value) {
|
if (index < 0) {
|
throw new RangeError('Can\'t set item at a negative index.');
|
}
|
this.data[index % this.capacity] = value;
|
}
|
/**
|
* Returns the current number of items in the buffer.
|
*/
|
length() {
|
let length = this.end - this.begin;
|
if (length < 0) {
|
length = this.doubledCapacity + length;
|
}
|
return length;
|
}
|
/**
|
* Reports whether the buffer is full.
|
* @returns true if the number of items in the buffer equals its capacity, and
|
* false otherwise.
|
*/
|
isFull() {
|
return this.length() === this.capacity;
|
}
|
/**
|
* Reports whether the buffer is empty.
|
* @returns true if the number of items in the buffer equals zero, and
|
* false otherwise.
|
*/
|
isEmpty() {
|
return this.length() === 0;
|
}
|
/**
|
* Adds an item to the end of the buffer.
|
*/
|
push(value) {
|
if (this.isFull()) {
|
throw new RangeError('Ring buffer is full.');
|
}
|
this.set(this.end, value);
|
this.end = this.wrap(this.end + 1);
|
}
|
/**
|
* Adds many items to the end of the buffer, in order.
|
*/
|
pushAll(values) {
|
for (const value of values) {
|
this.push(value);
|
}
|
}
|
/**
|
* Removes and returns the last item in the buffer.
|
*/
|
pop() {
|
if (this.isEmpty()) {
|
throw new RangeError('Ring buffer is empty.');
|
}
|
this.end = this.wrap(this.end - 1);
|
const result = this.get(this.end);
|
this.set(this.end, undefined);
|
return result;
|
}
|
/**
|
* Adds an item to the beginning of the buffer.
|
*/
|
unshift(value) {
|
if (this.isFull()) {
|
throw new RangeError('Ring buffer is full.');
|
}
|
this.begin = this.wrap(this.begin - 1);
|
this.set(this.begin, value);
|
}
|
/**
|
* Removes and returns the first item in the buffer.
|
*/
|
shift() {
|
if (this.isEmpty()) {
|
throw new RangeError('Ring buffer is empty.');
|
}
|
const result = this.get(this.begin);
|
this.set(this.begin, undefined);
|
this.begin = this.wrap(this.begin + 1);
|
return result;
|
}
|
/**
|
* Removes and returns a specific item in the buffer, and moves the last item
|
* to the vacated slot. This is useful for implementing a shuffling stream.
|
* Note that this operation necessarily scrambles the original order.
|
*
|
* @param relativeIndex: the index of the item to remove, relative to the
|
* first item in the buffer (e.g., hiding the ring nature of the underlying
|
* storage).
|
*/
|
shuffleExcise(relativeIndex) {
|
if (this.isEmpty()) {
|
throw new RangeError('Ring buffer is empty.');
|
}
|
const index = this.wrap(this.begin + relativeIndex);
|
const result = this.get(index);
|
this.set(index, this.pop());
|
return result;
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
class GrowingRingBuffer extends RingBuffer {
|
/**
|
* Constructs a `GrowingRingBuffer`.
|
*/
|
constructor() {
|
super(GrowingRingBuffer.INITIAL_CAPACITY);
|
}
|
isFull() {
|
return false;
|
}
|
push(value) {
|
if (super.isFull()) {
|
this.expand();
|
}
|
super.push(value);
|
}
|
unshift(value) {
|
if (super.isFull()) {
|
this.expand();
|
}
|
super.unshift(value);
|
}
|
/**
|
* Doubles the capacity of the buffer.
|
*/
|
expand() {
|
const newCapacity = this.capacity * 2;
|
const newData = new Array(newCapacity);
|
const len = this.length();
|
// Rotate the buffer to start at index 0 again, since we can't just
|
// allocate more space at the end.
|
for (let i = 0; i < len; i++) {
|
newData[i] = this.get(this.wrap(this.begin + i));
|
}
|
this.data = newData;
|
this.capacity = newCapacity;
|
this.doubledCapacity = 2 * this.capacity;
|
this.begin = 0;
|
this.end = len;
|
}
|
}
|
GrowingRingBuffer.INITIAL_CAPACITY = 32;
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
// Here we implement a simple asynchronous iterator.
|
// This lets us avoid using either third-party stream libraries or
|
// recent TypeScript language support requiring polyfills.
|
/**
|
* Create a `LazyIterator` from an array of items.
|
*/
|
function iteratorFromItems(items) {
|
return new ArrayIterator(items);
|
}
|
/**
|
* Create a `LazyIterator` from a function.
|
*
|
* ```js
|
* let i = -1;
|
* const func = () =>
|
* ++i < 5 ? {value: i, done: false} : {value: null, done: true};
|
* const iter = tf.data.iteratorFromFunction(func);
|
* await iter.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param func A function that produces data on each call.
|
*/
|
function iteratorFromFunction(func) {
|
return new FunctionCallIterator(func);
|
}
|
/**
|
* Create a `LazyIterator` by concatenating underlying streams, which are
|
* themselves provided as a stream.
|
*
|
* This can also be thought of as a "stream flatten" operation.
|
*
|
* @param baseIterators A stream of streams to be concatenated.
|
* @param baseErrorHandler An optional function that can intercept `Error`s
|
* raised during a `next()` call on the base stream. This function can decide
|
* whether the error should be propagated, whether the error should be
|
* ignored, or whether the base stream should be terminated.
|
*/
|
function iteratorFromConcatenated(baseIterators, baseErrorHandler) {
|
return new ChainedIterator(baseIterators, baseErrorHandler);
|
}
|
/**
|
* Create a `LazyIterator` by zipping together an array, dict, or nested
|
* structure of `LazyIterator`s (and perhaps additional constants).
|
*
|
* The underlying streams must provide elements in a consistent order such
|
* that they correspond.
|
*
|
* Typically, the underlying streams should have the same number of
|
* elements. If they do not, the behavior is determined by the
|
* `mismatchMode` argument.
|
*
|
* The nested structure of the `iterators` argument determines the
|
* structure of elements in the resulting iterator.
|
*
|
* @param iterators: An array or object containing LazyIterators at the
|
* leaves.
|
* @param mismatchMode: Determines what to do when one underlying iterator
|
* is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
|
* causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
|
* causes the zipped iterator to terminate with the furst underlying
|
* streams, so elements remaining on the longer streams are ignored.
|
* `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
|
* in nulls for the exhausted streams, until all streams are exhausted.
|
*/
|
function iteratorFromZipped(iterators, mismatchMode = ZipMismatchMode.FAIL) {
|
return new ZipIterator(iterators, mismatchMode);
|
}
|
/**
|
* An asynchronous iterator, providing lazy access to a potentially
|
* unbounded stream of elements.
|
*
|
* Iterator can be obtained from a dataset:
|
* `const iter = await dataset.iterator();`
|
*/
|
class LazyIterator {
|
/**
|
* Collect all remaining elements of a bounded stream into an array.
|
* Obviously this will succeed only for small streams that fit in memory.
|
* Useful for testing.
|
*
|
* @returns A Promise for an array of stream elements, which will resolve
|
* when the stream is exhausted.
|
*/
|
async toArray() {
|
const result = [];
|
let x = await this.next();
|
while (!x.done) {
|
result.push(x.value);
|
x = await this.next();
|
}
|
return result;
|
}
|
/**
|
* Collect all elements of this dataset into an array with prefetching 100
|
* elements. This is useful for testing, because the prefetch changes the
|
* order in which the Promises are resolved along the processing pipeline.
|
* This may help expose bugs where results are dependent on the order of
|
* Promise resolution rather than on the logical order of the stream (i.e.,
|
* due to hidden mutable state).
|
*
|
* @returns A Promise for an array of stream elements, which will resolve
|
* when the stream is exhausted.
|
*/
|
async toArrayForTest() {
|
const stream = this.prefetch(100);
|
const result = [];
|
let x = await stream.next();
|
while (!x.done) {
|
result.push(x.value);
|
x = await stream.next();
|
}
|
return result;
|
}
|
/**
|
* Draw items from the stream until it is exhausted.
|
*
|
* This can be useful when the stream has side effects but no output. In
|
* that case, calling this function guarantees that the stream will be
|
* fully processed.
|
*/
|
async resolveFully() {
|
let x = await this.next();
|
while (!x.done) {
|
x = await this.next();
|
}
|
}
|
/**
|
* Draw items from the stream until it is exhausted, or a predicate fails.
|
*
|
* This can be useful when the stream has side effects but no output. In
|
* that case, calling this function guarantees that the stream will be
|
* fully processed.
|
*/
|
async resolveWhile(predicate) {
|
let x = await this.next();
|
let shouldContinue = predicate(x.value);
|
while ((!x.done) && shouldContinue) {
|
x = await this.next();
|
shouldContinue = predicate(x.value);
|
}
|
}
|
/**
|
* Handles errors thrown on this stream using a provided handler function.
|
*
|
* @param handler A function that handles any `Error` thrown during a `next()`
|
* call and returns true if the stream should continue (dropping the failed
|
* call) or false if the stream should quietly terminate. If the handler
|
* itself throws (or rethrows) an `Error`, that will be propagated.
|
*
|
* @returns A `LazyIterator` of elements passed through from upstream,
|
* possibly filtering or terminating on upstream `next()` calls that
|
* throw an `Error`.
|
*/
|
handleErrors(handler) {
|
return new ErrorHandlingLazyIterator(this, handler);
|
}
|
// TODO(soergel): Implement reduce() etc.
|
/**
|
* Filters this stream according to `predicate`.
|
*
|
* @param predicate A function mapping a stream element to a boolean or a
|
* `Promise` for one.
|
*
|
* @returns A `LazyIterator` of elements for which the predicate was true.
|
*/
|
filter(predicate) {
|
return new FilterIterator(this, predicate);
|
}
|
/**
|
* Maps this stream through a 1-to-1 transform.
|
*
|
* @param transform A function mapping a stream element to a transformed
|
* element.
|
*
|
* @returns A `LazyIterator` of transformed elements.
|
*/
|
map(transform) {
|
return new MapIterator(this, transform);
|
}
|
/**
|
* Maps this stream through an async 1-to-1 transform.
|
*
|
* @param transform A function mapping a stream element to a `Promise` for a
|
* transformed stream element.
|
*
|
* @returns A `LazyIterator` of transformed elements.
|
*/
|
mapAsync(transform) {
|
return new AsyncMapIterator(this, transform);
|
}
|
/**
|
* Maps this stream through a 1-to-1 transform, forcing serial execution.
|
*
|
* @param transform A function mapping a stream element to a transformed
|
* element.
|
*
|
* @returns A `LazyIterator` of transformed elements.
|
*/
|
serialMapAsync(transform) {
|
return new AsyncMapIterator(this, transform).serial();
|
}
|
/**
|
* Maps this stream through a 1-to-many transform.
|
*
|
* @param transform A function mapping a stream element to an array of
|
* transformed elements.
|
*
|
* @returns A `DataStream` of transformed elements.
|
*/
|
flatmap(transform) {
|
return new FlatmapIterator(this, transform);
|
}
|
/**
|
* Apply a function to every element of the stream.
|
*
|
* @param f A function to apply to each stream element.
|
*/
|
async forEachAsync(f) {
|
return this.map(f).resolveFully();
|
}
|
/**
|
* Apply a function to every element of the stream, forcing serial execution.
|
*
|
* @param f A function to apply to each stream element. Should return 'true'
|
* to indicate that the stream should continue, or 'false' to cause it to
|
* terminate.
|
*/
|
async serialForEach(f) {
|
return this.serialMapAsync(f).resolveWhile(x => (x === true));
|
}
|
/**
|
* Groups elements into batches, represented as arrays of elements.
|
*
|
* We can think of the elements of this iterator as 'rows' (even if they are
|
* nested structures). By the same token, consecutive values for a given
|
* key within the elements form a 'column'. This matches the usual sense of
|
* 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
|
*
|
* Thus, "Row-major" means that the resulting batch is simply a collection of
|
* rows: `[row1, row2, row3, ...]`. This is contrast to the column-major
|
* form, which is needed for vectorized computation.
|
*
|
* @param batchSize The number of elements desired per batch.
|
* @param smallLastBatch Whether to emit the final batch when it has fewer
|
* than batchSize elements. Default true.
|
* @returns A `LazyIterator` of batches of elements, represented as arrays
|
* of the original element type.
|
*/
|
rowMajorBatch(batchSize, smallLastBatch = true) {
|
return new RowMajorBatchIterator(this, batchSize, smallLastBatch);
|
}
|
/**
|
* Groups elements into batches, represented in column-major form.
|
*
|
* We can think of the elements of this iterator as 'rows' (even if they are
|
* nested structures). By the same token, consecutive values for a given
|
* key within the elements form a 'column'. This matches the usual sense of
|
* 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
|
*
|
* Thus, "column-major" means that the resulting batch is a (potentially
|
* nested) structure representing the columns. Each column entry, then,
|
* contains a collection of the values found in that column for a range of
|
* input elements. This representation allows for vectorized computation, in
|
* contrast to the row-major form.
|
*
|
* The inputs should all have the same nested structure (i.e., of arrays and
|
* dicts). The result is a single object with the same nested structure,
|
* where the leaves are arrays collecting the values of the inputs at that
|
* location (or, optionally, the result of a custom function applied to those
|
* arrays).
|
*
|
* @param batchSize The number of elements desired per batch.
|
* @param smallLastBatch Whether to emit the final batch when it has fewer
|
* than batchSize elements. Default true.
|
* @param zipFn: (optional) A function that expects an array of elements at a
|
* single node of the object tree, and returns a `DeepMapResult`. The
|
* `DeepMapResult` either provides a result value for that node (i.e.,
|
* representing the subtree), or indicates that the node should be processed
|
* recursively. The default zipFn recurses as far as possible and places
|
* arrays at the leaves.
|
* @returns A `LazyIterator` of batches of elements, represented as an object
|
* with collections at the leaves.
|
*/
|
columnMajorBatch(batchSize, smallLastBatch = true,
|
// tslint:disable-next-line:no-any
|
zipFn = zipToList) {
|
// First collect the desired number of input elements as a row-major batch.
|
const rowBatches = this.rowMajorBatch(batchSize, smallLastBatch);
|
// Now 'rotate' or 'pivot' the data, collecting all values from each column
|
// in the batch (i.e., for each key within the elements) into an array.
|
return rowBatches.map(x => deepZip(x, zipFn));
|
}
|
/**
|
* Concatenate this `LazyIterator` with another.
|
*
|
* @param iterator A `LazyIterator` to be concatenated onto this one.
|
* @param baseErrorHandler An optional function that can intercept `Error`s
|
* raised during a `next()` call on the base stream. This function can
|
* decide whether the error should be propagated, whether the error should
|
* be ignored, or whether the base stream should be terminated.
|
* @returns A `LazyIterator`.
|
*/
|
concatenate(iterator, baseErrorHandler) {
|
return new ChainedIterator(iteratorFromItems([this, iterator]), baseErrorHandler);
|
}
|
/**
|
* Limits this stream to return at most `count` items.
|
*
|
* @param count The maximum number of items to provide from the stream. If
|
* a negative or undefined value is given, the entire stream is returned
|
* unaltered.
|
*/
|
take(count) {
|
if (count < 0 || count == null) {
|
return this;
|
}
|
return new TakeIterator(this, count);
|
}
|
/**
|
* Skips the first `count` items in this stream.
|
*
|
* @param count The number of items to skip. If a negative or undefined
|
* value is given, the entire stream is returned unaltered.
|
*/
|
skip(count) {
|
if (count < 0 || count == null) {
|
return this;
|
}
|
return new SkipIterator(this, count);
|
}
|
/**
|
* Prefetch the first `bufferSize` items in this stream.
|
*
|
* Note this prefetches Promises, but makes no guarantees about when those
|
* Promises resolve.
|
*
|
* @param bufferSize: An integer specifying the number of elements to be
|
* prefetched.
|
*/
|
prefetch(bufferSize) {
|
return new PrefetchIterator(this, bufferSize);
|
}
|
// TODO(soergel): deep sharded shuffle, where supported
|
/**
|
* Randomly shuffles the elements of this stream.
|
*
|
* @param bufferSize: An integer specifying the number of elements from
|
* this stream from which the new stream will sample.
|
* @param seed: (Optional.) An integer specifying the random seed that
|
* will be used to create the distribution.
|
*/
|
shuffle(windowSize, seed) {
|
return new ShuffleIterator(this, windowSize, seed);
|
}
|
/**
|
* Force an iterator to execute serially: each next() call will await the
|
* prior one, so that they cannot execute concurrently.
|
*/
|
serial() {
|
return new SerialIterator(this);
|
}
|
}
|
// ============================================================================
|
// The following private classes serve to implement the chainable methods
|
// on LazyIterator. Unfortunately they can't be placed in separate files,
|
// due to resulting trouble with circular imports.
|
// ============================================================================
|
// Iterators that just extend LazyIterator directly
|
// ============================================================================
|
class ArrayIterator extends LazyIterator {
|
constructor(items) {
|
super();
|
this.items = items;
|
this.trav = 0;
|
}
|
summary() {
|
return `Array of ${this.items.length} items`;
|
}
|
async next() {
|
if (this.trav >= this.items.length) {
|
return { value: null, done: true };
|
}
|
const item = this.items[this.trav];
|
this.trav++;
|
return { value: deepClone(item), done: false };
|
}
|
}
|
class FunctionCallIterator extends LazyIterator {
|
constructor(nextFn) {
|
super();
|
this.nextFn = nextFn;
|
}
|
summary() {
|
return `Function call`;
|
}
|
async next() {
|
try {
|
return this.nextFn();
|
}
|
catch (e) {
|
// Modify the error message but leave the stack trace intact
|
e.message =
|
`Error thrown while iterating through a dataset: ${e.message}`;
|
throw e;
|
}
|
}
|
}
|
class SerialIterator extends LazyIterator {
|
constructor(upstream) {
|
super();
|
this.upstream = upstream;
|
this.lastRead = Promise.resolve({ value: null, done: false });
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Serial`;
|
}
|
async next() {
|
// This sets this.lastRead to a new Promise right away, as opposed to
|
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
|
// would not work because this.nextRead would be updated only after the
|
// promise resolves.
|
this.lastRead = this.lastRead.then(() => this.serialNext());
|
return this.lastRead;
|
}
|
async serialNext() {
|
return this.upstream.next();
|
}
|
}
|
class SkipIterator extends LazyIterator {
|
constructor(upstream, maxCount) {
|
super();
|
this.upstream = upstream;
|
this.maxCount = maxCount;
|
// Local state that should not be clobbered by out-of-order execution.
|
this.count = 0;
|
this.lastRead = Promise.resolve({ value: null, done: false });
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Skip`;
|
}
|
async next() {
|
// This sets this.lastRead to a new Promise right away, as opposed to
|
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
|
// would not work because this.nextRead would be updated only after the
|
// promise resolves.
|
this.lastRead = this.lastRead.then(() => this.serialNext());
|
return this.lastRead;
|
}
|
async serialNext() {
|
// TODO(soergel): consider tradeoffs of reading in parallel, eg.
|
// collecting next() promises in an Array and then waiting for
|
// Promise.all() of those. Benefit: pseudo-parallel execution. Drawback:
|
// maybe delayed GC.
|
while (this.count++ < this.maxCount) {
|
const skipped = await this.upstream.next();
|
// short-circuit if upstream is already empty
|
if (skipped.done) {
|
return skipped;
|
}
|
tf.dispose(skipped.value);
|
}
|
return this.upstream.next();
|
}
|
}
|
class TakeIterator extends LazyIterator {
|
constructor(upstream, maxCount) {
|
super();
|
this.upstream = upstream;
|
this.maxCount = maxCount;
|
this.count = 0;
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Take`;
|
}
|
async next() {
|
if (this.count++ >= this.maxCount) {
|
return { value: null, done: true };
|
}
|
return this.upstream.next();
|
}
|
}
|
// Note this batch just groups items into row-wise element arrays.
|
// Rotating these to a column-wise representation happens only at the dataset
|
// level.
|
class RowMajorBatchIterator extends LazyIterator {
|
constructor(upstream, batchSize, enableSmallLastBatch = true) {
|
super();
|
this.upstream = upstream;
|
this.batchSize = batchSize;
|
this.enableSmallLastBatch = enableSmallLastBatch;
|
this.lastRead = Promise.resolve({ value: null, done: false });
|
}
|
summary() {
|
return `${this.upstream.summary()} -> RowMajorBatch`;
|
}
|
async next() {
|
// This sets this.lastRead to a new Promise right away, as opposed to
|
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
|
// would not work because this.nextRead would be updated only after the
|
// promise resolves.
|
this.lastRead = this.lastRead.then(() => this.serialNext());
|
return this.lastRead;
|
}
|
async serialNext() {
|
const batch = [];
|
while (batch.length < this.batchSize) {
|
const item = await this.upstream.next();
|
if (item.done) {
|
if (this.enableSmallLastBatch && batch.length > 0) {
|
return { value: batch, done: false };
|
}
|
return { value: null, done: true };
|
}
|
batch.push(item.value);
|
}
|
return { value: batch, done: false };
|
}
|
}
|
class FilterIterator extends LazyIterator {
|
constructor(upstream, predicate) {
|
super();
|
this.upstream = upstream;
|
this.predicate = predicate;
|
this.lastRead = Promise.resolve({ value: null, done: false });
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Filter`;
|
}
|
async next() {
|
// This sets this.lastRead to a new Promise right away, as opposed to
|
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
|
// would not work because this.nextRead would be updated only after the
|
// promise resolves.
|
this.lastRead = this.lastRead.then(() => this.serialNext());
|
return this.lastRead;
|
}
|
async serialNext() {
|
while (true) {
|
const item = await this.upstream.next();
|
if (item.done || this.predicate(item.value)) {
|
return item;
|
}
|
tf.dispose(item.value);
|
}
|
}
|
}
|
class MapIterator extends LazyIterator {
|
constructor(upstream, transform) {
|
super();
|
this.upstream = upstream;
|
this.transform = transform;
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Map`;
|
}
|
async next() {
|
const item = await this.upstream.next();
|
if (item.done) {
|
return { value: null, done: true };
|
}
|
const inputTensors = tf.tensor_util.getTensorsInContainer(item.value);
|
// Careful: the transform may mutate the item in place.
|
// That's why we have to remember the input Tensors above, and then
|
// below dispose only those that were not passed through to the output.
|
// Note too that the transform function is responsible for tidying
|
// any intermediate Tensors. Here we are concerned only about the
|
// inputs.
|
const mapped = this.transform(item.value);
|
const outputTensors = tf.tensor_util.getTensorsInContainer(mapped);
|
// TODO(soergel) faster intersection
|
// TODO(soergel) move to tf.disposeExcept(in, out)?
|
for (const t of inputTensors) {
|
if (!tf.tensor_util.isTensorInList(t, outputTensors)) {
|
t.dispose();
|
}
|
}
|
return { value: mapped, done: false };
|
}
|
}
|
class ErrorHandlingLazyIterator extends LazyIterator {
|
constructor(upstream, handler) {
|
super();
|
this.upstream = upstream;
|
this.handler = handler;
|
this.count = 0;
|
this.lastRead = Promise.resolve({ value: null, done: false });
|
}
|
summary() {
|
return `${this.upstream.summary()} -> handleErrors`;
|
}
|
async next() {
|
// This sets this.lastRead to a new Promise right away, as opposed to
|
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
|
// would not work because this.nextRead would be updated only after the
|
// promise resolves.
|
this.lastRead = this.lastRead.then(() => this.serialNext());
|
return this.lastRead;
|
}
|
async serialNext() {
|
while (true) {
|
try {
|
return await this.upstream.next();
|
}
|
catch (e) {
|
if (!this.handler(e)) {
|
return { value: null, done: true };
|
}
|
// If the handler returns true, loop and fetch the next upstream item.
|
// If the upstream iterator throws an endless stream of errors, and if
|
// the handler says to ignore them, then we loop forever here. That is
|
// the correct behavior-- it's up to the handler to decide when to stop.
|
}
|
}
|
}
|
}
|
class AsyncMapIterator extends LazyIterator {
|
constructor(upstream, transform) {
|
super();
|
this.upstream = upstream;
|
this.transform = transform;
|
}
|
summary() {
|
return `${this.upstream.summary()} -> AsyncMap`;
|
}
|
async next() {
|
const item = await this.upstream.next();
|
if (item.done) {
|
return { value: null, done: true };
|
}
|
const inputTensors = tf.tensor_util.getTensorsInContainer(item.value);
|
// Careful: the transform may mutate the item in place.
|
// That's why we have to remember the input Tensors above, and then
|
// below dispose only those that were not passed through to the output.
|
// Note too that the transform function is responsible for tidying
|
// any intermediate Tensors. Here we are concerned only about the
|
// inputs.
|
const mapped = await this.transform(item.value);
|
const outputTensors = tf.tensor_util.getTensorsInContainer(mapped);
|
// TODO(soergel) faster intersection
|
// TODO(soergel) move to tf.disposeExcept(in, out)?
|
for (const t of inputTensors) {
|
if (!tf.tensor_util.isTensorInList(t, outputTensors)) {
|
t.dispose();
|
}
|
}
|
return { value: mapped, done: false };
|
}
|
}
|
// Iterators that maintain a queue of pending items
|
// ============================================================================
|
/**
|
* A base class for transforming streams that operate by maintaining an
|
* output queue of elements that are ready to return via next(). This is
|
* commonly required when the transformation is 1-to-many: A call to next()
|
* may trigger a call to the underlying stream, which will produce many
|
* mapped elements of this stream-- of which we need to return only one, so
|
* we have to queue the rest.
|
*/
|
class OneToManyIterator extends LazyIterator {
|
constructor() {
|
super();
|
this.outputQueue = new GrowingRingBuffer();
|
this.lastRead = Promise.resolve({ value: null, done: false });
|
}
|
async next() {
|
// This sets this.lastRead to a new Promise right away, as opposed to
|
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
|
// would not work because this.nextRead would be updated only after the
|
// promise resolves.
|
this.lastRead = this.lastRead.then(() => this.serialNext());
|
return this.lastRead;
|
}
|
async serialNext() {
|
// Fetch so that the queue contains at least one item if possible.
|
// If the upstream source is exhausted, AND there are no items left in
|
// the output queue, then this stream is also exhausted.
|
while (this.outputQueue.length() === 0) {
|
// TODO(soergel): consider parallel reads.
|
if (!await this.pump()) {
|
return { value: null, done: true };
|
}
|
}
|
return { value: this.outputQueue.shift(), done: false };
|
}
|
}
|
class FlatmapIterator extends OneToManyIterator {
|
constructor(upstream, transform) {
|
super();
|
this.upstream = upstream;
|
this.transform = transform;
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Flatmap`;
|
}
|
async pump() {
|
const item = await this.upstream.next();
|
if (item.done) {
|
return false;
|
}
|
const inputTensors = tf.tensor_util.getTensorsInContainer(item.value);
|
// Careful: the transform may mutate the item in place.
|
// that's why we have to remember the input Tensors above, and then
|
// below dispose only those that were not passed through to the output.
|
// Note too that the transform function is responsible for tidying any
|
// intermediate Tensors. Here we are concerned only about the inputs.
|
const mappedArray = this.transform(item.value);
|
const outputTensors = tf.tensor_util.getTensorsInContainer(mappedArray);
|
this.outputQueue.pushAll(mappedArray);
|
// TODO(soergel) faster intersection, and deduplicate outputTensors
|
// TODO(soergel) move to tf.disposeExcept(in, out)?
|
for (const t of inputTensors) {
|
if (!tf.tensor_util.isTensorInList(t, outputTensors)) {
|
t.dispose();
|
}
|
}
|
return true;
|
}
|
}
|
/**
|
* Provides a `LazyIterator` that concatenates a stream of underlying
|
* streams.
|
*
|
* Doing this in a concurrency-safe way requires some trickery. In
|
* particular, we want this stream to return the elements from the
|
* underlying streams in the correct order according to when next() was
|
* called, even if the resulting Promises resolve in a different order.
|
*/
|
class ChainedIterator extends LazyIterator {
|
constructor(iterators, baseErrorHandler) {
|
super();
|
this.baseErrorHandler = baseErrorHandler;
|
// Strict Promise execution order:
|
// a next() call may not even begin until the previous one completes.
|
this.lastRead = null;
|
// Local state that should not be clobbered by out-of-order execution.
|
this.iterator = null;
|
this.moreIterators = iterators;
|
}
|
summary() {
|
const upstreamSummaries = 'TODO: fill in upstream of chained summaries';
|
return `${upstreamSummaries} -> Chained`;
|
}
|
async next() {
|
this.lastRead = this.readFromChain(this.lastRead);
|
return this.lastRead;
|
}
|
async readFromChain(lastRead) {
|
// Must await on the previous read since the previous read may have advanced
|
// the stream of streams, from which we need to read.
|
// This is unfortunate since we can't parallelize reads. Which means
|
// prefetching of chained streams is a no-op.
|
// One solution is to prefetch immediately upstream of this.
|
await lastRead;
|
if (this.iterator == null) {
|
const iteratorResult = await this.moreIterators.next();
|
if (iteratorResult.done) {
|
// No more streams to stream from.
|
return { value: null, done: true };
|
}
|
this.iterator = iteratorResult.value;
|
if (this.baseErrorHandler != null) {
|
this.iterator = this.iterator.handleErrors(this.baseErrorHandler);
|
}
|
}
|
const itemResult = await this.iterator.next();
|
if (itemResult.done) {
|
this.iterator = null;
|
return this.readFromChain(lastRead);
|
}
|
return itemResult;
|
}
|
}
|
var ZipMismatchMode;
|
(function (ZipMismatchMode) {
|
ZipMismatchMode[ZipMismatchMode["FAIL"] = 0] = "FAIL";
|
ZipMismatchMode[ZipMismatchMode["SHORTEST"] = 1] = "SHORTEST";
|
ZipMismatchMode[ZipMismatchMode["LONGEST"] = 2] = "LONGEST"; // use nulls for exhausted streams; use up the longest stream.
|
})(ZipMismatchMode || (ZipMismatchMode = {}));
|
/**
|
* Provides a `LazyIterator` that zips together an array, dict, or nested
|
* structure of `LazyIterator`s (and perhaps additional constants).
|
*
|
* The underlying streams must provide elements in a consistent order such
|
* that they correspond.
|
*
|
* Typically, the underlying streams should have the same number of
|
* elements. If they do not, the behavior is determined by the
|
* `mismatchMode` argument.
|
*
|
* The nested structure of the `iterators` argument determines the
|
* structure of elements in the resulting iterator.
|
*
|
* Doing this in a concurrency-safe way requires some trickery. In
|
* particular, we want this stream to return the elements from the
|
* underlying streams in the correct order according to when next() was
|
* called, even if the resulting Promises resolve in a different order.
|
*
|
* @param iterators: An array or object containing LazyIterators at the
|
* leaves.
|
* @param mismatchMode: Determines what to do when one underlying iterator
|
* is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
|
* causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
|
* causes the zipped iterator to terminate with the furst underlying
|
* streams, so elements remaining on the longer streams are ignored.
|
* `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
|
* in nulls for the exhausted streams, until all streams are exhausted.
|
*/
|
class ZipIterator extends LazyIterator {
|
constructor(iterators, mismatchMode = ZipMismatchMode.FAIL) {
|
super();
|
this.iterators = iterators;
|
this.mismatchMode = mismatchMode;
|
this.count = 0;
|
this.currentPromise = null;
|
}
|
summary() {
|
const upstreamSummaries = 'TODO: fill in upstream of zip summaries';
|
return `{${upstreamSummaries}} -> Zip`;
|
}
|
async nextState(afterState) {
|
// This chaining ensures that the underlying next() are not even called
|
// before the previous ones have resolved.
|
await afterState;
|
// Collect underlying iterator "done" signals as a side effect in
|
// getNext()
|
let numIterators = 0;
|
let iteratorsDone = 0;
|
function getNext(container) {
|
if (container instanceof LazyIterator) {
|
const result = container.next();
|
return {
|
value: result.then(x => {
|
numIterators++;
|
if (x.done) {
|
iteratorsDone++;
|
}
|
return x.value;
|
}),
|
recurse: false
|
};
|
}
|
else {
|
return { value: null, recurse: true };
|
}
|
}
|
const mapped = await deepMapAndAwaitAll(this.iterators, getNext);
|
if (numIterators === iteratorsDone) {
|
// The streams have all ended.
|
return { value: null, done: true };
|
}
|
if (iteratorsDone > 0) {
|
switch (this.mismatchMode) {
|
case ZipMismatchMode.FAIL:
|
throw new Error('Zipped streams should have the same length. ' +
|
`Mismatched at element ${this.count}.`);
|
case ZipMismatchMode.SHORTEST:
|
return { value: null, done: true };
|
case ZipMismatchMode.LONGEST:
|
// Continue. The exhausted streams already produced value: null.
|
}
|
}
|
this.count++;
|
return { value: mapped, done: false };
|
}
|
async next() {
|
this.currentPromise = this.nextState(this.currentPromise);
|
return this.currentPromise;
|
}
|
}
|
// Iterators that maintain a ring buffer of pending promises
|
// ============================================================================
|
/**
|
* A stream that prefetches a given number of items from an upstream source,
|
* returning them in FIFO order.
|
*
|
* Note this prefetches Promises, but makes no guarantees about when those
|
* Promises resolve.
|
*/
|
class PrefetchIterator extends LazyIterator {
|
constructor(upstream, bufferSize) {
|
super();
|
this.upstream = upstream;
|
this.bufferSize = bufferSize;
|
this.buffer = new RingBuffer(bufferSize);
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Prefetch`;
|
}
|
/**
|
* Refill the prefetch buffer. Returns only after the buffer is full, or
|
* the upstream source is exhausted.
|
*/
|
refill() {
|
while (!this.buffer.isFull()) {
|
const v = this.upstream.next();
|
this.buffer.push(v);
|
}
|
}
|
next() {
|
this.refill();
|
// This shift will never throw an error because the buffer is always
|
// full after a refill. If the stream is exhausted, the buffer will be
|
// full of Promises that will resolve to the end-of-stream signal.
|
return this.buffer.shift();
|
}
|
}
|
/**
|
* A stream that performs a sliding-window random shuffle on an upstream
|
* source. This is like a `PrefetchIterator` except that the items are
|
* returned in randomized order. Mixing naturally improves as the buffer
|
* size increases.
|
*/
|
class ShuffleIterator extends PrefetchIterator {
|
constructor(upstream, windowSize, seed) {
|
super(upstream, windowSize);
|
this.upstream = upstream;
|
this.windowSize = windowSize;
|
// Local state that should not be clobbered by out-of-order execution.
|
this.upstreamExhausted = false;
|
this.random = seedrandom.alea(seed || tf.util.now().toString());
|
this.lastRead = Promise.resolve({ value: null, done: false });
|
}
|
async next() {
|
// This sets this.lastRead to a new Promise right away, as opposed to
|
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
|
// would not work because this.nextRead would be updated only after the
|
// promise resolves.
|
this.lastRead = this.lastRead.then(() => this.serialNext());
|
return this.lastRead;
|
}
|
randomInt(max) {
|
return Math.floor(this.random() * max);
|
}
|
chooseIndex() {
|
return this.randomInt(this.buffer.length());
|
}
|
async serialNext() {
|
// TODO(soergel): consider performance
|
if (!this.upstreamExhausted) {
|
this.refill();
|
}
|
while (!this.buffer.isEmpty()) {
|
const chosenIndex = this.chooseIndex();
|
const result = await this.buffer.shuffleExcise(chosenIndex);
|
if (result.done) {
|
this.upstreamExhausted = true;
|
}
|
else {
|
this.refill();
|
return result;
|
}
|
}
|
return { value: null, done: true };
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
// TODO(soergel): consider vectorized operations within the pipeline.
|
/**
|
* Represents a potentially large list of independent data elements (typically
|
* 'samples' or 'examples').
|
*
|
* A 'data example' may be a primitive, an array, a map from string keys to
|
* values, or any nested structure of these.
|
*
|
* A `Dataset` represents an ordered collection of elements, together with a
|
* chain of transformations to be performed on those elements. Each
|
* transformation is a method of `Dataset` that returns another `Dataset`, so
|
* these may be chained, e.g.
|
* `const processedDataset = rawDataset.filter(...).map(...).batch(...)`.
|
*
|
* Data loading and transformation is done in a lazy, streaming fashion. The
|
* dataset may be iterated over multiple times; each iteration starts the data
|
* loading anew and recapitulates the transformations.
|
*
|
* A `Dataset` is typically processed as a stream of unbatched examples -- i.e.,
|
* its transformations are applied one example at a time. Batching produces a
|
* new `Dataset` where each element is a batch. Batching should usually come
|
* last in a pipeline, because data transformations are easier to express on a
|
* per-example basis than on a per-batch basis.
|
*
|
* The following code examples are calling `await dataset.forEachAsync(...)` to
|
* iterate once over the entire dataset in order to print out the data.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
|
*/
|
class Dataset {
|
constructor() {
|
this.size = null;
|
}
|
// TODO(soergel): Make Datasets report whether repeated iterator() calls
|
// produce the same result (e.g., reading from a file) or different results
|
// (e.g., from the webcam). Currently we don't make this distinction but it
|
// could be important for the user to know.
|
// abstract isDeterministic(): boolean;
|
/**
|
* Groups elements into batches.
|
*
|
* It is assumed that each of the incoming dataset elements has the same
|
* structure -- i.e. the same set of keys at each location in an object
|
* hierarchy. For each key, the resulting `Dataset` provides a batched
|
* element collecting all of the incoming values for that key.
|
*
|
* * Incoming primitives are grouped into a 1-D Tensor.
|
* * Incoming Tensors are grouped into a new Tensor where the 0th axis is
|
* the batch dimension.
|
* * Incoming arrays are converted to Tensor and then batched.
|
* * A nested array is interpreted as an n-D Tensor, so the batched result
|
* has n+1 dimensions.
|
* * An array that cannot be converted to Tensor produces an error.
|
*
|
* If an array should not be batched as a unit, it should first be converted
|
* to an object with integer keys.
|
*
|
* Here are a few examples:
|
*
|
* Batch a dataset of numbers:
|
* ```js
|
* const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8]).batch(4);
|
* await a.forEachAsync(e => e.print());
|
* ```
|
*
|
* Batch a dataset of arrays:
|
* ```js
|
* const b = tf.data.array([[1], [2], [3], [4], [5], [6], [7], [8]]).batch(4);
|
* await b.forEachAsync(e => e.print());
|
* ```
|
*
|
* Batch a dataset of objects:
|
* ```js
|
* const c = tf.data.array([{a: 1, b: 11}, {a: 2, b: 12}, {a: 3, b: 13},
|
* {a: 4, b: 14}, {a: 5, b: 15}, {a: 6, b: 16}, {a: 7, b: 17},
|
* {a: 8, b: 18}]).batch(4);
|
* await c.forEachAsync(e => {
|
* console.log('{');
|
* for(var key in e) {
|
* console.log(key+':');
|
* e[key].print();
|
* }
|
* console.log('}');
|
* })
|
* ```
|
*
|
* @param batchSize The number of elements desired per batch.
|
* @param smallLastBatch Whether to emit the final batch when it has fewer
|
* than batchSize elements. Default true.
|
* @returns A `Dataset`, from which a stream of batches can be obtained.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
batch(batchSize, smallLastBatch = true) {
|
const base = this;
|
tf.util.assert(batchSize > 0, () => `batchSize needs to be positive, but it is
|
${batchSize}`);
|
let size;
|
if (this.size === Infinity || this.size == null) {
|
// If the size of this dataset is infinity or null, the new size keeps the
|
// same.
|
size = this.size;
|
}
|
else if (smallLastBatch) {
|
// If the size of this dataset is known and include small last batch, the
|
// new size is full batch count plus last batch.
|
size = Math.ceil(this.size / batchSize);
|
}
|
else {
|
// If the size of this dataset is known and not include small last batch,
|
// the new size is full batch count.
|
size = Math.floor(this.size / batchSize);
|
}
|
return datasetFromIteratorFn(async () => {
|
return (await base.iterator())
|
.columnMajorBatch(batchSize, smallLastBatch, deepBatchConcat);
|
}, size);
|
}
|
/**
|
* Concatenates this `Dataset` with another.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3]);
|
* const b = tf.data.array([4, 5, 6]);
|
* const c = a.concatenate(b);
|
* await c.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param dataset A `Dataset` to be concatenated onto this one.
|
* @returns A `Dataset`.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
concatenate(dataset) {
|
const base = this;
|
let size;
|
if (this.size === Infinity || dataset.size === Infinity) {
|
// If the size of any of these two dataset is infinity, new size is
|
// infinity.
|
size = Infinity;
|
}
|
else if (this.size != null && dataset.size != null) {
|
// If the size of both datasets are known and not infinity, new size is
|
// sum the size of these two datasets.
|
size = this.size + dataset.size;
|
}
|
else {
|
// If neither of these two datasets has infinite size and any of these two
|
// datasets' size is null, the new size is null.
|
size = null;
|
}
|
return datasetFromIteratorFn(async () => (await base.iterator()).concatenate(await dataset.iterator()), size);
|
}
|
/**
|
* Filters this dataset according to `predicate`.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
* .filter(x => x%2 === 0);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param predicate A function mapping a dataset element to a boolean or a
|
* `Promise` for one.
|
*
|
* @returns A `Dataset` of elements for which the predicate was true.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
filter(predicate) {
|
const base = this;
|
let size;
|
if (this.size === Infinity) {
|
// If the size of this dataset is infinity, new size is infinity
|
size = Infinity;
|
}
|
else {
|
// If this dataset has limited elements, new size is null because it might
|
// exhausted randomly.
|
size = null;
|
}
|
return datasetFromIteratorFn(async () => {
|
return (await base.iterator()).filter(x => tf.tidy(() => predicate(x)));
|
}, size);
|
}
|
/**
|
* Apply a function to every element of the dataset.
|
*
|
* After the function is applied to a dataset element, any Tensors contained
|
* within that element are disposed.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3]);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param f A function to apply to each dataset element.
|
* @returns A `Promise` that resolves after all elements have been processed.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
async forEachAsync(f) {
|
return (await this.iterator()).forEachAsync(f);
|
}
|
/**
|
* Maps this dataset through a 1-to-1 transform.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3]).map(x => x*x);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param transform A function mapping a dataset element to a transformed
|
* dataset element.
|
*
|
* @returns A `Dataset` of transformed elements.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
map(transform) {
|
const base = this;
|
return datasetFromIteratorFn(async () => {
|
return (await base.iterator()).map(x => tf.tidy(() => transform(x)));
|
}, this.size);
|
}
|
/**
|
* Maps this dataset through an async 1-to-1 transform.
|
*
|
* ```js
|
* const a =
|
* tf.data.array([1, 2, 3]).mapAsync(x => new Promise(function(resolve){
|
* setTimeout(() => {
|
* resolve(x * x);
|
* }, Math.random()*1000 + 500);
|
* }));
|
* console.log(await a.toArray());
|
* ```
|
*
|
* @param transform A function mapping a dataset element to a `Promise` for a
|
* transformed dataset element. This transform is responsible for disposing
|
* any intermediate `Tensor`s, i.e. by wrapping its computation in
|
* `tf.tidy()`; that cannot be automated here (as it is in the synchronous
|
* `map()` case).
|
*
|
* @returns A `Dataset` of transformed elements.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
mapAsync(transform) {
|
const base = this;
|
return datasetFromIteratorFn(async () => {
|
return (await base.iterator()).mapAsync(transform);
|
}, this.size);
|
}
|
/**
|
* Creates a `Dataset` that prefetches elements from this dataset.
|
*
|
* @param bufferSize: An integer specifying the number of elements to be
|
* prefetched.
|
* @returns A `Dataset`.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
prefetch(bufferSize) {
|
if (bufferSize == null) {
|
throw new RangeError('`Dataset.prefetch()` requires bufferSize to be specified.');
|
}
|
const base = this;
|
return datasetFromIteratorFn(async () => (await base.iterator()).prefetch(bufferSize), this.size);
|
}
|
/**
|
* Repeats this dataset `count` times.
|
*
|
* NOTE: If this dataset is a function of global state (e.g. a random number
|
* generator), then different repetitions may produce different elements.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3]).repeat(3);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param count: (Optional) An integer, representing the number of times
|
* the dataset should be repeated. The default behavior (if `count` is
|
* `undefined` or negative) is for the dataset be repeated indefinitely.
|
* @returns A `Dataset`.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
repeat(count) {
|
const base = this;
|
let size;
|
if (this.size != null && count > 0) {
|
// If this dataset has size and count is positive, new size is current
|
// size multiply count. This also covers the case that current size is
|
// infinity.
|
size = this.size * count;
|
}
|
else if (count === 0) {
|
// If count is 0, new size is 0.
|
size = 0;
|
}
|
else if (this.size != null && (count === undefined || count < 0)) {
|
// If this dataset has size and count is undefined or negative, the
|
// dataset will be repeated indefinitely and new size is infinity.
|
size = Infinity;
|
}
|
else {
|
// If the size of this dataset is null, the new dataset's size is null.
|
size = null;
|
}
|
return datasetFromIteratorFn(async () => {
|
const iteratorIterator = iteratorFromFunction(async () => ({ value: await base.iterator(), done: false }));
|
return iteratorFromConcatenated(iteratorIterator.take(count));
|
}, size);
|
}
|
/**
|
* Creates a `Dataset` that skips `count` initial elements from this dataset.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3, 4, 5, 6]).skip(3);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param count: The number of elements of this dataset that should be skipped
|
* to form the new dataset. If `count` is greater than the size of this
|
* dataset, the new dataset will contain no elements. If `count`
|
* is `undefined` or negative, skips the entire dataset.
|
*
|
* @returns A `Dataset`.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
skip(count) {
|
const base = this;
|
let size;
|
if (this.size != null && count >= 0 && this.size >= count) {
|
// If the size of this dataset is greater than count, the new dataset's
|
// size is current size minus skipped size.This also covers the case that
|
// current size is infinity.
|
size = this.size - count;
|
}
|
else if (this.size != null &&
|
(this.size < count || count === undefined || count < 0)) {
|
// If the size of this dataset is smaller than count, or count is
|
// undefined or negative, skips the entire dataset and the new size is 0.
|
size = 0;
|
}
|
else {
|
// If the size of this dataset is null, the new dataset's size is null.
|
size = null;
|
}
|
return datasetFromIteratorFn(async () => (await base.iterator()).skip(count), size);
|
}
|
/**
|
* Pseudorandomly shuffles the elements of this dataset. This is done in a
|
* streaming manner, by sampling from a given number of prefetched elements.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3, 4, 5, 6]).shuffle(3);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param bufferSize: An integer specifying the number of elements from this
|
* dataset from which the new dataset will sample.
|
* @param seed: (Optional) An integer specifying the random seed that will
|
* be used to create the distribution.
|
* @param reshuffleEachIteration: (Optional) A boolean, which if true
|
* indicates that the dataset should be pseudorandomly reshuffled each time
|
* it is iterated over. If false, elements will be returned in the same
|
* shuffled order on each iteration. (Defaults to `true`.)
|
* @returns A `Dataset`.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
shuffle(bufferSize, seed, reshuffleEachIteration = true) {
|
if (bufferSize == null || bufferSize < 0) {
|
if (this.size == null) {
|
throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified.');
|
}
|
else {
|
throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified. ' +
|
'If your data fits in main memory (for regular JS objects), ' +
|
'and/or GPU memory (for `tf.Tensor`s), consider setting ' +
|
`bufferSize to the dataset size (${this.size} elements)`);
|
}
|
}
|
const base = this;
|
const random = seedrandom.alea(seed || tf.util.now().toString());
|
return datasetFromIteratorFn(async () => {
|
let seed2 = random.int32();
|
if (reshuffleEachIteration) {
|
seed2 += random.int32();
|
}
|
return (await base.iterator()).shuffle(bufferSize, seed2.toString());
|
}, this.size);
|
}
|
/**
|
* Creates a `Dataset` with at most `count` initial elements from this
|
* dataset.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3, 4, 5, 6]).take(3);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param count: The number of elements of this dataset that should be taken
|
* to form the new dataset. If `count` is `undefined` or negative, or if
|
* `count` is greater than the size of this dataset, the new dataset will
|
* contain all elements of this dataset.
|
* @returns A `Dataset`.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
take(count) {
|
const base = this;
|
let size;
|
if (this.size != null && this.size > count) {
|
// If the size of this dataset is greater than count, the new dataset's
|
// size is count.
|
size = count;
|
}
|
else if (this.size != null && this.size <= count) {
|
// If the size of this dataset is equal or smaller than count, the new
|
// dataset's size is the size of this dataset.
|
size = this.size;
|
}
|
else {
|
// If the size of this dataset is null, the new dataset's size is null.
|
size = null;
|
}
|
return datasetFromIteratorFn(async () => (await base.iterator()).take(count), size);
|
}
|
/**
|
* Collect all elements of this dataset into an array.
|
*
|
* Obviously this will succeed only for small datasets that fit in memory.
|
* Useful for testing and generally should be avoided if possible.
|
*
|
* ```js
|
* const a = tf.data.array([1, 2, 3, 4, 5, 6]);
|
* console.log(await a.toArray());
|
* ```
|
*
|
* @returns A Promise for an array of elements, which will resolve
|
* when a new stream has been obtained and fully consumed.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
async toArray() {
|
if (this.size === Infinity) {
|
throw new Error('Can not convert infinite data stream to array.');
|
}
|
return (await this.iterator()).toArray();
|
}
|
/**
|
* Collect all elements of this dataset into an array with prefetching 100
|
* elements. This is useful for testing, because the prefetch changes the
|
* order in which the Promises are resolved along the processing pipeline.
|
* This may help expose bugs where results are dependent on the order of
|
* Promise resolution rather than on the logical order of the stream (i.e.,
|
* due to hidden mutable state).
|
*
|
* @returns A Promise for an array of elements, which will resolve
|
* when a new stream has been obtained and fully consumed.
|
*/
|
async toArrayForTest() {
|
if (this.size === Infinity) {
|
throw new Error('Can not convert infinite data stream to array.');
|
}
|
return (await this.iterator()).toArrayForTest();
|
}
|
}
|
// TODO(soergel): deep sharded shuffle, where supported
|
Dataset.MAX_BUFFER_SIZE = 10000;
|
/**
|
* Create a `Dataset` defined by a provided iterator() function.
|
*
|
* ```js
|
* let i = -1;
|
* const func = () =>
|
* ++i < 5 ? {value: i, done: false} : {value: null, done: true};
|
* const iter = tf.data.iteratorFromFunction(func);
|
* const ds = tf.data.datasetFromIteratorFn(iter);
|
* await ds.forEachAsync(e => console.log(e));
|
* ```
|
*/
|
function datasetFromIteratorFn(iteratorFn, size = null) {
|
return new class extends Dataset {
|
constructor() {
|
super(...arguments);
|
this.size = size;
|
}
|
/*
|
* Provide a new stream of elements. Note this will also start new streams
|
* from any underlying `Dataset`s.
|
*/
|
async iterator() {
|
return iteratorFn();
|
}
|
}();
|
}
|
/**
|
* Create a `Dataset` from an array of elements.
|
*
|
* Create a Dataset from an array of objects:
|
* ```js
|
* const a = tf.data.array([{'item': 1}, {'item': 2}, {'item': 3}]);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* Create a Dataset from an array of numbers:
|
* ```js
|
* const a = tf.data.array([4, 5, 6]);
|
* await a.forEachAsync(e => console.log(e));
|
* ```
|
* @param items An array of elements that will be parsed as items in a dataset.
|
*
|
* @doc {heading: 'Data', subheading: 'Creation', namespace: 'data'}
|
*/
|
function array(items) {
|
return datasetFromIteratorFn(async () => iteratorFromItems(items), items.length);
|
}
|
/**
|
* Create a `Dataset` by zipping together an array, dict, or nested
|
* structure of `Dataset`s (and perhaps additional constants).
|
* The underlying datasets must provide elements in a consistent order such that
|
* they correspond.
|
*
|
* The number of elements in the resulting dataset is the same as the size of
|
* the smallest dataset in datasets.
|
*
|
* The nested structure of the `datasets` argument determines the
|
* structure of elements in the resulting iterator.
|
*
|
* Note this means that, given an array of two datasets that produce dict
|
* elements, the result is a dataset that produces elements that are arrays
|
* of two dicts:
|
*
|
* Zip an array of datasets:
|
* ```js
|
* console.log('Zip two datasets of objects:');
|
* const ds1 = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
|
* const ds2 = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
|
* const ds3 = tf.data.zip([ds1, ds2]);
|
* await ds3.forEachAsync(e => console.log(JSON.stringify(e)));
|
*
|
* // If the goal is to merge the dicts in order to produce elements like
|
* // {a: ..., b: ...}, this requires a second step such as:
|
* console.log('Merge the objects:');
|
* const ds4 = ds3.map(x => {return {a: x[0].a, b: x[1].b}});
|
* await ds4.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* Zip a dict of datasets:
|
* ```js
|
* const a = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
|
* const b = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
|
* const c = tf.data.zip({c: a, d: b});
|
* await c.forEachAsync(e => console.log(JSON.stringify(e)));
|
* ```
|
*
|
* @doc {heading: 'Data', subheading: 'Operations', namespace: 'data'}
|
*/
|
function zip(datasets) {
|
// manually type-check the argument for JS users
|
if (!isIterable(datasets)) {
|
throw new Error('The argument to zip() must be an object or array.');
|
}
|
let size;
|
if (Array.isArray(datasets)) {
|
for (let i = 0; i < datasets.length; i++) {
|
size = size == null ? datasets[i].size :
|
Math.min(size, datasets[i].size);
|
}
|
}
|
else if (datasets instanceof Object) {
|
for (const ds in datasets) {
|
size = size == null ? datasets[ds].size :
|
Math.min(size, datasets[ds].size);
|
}
|
}
|
return datasetFromIteratorFn(async () => {
|
const streams = await deepMapAndAwaitAll(datasets, d => {
|
if (d instanceof Dataset) {
|
return { value: d.iterator(), recurse: false };
|
}
|
else if (isIterable(d)) {
|
return { value: null, recurse: true };
|
}
|
else {
|
throw new Error('Leaves of the structure passed to zip() must be Datasets, ' +
|
'not primitives.');
|
}
|
});
|
return iteratorFromZipped(streams, ZipMismatchMode.SHORTEST);
|
}, size);
|
}
|
/**
|
* A zip function for use with deepZip, passed via the columnMajorBatch call.
|
*
|
* Accepts an array of identically-structured nested elements and either batches
|
* them (if they are primitives, numeric arrays, or Tensors) or requests
|
* recursion (if not).
|
*/
|
// tslint:disable-next-line:no-any
|
function deepBatchConcat(rows) {
|
if (rows === null) {
|
return null;
|
}
|
// use the first item to decide whether to recurse or batch here.
|
const exampleRow = rows[0];
|
if (canTensorify(exampleRow)) {
|
// rows is an array of primitives, Tensors, or arrays. Batch them.
|
const value = batchConcat(rows);
|
return { value, recurse: false };
|
}
|
// the example row is an object, so recurse into it.
|
return { value: null, recurse: true };
|
}
|
/**
|
* Assembles a list of same-shaped numbers, number arrays, or Tensors
|
* into a single new Tensor where axis 0 is the batch dimension.
|
*/
|
function batchConcat(arrays) {
|
if (arrays.length === 0) {
|
// We can't return an empty Tensor because we don't know the element shape.
|
throw new Error('Can\'t make a batch of zero elements.');
|
}
|
if (arrays[0] instanceof tf.Tensor) {
|
// Input is an array of Tensors
|
return tf.stack(arrays);
|
}
|
else {
|
// Input is a possibly-nested array of numbers.
|
return tf.tensor(arrays);
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Represents a potentially large collection of text lines.
|
*
|
* The results are not batched.
|
*/
|
class TextLineDataset extends Dataset {
|
/**
|
* Create a `TextLineDataset`.
|
*
|
* @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
|
*/
|
constructor(input) {
|
super();
|
this.input = input;
|
}
|
async iterator() {
|
const inputIterator = await this.input.iterator();
|
const utf8Iterator = inputIterator.decodeUTF8();
|
const lineIterator = utf8Iterator.split('\n').map(line => {
|
// Windows/DOS format text file has extra line breaker at the end of line.
|
if (line.endsWith('\r')) {
|
line = line.slice(0, -1);
|
}
|
return line;
|
});
|
return lineIterator;
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
const CODE_QUOTE = '"';
|
const STATE_OUT = Symbol('out');
|
const STATE_FIELD = Symbol('field');
|
const STATE_QUOTE = Symbol('quote');
|
const STATE_QUOTE_AFTER_QUOTE = Symbol('quoteafterquote');
|
const STATE_WITHIN_QUOTE_IN_QUOTE = Symbol('quoteinquote');
|
/**
|
* Represents a potentially large collection of delimited text records.
|
*
|
* The produced `TensorContainer`s each contain one key-value pair for
|
* every column of the table. When a field is empty in the incoming data, the
|
* resulting value is `undefined`, or throw error if it is required. Values
|
* that can be parsed as numbers are emitted as type `number`, other values
|
* are parsed as `string`.
|
*
|
* The results are not batched.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
|
*/
|
class CSVDataset extends Dataset {
|
/**
|
* Returns column names of the csv dataset. If `configuredColumnsOnly` is
|
* true, return column names in `columnConfigs`. If `configuredColumnsOnly` is
|
* false and `columnNames` is provided, `columnNames`. If
|
* `configuredColumnsOnly` is false and `columnNames` is not provided, return
|
* all column names parsed from the csv file. For example usage please go to
|
* `tf.data.csv`.
|
*
|
* @doc {heading: 'Data', subheading: 'Classes'}
|
*/
|
async columnNames() {
|
if (!this.columnNamesValidated) {
|
await this.setColumnNames();
|
}
|
return this.configuredColumnsOnly ? Object.keys(this.columnConfigs) :
|
this.fullColumnNames;
|
}
|
/* 1) If `columnNames` is provided as string[], use this string[] as output
|
* keys in corresponding order. The length must match the number of inferred
|
* columns if `hasHeader` is true .
|
* 2) If `columnNames` is not provided, parse header line as `columnNames` if
|
* hasHeader is true. If `hasHeader` is false, throw an error.
|
* 3) If `columnConfigs` is provided, all the keys in `columnConfigs` must
|
* exist in parsed `columnNames`.
|
*/
|
async setColumnNames() {
|
const columnNamesFromFile = await this.maybeReadHeaderLine();
|
if (!this.fullColumnNames && !columnNamesFromFile) {
|
// Throw an error if columnNames is not provided and no header line.
|
throw new Error('Column names must be provided if there is no header line.');
|
}
|
else if (this.fullColumnNames && columnNamesFromFile) {
|
// Check provided columnNames match header line.
|
util.assert(columnNamesFromFile.length === this.fullColumnNames.length, () => 'The length of provided columnNames (' +
|
this.fullColumnNames.length.toString() +
|
') does not match the length of the header line read from ' +
|
'file (' + columnNamesFromFile.length.toString() + ').');
|
}
|
if (!this.fullColumnNames) {
|
this.fullColumnNames = columnNamesFromFile;
|
}
|
// Check if there are duplicate column names.
|
const counts = this.fullColumnNames.reduce((countAcc, name) => {
|
countAcc[name] = (countAcc[name] + 1) || 1;
|
return countAcc;
|
}, {});
|
const duplicateNames = Object.keys(counts).filter((name) => (counts[name] > 1));
|
util.assert(duplicateNames.length === 0, () => 'Duplicate column names found: ' + duplicateNames.toString());
|
// Check if keys in columnConfigs match columnNames.
|
if (this.columnConfigs) {
|
for (const key of Object.keys(this.columnConfigs)) {
|
const index = this.fullColumnNames.indexOf(key);
|
if (index === -1) {
|
throw new Error('The key "' + key +
|
'" provided in columnConfigs does not match any of the column ' +
|
'names (' + this.fullColumnNames.toString() + ').');
|
}
|
}
|
}
|
this.columnNamesValidated = true;
|
}
|
async maybeReadHeaderLine() {
|
if (this.hasHeader) {
|
const iter = await this.base.iterator();
|
const firstElement = await iter.next();
|
if (firstElement.done) {
|
throw new Error('No data was found for CSV parsing.');
|
}
|
const firstLine = firstElement.value;
|
const headers = this.parseRow(firstLine, false);
|
return headers;
|
}
|
else {
|
return null;
|
}
|
}
|
/**
|
* Create a `CSVDataset`.
|
*
|
* @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
|
* @param csvConfig (Optional) A CSVConfig object that contains configurations
|
* of reading and decoding from CSV file(s).
|
*
|
* hasHeader: (Optional) A boolean value that indicates whether the first
|
* row of provided CSV file is a header line with column names, and should
|
* not be included in the data. Defaults to `true`.
|
*
|
* columnNames: (Optional) A list of strings that corresponds to
|
* the CSV column names, in order. If provided, it ignores the column
|
* names inferred from the header row. If not provided, infers the column
|
* names from the first row of the records. If hasHeader is false and
|
* columnNames is not provided, this method throws an error.
|
*
|
* columnConfigs: (Optional) A dictionary whose key is column names, value
|
* is an object stating if this column is required, column's data type,
|
* default value, and if this column is label. If provided, keys must
|
* correspond to names provided in columnNames or inferred from the file
|
* header lines. If isLabel is true any column, returns an array of two
|
* items: the first item is a dict of features key/value pairs, the second
|
* item is a dict of labels key/value pairs. If no feature is marked as
|
* label, returns a dict of features only.
|
*
|
* configuredColumnsOnly (Optional) If true, only columns provided in
|
* columnConfigs will be parsed and provided during iteration.
|
*
|
* delimiter (Optional) The string used to parse each line of the input
|
* file. Defaults to `,`.
|
*/
|
constructor(input, csvConfig) {
|
super();
|
this.input = input;
|
this.hasHeader = true;
|
this.fullColumnNames = null;
|
this.columnNamesValidated = false;
|
this.columnConfigs = null;
|
this.configuredColumnsOnly = false;
|
this.delimiter = ',';
|
this.delimWhitespace = false;
|
this.base = new TextLineDataset(input);
|
if (!csvConfig) {
|
csvConfig = {};
|
}
|
this.hasHeader = csvConfig.hasHeader === false ? false : true;
|
this.fullColumnNames = csvConfig.columnNames;
|
this.columnConfigs = csvConfig.columnConfigs;
|
this.configuredColumnsOnly = csvConfig.configuredColumnsOnly;
|
if (csvConfig.delimWhitespace) {
|
util.assert(csvConfig.delimiter == null, () => 'Delimiter should not be provided when delimWhitespace is true.');
|
this.delimWhitespace = true;
|
this.delimiter = ' ';
|
}
|
else {
|
this.delimiter = csvConfig.delimiter ? csvConfig.delimiter : ',';
|
}
|
}
|
async iterator() {
|
if (!this.columnNamesValidated) {
|
await this.setColumnNames();
|
}
|
let lines = await this.base.iterator();
|
if (this.hasHeader) {
|
// We previously read the first line to get the columnNames.
|
// Now that we're providing data, skip it.
|
lines = lines.skip(1);
|
}
|
return lines.map(x => this.makeDataElement(x));
|
}
|
makeDataElement(line) {
|
const values = this.parseRow(line);
|
const features = {};
|
const labels = {};
|
for (let i = 0; i < this.fullColumnNames.length; i++) {
|
const key = this.fullColumnNames[i];
|
const config = this.columnConfigs ? this.columnConfigs[key] : null;
|
if (this.configuredColumnsOnly && !config) {
|
// This column is not selected.
|
continue;
|
}
|
else {
|
const value = values[i];
|
let parsedValue = null;
|
if (value === '') {
|
// If default value is provided, use it. If default value is not
|
// provided, set as undefined.
|
if (config && config.default !== undefined) {
|
parsedValue = config.default;
|
}
|
else if (config && (config.required || config.isLabel)) {
|
throw new Error(`Required column ${key} is empty in this line: ${line}`);
|
}
|
else {
|
parsedValue = undefined;
|
}
|
}
|
else {
|
// A value is present, so parse it based on type
|
const valueAsNum = Number(value);
|
if (isNaN(valueAsNum)) {
|
// The value is a string and this column is declared as boolean
|
// in config, parse it as boolean.
|
if (config && config.dtype === 'bool') {
|
parsedValue = this.getBoolean(value);
|
}
|
else {
|
// Set value as string
|
parsedValue = value;
|
}
|
}
|
else if (!config || !config.dtype) {
|
// If this value is a number and no type config is provided, return
|
// it as number.
|
parsedValue = valueAsNum;
|
}
|
else {
|
// If this value is a number and data type is provided, parse it
|
// according to provided data type.
|
switch (config.dtype) {
|
case 'float32':
|
parsedValue = valueAsNum;
|
break;
|
case 'int32':
|
parsedValue = Math.floor(valueAsNum);
|
break;
|
case 'bool':
|
parsedValue = this.getBoolean(value);
|
break;
|
default:
|
parsedValue = valueAsNum;
|
}
|
}
|
}
|
// Check if this column is label.
|
(config && config.isLabel) ? labels[key] = parsedValue :
|
features[key] = parsedValue;
|
}
|
}
|
// If label exists, return an object of features and labels as {xs:features,
|
// ys:labels}, otherwise return features only.
|
if (Object.keys(labels).length === 0) {
|
return features;
|
}
|
else {
|
return { xs: features, ys: labels };
|
}
|
}
|
getBoolean(value) {
|
if (value === '1' || value.toLowerCase() === 'true') {
|
return 1;
|
}
|
else {
|
return 0;
|
}
|
}
|
// adapted from https://beta.observablehq.com/@mbostock/streaming-csv
|
parseRow(line, validateElementCount = true) {
|
const result = [];
|
let readOffset = 0;
|
const readLength = line.length;
|
let currentState = STATE_OUT;
|
// Goes through the line to parse quote.
|
for (let i = 0; i < readLength; i++) {
|
switch (currentState) {
|
// Before enter a new field
|
case STATE_OUT:
|
switch (line.charAt(i)) {
|
// Enter a quoted field
|
case CODE_QUOTE:
|
readOffset = i + 1;
|
currentState = STATE_QUOTE;
|
break;
|
// Read an empty field
|
case this.delimiter:
|
readOffset = i + 1;
|
// If delimiter is white space and configured to collapse
|
// multiple white spaces, ignore this white space.
|
if (this.delimiter === ' ' && this.delimWhitespace) {
|
break;
|
}
|
result.push('');
|
currentState = STATE_OUT;
|
break;
|
// Enter an unquoted field
|
default:
|
currentState = STATE_FIELD;
|
readOffset = i;
|
break;
|
}
|
break;
|
// In an unquoted field
|
case STATE_FIELD:
|
switch (line.charAt(i)) {
|
// Exit an unquoted field, add it to result
|
case this.delimiter:
|
result.push(line.substring(readOffset, i));
|
currentState = STATE_OUT;
|
readOffset = i + 1;
|
break;
|
}
|
break;
|
// In a quoted field
|
case STATE_QUOTE:
|
switch (line.charAt(i)) {
|
// Read a quote after a quote
|
case CODE_QUOTE:
|
currentState = STATE_QUOTE_AFTER_QUOTE;
|
break;
|
}
|
break;
|
// This state means it's right after a second quote in a field
|
case STATE_QUOTE_AFTER_QUOTE:
|
switch (line.charAt(i)) {
|
// Finished a quoted field
|
case this.delimiter:
|
result.push(line.substring(readOffset, i - 1));
|
currentState = STATE_OUT;
|
readOffset = i + 1;
|
break;
|
// Finished a quoted part in a quoted field
|
case CODE_QUOTE:
|
currentState = STATE_QUOTE;
|
break;
|
// In a quoted part in a quoted field
|
default:
|
currentState = STATE_WITHIN_QUOTE_IN_QUOTE;
|
break;
|
}
|
break;
|
case STATE_WITHIN_QUOTE_IN_QUOTE:
|
switch (line.charAt(i)) {
|
// Exit a quoted part in a quoted field
|
case CODE_QUOTE:
|
currentState = STATE_QUOTE;
|
break;
|
}
|
break;
|
}
|
}
|
// Adds last item based on if it is quoted.
|
if (currentState === STATE_QUOTE_AFTER_QUOTE) {
|
result.push(line.substring(readOffset, readLength - 1));
|
}
|
else {
|
result.push(line.substring(readOffset));
|
}
|
// Check if each row has the same number of elements as column names.
|
if (validateElementCount && result.length !== this.fullColumnNames.length) {
|
throw new Error(`Invalid row in csv file. Should have ${this.fullColumnNames.length} elements in a row, but got ${result}`);
|
}
|
return result;
|
}
|
}
|
// TODO(soergel): add more basic datasets for parity with tf.data
|
// tf.data.FixedLengthRecordDataset()
|
// tf.data.TFRecordDataset()
|
|
/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Provide a stream of tensors from microphone audio stream. The tensors are
|
* representing audio data as frequency-domain spectrogram generated with
|
* browser's native FFT. Tensors representing time-domain waveform is available
|
* based on configuration. Only works in browser environment.
|
*/
|
class MicrophoneIterator extends LazyIterator {
|
constructor(microphoneConfig) {
|
super();
|
this.microphoneConfig = microphoneConfig;
|
this.isClosed = false;
|
this.fftSize = microphoneConfig.fftSize || 1024;
|
const fftSizeLog2 = Math.log2(this.fftSize);
|
if (this.fftSize < 0 || fftSizeLog2 < 4 || fftSizeLog2 > 14 ||
|
!Number.isInteger(fftSizeLog2)) {
|
throw new Error(`Invalid fftSize: it must be a power of 2 between ` +
|
`2 to 4 and 2 to 14, but got ${this.fftSize}`);
|
}
|
this.numFrames = microphoneConfig.numFramesPerSpectrogram || 43;
|
this.sampleRateHz = microphoneConfig.sampleRateHz;
|
this.columnTruncateLength =
|
microphoneConfig.columnTruncateLength || this.fftSize;
|
this.audioTrackConstraints = microphoneConfig.audioTrackConstraints;
|
this.smoothingTimeConstant = microphoneConfig.smoothingTimeConstant || 0;
|
this.includeSpectrogram =
|
microphoneConfig.includeSpectrogram === false ? false : true;
|
this.includeWaveform =
|
microphoneConfig.includeWaveform === true ? true : false;
|
if (!this.includeSpectrogram && !this.includeWaveform) {
|
throw new Error('Both includeSpectrogram and includeWaveform are false. ' +
|
'At least one type of data should be returned.');
|
}
|
}
|
summary() {
|
return `microphone`;
|
}
|
// Construct a MicrophoneIterator and start the audio stream.
|
static async create(microphoneConfig = {}) {
|
if (!env().get('IS_BROWSER')) {
|
throw new Error('microphone API is only supported in browser environment.');
|
}
|
const microphoneIterator = new MicrophoneIterator(microphoneConfig);
|
// Call async function start() to initialize the audio stream.
|
await microphoneIterator.start();
|
return microphoneIterator;
|
}
|
// Start the audio stream and FFT.
|
async start() {
|
try {
|
this.stream = await navigator.mediaDevices.getUserMedia({
|
audio: this.audioTrackConstraints == null ? true :
|
this.audioTrackConstraints,
|
video: false
|
});
|
}
|
catch (e) {
|
throw new Error(`Error thrown while initializing video stream: ${e.message}`);
|
}
|
if (!this.stream) {
|
throw new Error('Could not obtain audio from microphone.');
|
}
|
const ctxConstructor =
|
// tslint:disable-next-line:no-any
|
window.AudioContext || window.webkitAudioContext;
|
this.audioContext = new ctxConstructor();
|
if (!this.sampleRateHz) {
|
// If sample rate is not provided, use the available sample rate on
|
// device.
|
this.sampleRateHz = this.audioContext.sampleRate;
|
}
|
else if (this.audioContext.sampleRate !== this.sampleRateHz) {
|
throw new Error(`Mismatch in sampling rate: ` +
|
`Expected: ${this.sampleRateHz}; ` +
|
`Actual: ${this.audioContext.sampleRate}`);
|
}
|
const streamSource = this.audioContext.createMediaStreamSource(this.stream);
|
this.analyser = this.audioContext.createAnalyser();
|
this.analyser.fftSize = this.fftSize * 2;
|
this.analyser.smoothingTimeConstant = this.smoothingTimeConstant;
|
streamSource.connect(this.analyser);
|
this.freqData = new Float32Array(this.fftSize);
|
this.timeData = new Float32Array(this.fftSize);
|
return;
|
}
|
async next() {
|
if (this.isClosed) {
|
return { value: null, done: true };
|
}
|
let spectrogramTensor;
|
let waveformTensor;
|
const audioDataQueue = await this.getAudioData();
|
if (this.includeSpectrogram) {
|
const freqData = this.flattenQueue(audioDataQueue.freqDataQueue);
|
spectrogramTensor = this.getTensorFromAudioDataArray(freqData, [this.numFrames, this.columnTruncateLength, 1]);
|
}
|
if (this.includeWaveform) {
|
const timeData = this.flattenQueue(audioDataQueue.timeDataQueue);
|
waveformTensor = this.getTensorFromAudioDataArray(timeData, [this.numFrames * this.fftSize, 1]);
|
}
|
return {
|
value: { 'spectrogram': spectrogramTensor, 'waveform': waveformTensor },
|
done: false
|
};
|
}
|
// Capture one result from the audio stream, and extract the value from
|
// iterator.next() result.
|
async capture() {
|
return (await this.next()).value;
|
}
|
async getAudioData() {
|
const freqDataQueue = [];
|
const timeDataQueue = [];
|
let currentFrames = 0;
|
return new Promise(resolve => {
|
const intervalID = setInterval(() => {
|
if (this.includeSpectrogram) {
|
this.analyser.getFloatFrequencyData(this.freqData);
|
// If the audio stream is initializing, return empty queue.
|
if (this.freqData[0] === -Infinity) {
|
resolve({ freqDataQueue, timeDataQueue });
|
}
|
freqDataQueue.push(this.freqData.slice(0, this.columnTruncateLength));
|
}
|
if (this.includeWaveform) {
|
this.analyser.getFloatTimeDomainData(this.timeData);
|
timeDataQueue.push(this.timeData.slice());
|
}
|
// Clean interval and return when all frames have been collected
|
if (++currentFrames === this.numFrames) {
|
clearInterval(intervalID);
|
resolve({ freqDataQueue, timeDataQueue });
|
}
|
}, this.fftSize / this.sampleRateHz * 1e3);
|
});
|
}
|
// Stop the audio stream and pause the iterator.
|
stop() {
|
if (!this.isClosed) {
|
this.isClosed = true;
|
this.analyser.disconnect();
|
this.audioContext.close();
|
if (this.stream != null && this.stream.getTracks().length > 0) {
|
this.stream.getTracks()[0].stop();
|
}
|
}
|
}
|
// Override toArray() function to prevent collecting.
|
toArray() {
|
throw new Error('Can not convert infinite audio stream to array.');
|
}
|
// Return audio sampling rate in Hz
|
getSampleRate() {
|
return this.sampleRateHz;
|
}
|
flattenQueue(queue) {
|
const frameSize = queue[0].length;
|
const freqData = new Float32Array(queue.length * frameSize);
|
queue.forEach((data, i) => freqData.set(data, i * frameSize));
|
return freqData;
|
}
|
getTensorFromAudioDataArray(freqData, shape) {
|
const vals = new Float32Array(util.sizeFromShape(shape));
|
// If the data is less than the output shape, the rest is padded with zeros.
|
vals.set(freqData, vals.length - freqData.length);
|
return tensor(vals, shape);
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Provide a stream of image tensors from webcam video stream. Only works in
|
* browser environment.
|
*/
|
class WebcamIterator extends LazyIterator {
|
constructor(webcamVideoElement, webcamConfig) {
|
super();
|
this.webcamVideoElement = webcamVideoElement;
|
this.webcamConfig = webcamConfig;
|
this.isClosed = true;
|
this.resize = false;
|
if (this.needToResize()) {
|
this.resize = true;
|
this.cropSize =
|
[this.webcamConfig.resizeHeight, this.webcamConfig.resizeWidth];
|
this.cropBoxInd = tensor1d([0], 'int32');
|
if (this.webcamConfig.centerCrop) {
|
// Calculate the box based on resizing shape.
|
const widthCroppingRatio = this.webcamConfig.resizeWidth * 1.0 / this.webcamVideoElement.width;
|
const heightCroppingRatio = this.webcamConfig.resizeHeight * 1.0 /
|
this.webcamVideoElement.height;
|
const widthCropStart = (1 - widthCroppingRatio) / 2;
|
const heightCropStart = (1 - heightCroppingRatio) / 2;
|
const widthCropEnd = widthCropStart + widthCroppingRatio;
|
const heightCropEnd = heightCroppingRatio + heightCropStart;
|
this.cropBox = tensor2d([heightCropStart, widthCropStart, heightCropEnd, widthCropEnd], [1, 4]);
|
}
|
else {
|
this.cropBox = tensor2d([0, 0, 1, 1], [1, 4]);
|
}
|
}
|
}
|
summary() {
|
return `webcam`;
|
}
|
// Construct a WebcamIterator and start it's video stream.
|
static async create(webcamVideoElement, webcamConfig = {}) {
|
if (!env().get('IS_BROWSER')) {
|
throw new Error('tf.data.webcam is only supported in browser environment.');
|
}
|
if (!webcamVideoElement) {
|
// If webcam video element is not provided, create a hidden video element
|
// with provided width and height.
|
webcamVideoElement = document.createElement('video');
|
if (!webcamConfig.resizeWidth || !webcamConfig.resizeHeight) {
|
throw new Error('Please provide webcam video element, or resizeWidth and ' +
|
'resizeHeight to create a hidden video element.');
|
}
|
webcamVideoElement.width = webcamConfig.resizeWidth;
|
webcamVideoElement.height = webcamConfig.resizeHeight;
|
}
|
const webcamIterator = new WebcamIterator(webcamVideoElement, webcamConfig);
|
// Call async function to initialize the video stream.
|
await webcamIterator.start();
|
return webcamIterator;
|
}
|
// Async function to start video stream.
|
async start() {
|
if (this.webcamConfig.facingMode) {
|
util.assert((this.webcamConfig.facingMode === 'user') ||
|
(this.webcamConfig.facingMode === 'environment'), () => `Invalid webcam facing mode: ${this.webcamConfig.facingMode}. ` +
|
`Please provide 'user' or 'environment'`);
|
}
|
try {
|
this.stream = await navigator.mediaDevices.getUserMedia({
|
video: {
|
deviceId: this.webcamConfig.deviceId,
|
facingMode: this.webcamConfig.facingMode ?
|
this.webcamConfig.facingMode :
|
'user',
|
width: this.webcamVideoElement.width,
|
height: this.webcamVideoElement.height
|
}
|
});
|
}
|
catch (e) {
|
// Modify the error message but leave the stack trace intact
|
e.message = `Error thrown while initializing video stream: ${e.message}`;
|
throw e;
|
}
|
if (!this.stream) {
|
throw new Error('Could not obtain video from webcam.');
|
}
|
// Older browsers may not have srcObject
|
try {
|
this.webcamVideoElement.srcObject = this.stream;
|
}
|
catch (error) {
|
console.log(error);
|
this.webcamVideoElement.src = window.URL.createObjectURL(this.stream);
|
}
|
// Start the webcam video stream
|
this.webcamVideoElement.play();
|
this.isClosed = false;
|
return new Promise(resolve => {
|
// Add event listener to make sure the webcam has been fully initialized.
|
this.webcamVideoElement.onloadedmetadata = () => {
|
resolve();
|
};
|
});
|
}
|
async next() {
|
if (this.isClosed) {
|
return { value: null, done: true };
|
}
|
let img;
|
try {
|
img = browser.fromPixels(this.webcamVideoElement);
|
}
|
catch (e) {
|
throw new Error(`Error thrown converting video to pixels: ${JSON.stringify(e)}`);
|
}
|
if (this.resize) {
|
try {
|
return { value: this.cropAndResizeFrame(img), done: false };
|
}
|
catch (e) {
|
throw new Error(`Error thrown cropping the video: ${e.message}`);
|
}
|
finally {
|
img.dispose();
|
}
|
}
|
else {
|
return { value: img, done: false };
|
}
|
}
|
needToResize() {
|
// If resizeWidth and resizeHeight are provided, and different from the
|
// width and height of original HTMLVideoElement, then resizing and cropping
|
// is required.
|
if (this.webcamConfig.resizeWidth && this.webcamConfig.resizeHeight &&
|
(this.webcamVideoElement.width !== this.webcamConfig.resizeWidth ||
|
this.webcamVideoElement.height !== this.webcamConfig.resizeHeight)) {
|
return true;
|
}
|
return false;
|
}
|
// Cropping and resizing each frame based on config
|
cropAndResizeFrame(img) {
|
return tidy(() => {
|
const expandedImage = expandDims(cast(img, 'float32'), (0));
|
let resizedImage;
|
resizedImage = image.cropAndResize(expandedImage, this.cropBox, this.cropBoxInd, this.cropSize, 'bilinear');
|
// Extract image from batch cropping.
|
const shape = resizedImage.shape;
|
return reshape(resizedImage, shape.slice(1));
|
});
|
}
|
// Capture one frame from the video stream, and extract the value from
|
// iterator.next() result.
|
async capture() {
|
return (await this.next()).value;
|
}
|
// Stop the video stream and pause webcam iterator.
|
stop() {
|
const tracks = this.stream.getTracks();
|
tracks.forEach(track => track.stop());
|
try {
|
this.webcamVideoElement.srcObject = null;
|
}
|
catch (error) {
|
console.log(error);
|
this.webcamVideoElement.src = null;
|
}
|
this.isClosed = true;
|
}
|
// Override toArray() function to prevent collecting.
|
toArray() {
|
throw new Error('Can not convert infinite video stream to array.');
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Represents a data source readable as a stream of binary data chunks.
|
*
|
* Because `Dataset`s can be read repeatedly (via `Dataset.iterator()`), this
|
* provides a means to repeatedly create streams from the underlying data
|
* sources.
|
*/
|
class DataSource {
|
}
|
// TODO(soergel): consider convenience factory functions here
|
// in combination with chainable source->dataset above, e.g.:
|
// tf.data.url(...).asCsvDataset().shuffle().batch()
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
class StringIterator extends LazyIterator {
|
/**
|
* Splits a string stream on a given separator.
|
*
|
* It is assumed that the incoming chunk boundaries have no semantic meaning,
|
* so conceptually the incoming stream is treated simply as the concatenation
|
* of its elements.
|
*
|
* The outgoing stream provides chunks corresponding to the results of the
|
* standard string split() operation (even if such a chunk spanned incoming
|
* chunks). The separators are not included.
|
*
|
* A typical usage is to split a text file (represented as a stream with
|
* arbitrary chunk boundaries) into lines.
|
*
|
* @param upstream A readable stream of strings that can be treated as
|
* concatenated.
|
* @param separator A character to split on.
|
*/
|
split(separator) {
|
return new SplitIterator(this, separator);
|
}
|
}
|
// ============================================================================
|
// The following private classes serve to implement the chainable methods
|
// on StringIterator. Unfortunately they can't be placed in separate files, due
|
// to resulting trouble with circular imports.
|
// ============================================================================
|
// We wanted multiple inheritance, e.g.
|
// class SplitIterator extends QueueIterator<string>, StringIterator
|
// but the TypeScript mixin approach is a bit hacky, so we take this adapter
|
// approach instead.
|
class SplitIterator extends StringIterator {
|
constructor(upstream, separator) {
|
super();
|
this.upstream = upstream;
|
this.impl = new SplitIteratorImpl(upstream, separator);
|
}
|
summary() {
|
return this.impl.summary();
|
}
|
async next() {
|
return this.impl.next();
|
}
|
}
|
class SplitIteratorImpl extends OneToManyIterator {
|
constructor(upstream, separator) {
|
super();
|
this.upstream = upstream;
|
this.separator = separator;
|
// A partial string at the end of an upstream chunk
|
this.carryover = '';
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Split('${this.separator}')`;
|
}
|
async pump() {
|
const chunkResult = await this.upstream.next();
|
if (chunkResult.done) {
|
if (this.carryover === '') {
|
return false;
|
}
|
// Pretend that the pump succeeded in order to emit the small last batch.
|
// The next pump() call will actually fail.
|
this.outputQueue.push(this.carryover);
|
this.carryover = '';
|
return true;
|
}
|
const lines = chunkResult.value.split(this.separator);
|
// Note the behavior: " ab ".split(' ') === ['', 'ab', '']
|
// Thus the carryover may be '' if the separator falls on a chunk
|
// boundary; this produces the correct result.
|
lines[0] = this.carryover + lines[0];
|
for (const line of lines.slice(0, -1)) {
|
this.outputQueue.push(line);
|
}
|
this.carryover = lines[lines.length - 1];
|
return true;
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
class ByteChunkIterator extends LazyIterator {
|
/**
|
* Decode a stream of UTF8-encoded byte arrays to a stream of strings.
|
*
|
* The byte arrays producetd from the ByteChunkIterator on which this is
|
* called will be interpreted as concatenated. No assumptions are made about
|
* the boundaries of the incoming chunks, so a multi-byte UTF8 encoding of a
|
* character may span the boundary between chunks. This naturally happens,
|
* for instance, when reading fixed-size byte arrays from a file.
|
*/
|
decodeUTF8() {
|
return new Utf8Iterator(this);
|
}
|
}
|
// ============================================================================
|
// The following private classes serve to implement the chainable methods
|
// on ByteChunkIterator. Unfortunately they can't be placed in separate files,
|
// due to resulting trouble with circular imports.
|
// ============================================================================
|
// We wanted multiple inheritance, e.g.
|
// class Utf8Iterator extends QueueIterator<string>, StringIterator
|
// but the TypeScript mixin approach is a bit hacky, so we take this adapter
|
// approach instead.
|
class Utf8Iterator extends StringIterator {
|
constructor(upstream) {
|
super();
|
this.upstream = upstream;
|
this.impl = new Utf8IteratorImpl(upstream);
|
}
|
summary() {
|
return this.impl.summary();
|
}
|
async next() {
|
return this.impl.next();
|
}
|
}
|
/**
|
* Decode a stream of UTF8-encoded byte arrays to a stream of strings.
|
*
|
* This is tricky because the incoming byte array boundaries may disrupt a
|
* multi-byte UTF8 character. Thus any incomplete character data at the end of
|
* a chunk must be carried over and prepended to the next chunk before
|
* decoding. Luckily with native decoder, TextDecoder in browser and
|
* string_decoder in node, byte array boundaries are handled automatically.
|
*
|
* In the context of an input pipeline for machine learning, UTF8 decoding is
|
* needed to parse text files containing training examples or prediction
|
* requests (e.g., formatted as CSV or JSON). We cannot use the built-in
|
* decoding provided by FileReader.readAsText() because here we are in a
|
* streaming context, which FileReader does not support.
|
*
|
* @param upstream A `LazyIterator` of `Uint8Arrays` containing UTF8-encoded
|
* text, which should be interpreted as concatenated. No assumptions are
|
* made about the boundaries of the incoming chunks, so a multi-byte UTF8
|
* encoding of a character may span the boundary between chunks. This
|
* naturally happens, for instance, when reading fixed-size byte arrays from a
|
* file.
|
*/
|
class Utf8IteratorImpl extends OneToManyIterator {
|
constructor(upstream) {
|
super();
|
this.upstream = upstream;
|
if (env().get('IS_BROWSER')) {
|
this.decoder = new TextDecoder('utf-8');
|
}
|
else {
|
// tslint:disable-next-line:no-require-imports
|
const { StringDecoder } = require('string_decoder');
|
this.decoder = new StringDecoder('utf8');
|
}
|
}
|
summary() {
|
return `${this.upstream.summary()} -> Utf8`;
|
}
|
async pump() {
|
const chunkResult = await this.upstream.next();
|
let chunk;
|
if (chunkResult.done) {
|
return false;
|
}
|
else {
|
chunk = chunkResult.value;
|
}
|
let text;
|
if (env().get('IS_BROWSER')) {
|
text = this.decoder.decode(chunk, { stream: true });
|
}
|
else {
|
text = this.decoder.write(Buffer.from(chunk.buffer));
|
}
|
this.outputQueue.push(text);
|
return true;
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Provide a stream of chunks from a File, Blob, or Uint8Array.
|
* @param file The source File, Blob or Uint8Array.
|
* @param options Optional settings controlling file reading.
|
* @returns a lazy Iterator of Uint8Arrays containing sequential chunks of the
|
* input File, Blob or Uint8Array.
|
*/
|
class FileChunkIterator extends ByteChunkIterator {
|
constructor(file, options = {}) {
|
super();
|
this.file = file;
|
this.options = options;
|
util.assert((file instanceof Uint8Array) ||
|
(env().get('IS_BROWSER') ?
|
(file instanceof File || file instanceof Blob) :
|
false), () => 'FileChunkIterator only supports File, Blob and Uint8Array ' +
|
'right now.');
|
this.offset = options.offset || 0;
|
// default 1MB chunk has tolerable perf on large files
|
this.chunkSize = options.chunkSize || 1024 * 1024;
|
}
|
summary() {
|
return `FileChunks ${this.file}`;
|
}
|
async next() {
|
if (this.offset >= ((this.file instanceof Uint8Array) ?
|
this.file.byteLength :
|
this.file.size)) {
|
return { value: null, done: true };
|
}
|
const chunk = new Promise((resolve, reject) => {
|
const end = this.offset + this.chunkSize;
|
if (this.file instanceof Uint8Array) {
|
// Note if end > this.uint8Array.byteLength, we just get a small last
|
// chunk.
|
resolve(new Uint8Array(this.file.slice(this.offset, end)));
|
}
|
else {
|
// This branch assumes that this.file type is File or Blob, which
|
// means it is in the browser environment.
|
// TODO(soergel): is this a performance issue?
|
const fileReader = new FileReader();
|
fileReader.onload = (event) => {
|
let data = fileReader.result;
|
// Not sure we can trust the return type of
|
// FileReader.readAsArrayBuffer See e.g.
|
// https://github.com/node-file-api/FileReader/issues/2
|
if (data instanceof ArrayBuffer) {
|
data = new Uint8Array(data);
|
}
|
if (!(data instanceof Uint8Array)) {
|
return reject(new TypeError('FileReader returned unknown type.'));
|
}
|
resolve(data);
|
};
|
fileReader.onabort = (event) => {
|
return reject(new Error('Aborted'));
|
};
|
fileReader.onerror = (event) => {
|
return reject(new Error(event.type));
|
};
|
// TODO(soergel): better handle onabort, onerror
|
// Note if end > this.file.size, we just get a small last chunk.
|
const slice = this.file.slice(this.offset, end);
|
// We can't use readAsText here (even if we know the file is text)
|
// because the slice boundary may fall within a multi-byte character.
|
fileReader.readAsArrayBuffer(slice);
|
}
|
this.offset = end;
|
});
|
return { value: (await chunk), done: false };
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Provide a stream of chunks from a URL.
|
*
|
* Note this class first downloads the entire file into memory before providing
|
* the first element from the stream. This is because the Fetch API does not
|
* yet reliably provide a reader stream for the response body.
|
*/
|
async function urlChunkIterator(url, options = {}, fetchFunc) {
|
let urlString;
|
let requestInit;
|
if ((typeof url) === 'string') {
|
urlString = url;
|
}
|
else {
|
urlString = url.url;
|
requestInit = getRequestInitFromRequest(url);
|
}
|
const response = await (fetchFunc || util.fetch)(urlString, requestInit);
|
if (response.ok) {
|
const uint8Array = new Uint8Array(await response.arrayBuffer());
|
return new FileChunkIterator(uint8Array, options);
|
}
|
else {
|
throw new Error(response.statusText);
|
}
|
}
|
// Generate RequestInit from Request to match tf.util.fetch signature.
|
const getRequestInitFromRequest = (request) => {
|
const init = {
|
method: request.method,
|
headers: request.headers,
|
body: request.body,
|
mode: request.mode,
|
credentials: request.credentials,
|
cache: request.cache,
|
redirect: request.redirect,
|
referrer: request.referrer,
|
integrity: request.integrity,
|
};
|
return init;
|
};
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
// Skip tslint any type check cause this method is aiming to check type of
|
// input.
|
// tslint:disable-next-line:no-any
|
function isLocalPath(source) {
|
return (typeof source === 'string') && source.slice(0, 7) === 'file://';
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Represents a file, blob, or Uint8Array readable as a stream of binary data
|
* chunks.
|
*/
|
class FileDataSource extends DataSource {
|
/**
|
* Create a `FileDataSource`.
|
*
|
* @param input Local file path, or `File`/`Blob`/`Uint8Array` object to
|
* read. Local file only works in node environment.
|
* @param options Options passed to the underlying `FileChunkIterator`s,
|
* such as {chunksize: 1024}.
|
*/
|
constructor(input, options = {}) {
|
super();
|
this.input = input;
|
this.options = options;
|
}
|
async iterator() {
|
if (isLocalPath(this.input) && env().get('IS_NODE')) {
|
// tslint:disable-next-line:no-require-imports
|
const fs = require('fs');
|
this.input = fs.readFileSync(this.input.slice(7));
|
}
|
// TODO(kangyizhang): Add LocalFileChunkIterator to split local streaming
|
// with file in browser.
|
return new FileChunkIterator(this.input, this.options);
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/*
|
* Represents a URL readable as a stream of binary data chunks.
|
*/
|
class URLDataSource extends DataSource {
|
/**
|
* Create a `URLDataSource`.
|
*
|
* @param url A source URL string, or a `Request` object.
|
* @param options Options passed to the underlying `FileChunkIterator`s,
|
* such as {chunksize: 1024}.
|
*/
|
constructor(url, fileOptions = {}) {
|
super();
|
this.url = url;
|
this.fileOptions = fileOptions;
|
}
|
// TODO(soergel): provide appropriate caching options. Currently this
|
// will download the URL anew for each call to iterator(). Since we have
|
// to treat the downloaded file as a blob/buffer anyway, we may as well retain
|
// it-- but that raises GC issues. Also we may want a persistent disk cache.
|
async iterator() {
|
if (isLocalPath(this.url)) {
|
return (new FileDataSource(this.url, this.fileOptions))
|
.iterator();
|
}
|
else {
|
return urlChunkIterator(this.url, this.fileOptions);
|
}
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*
|
* =============================================================================
|
*/
|
/**
|
* Create a `CSVDataset` by reading and decoding CSV file(s) from provided URL
|
* or local path if it's in Node environment.
|
*
|
* Note: If isLabel in columnConfigs is `true` for at least one column, the
|
* element in returned `CSVDataset` will be an object of
|
* `{xs:features, ys:labels}`: xs is a dict of features key/value pairs, ys
|
* is a dict of labels key/value pairs. If no column is marked as label,
|
* returns a dict of features only.
|
*
|
* ```js
|
* const csvUrl =
|
* 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
|
*
|
* async function run() {
|
* // We want to predict the column "medv", which represents a median value of
|
* // a home (in $1000s), so we mark it as a label.
|
* const csvDataset = tf.data.csv(
|
* csvUrl, {
|
* columnConfigs: {
|
* medv: {
|
* isLabel: true
|
* }
|
* }
|
* });
|
*
|
* // Number of features is the number of column names minus one for the label
|
* // column.
|
* const numOfFeatures = (await csvDataset.columnNames()).length - 1;
|
*
|
* // Prepare the Dataset for training.
|
* const flattenedDataset =
|
* csvDataset
|
* .map(({xs, ys}) =>
|
* {
|
* // Convert xs(features) and ys(labels) from object form (keyed by
|
* // column name) to array form.
|
* return {xs:Object.values(xs), ys:Object.values(ys)};
|
* })
|
* .batch(10);
|
*
|
* // Define the model.
|
* const model = tf.sequential();
|
* model.add(tf.layers.dense({
|
* inputShape: [numOfFeatures],
|
* units: 1
|
* }));
|
* model.compile({
|
* optimizer: tf.train.sgd(0.000001),
|
* loss: 'meanSquaredError'
|
* });
|
*
|
* // Fit the model using the prepared Dataset
|
* return model.fitDataset(flattenedDataset, {
|
* epochs: 10,
|
* callbacks: {
|
* onEpochEnd: async (epoch, logs) => {
|
* console.log(epoch + ':' + logs.loss);
|
* }
|
* }
|
* });
|
* }
|
*
|
* await run();
|
* ```
|
*
|
* @param source URL or local path to get CSV file. If it's a local path, it
|
* must have prefix `file://` and it only works in node environment.
|
* @param csvConfig (Optional) A CSVConfig object that contains configurations
|
* of reading and decoding from CSV file(s).
|
*
|
* @doc {
|
* heading: 'Data',
|
* subheading: 'Creation',
|
* namespace: 'data',
|
* configParamIndices: [1]
|
* }
|
*/
|
function csv(source, csvConfig = {}) {
|
return new CSVDataset(new URLDataSource(source), csvConfig);
|
}
|
/**
|
* Create a `Dataset` that produces each element by calling a provided function.
|
*
|
* Note that repeated iterations over this `Dataset` may produce different
|
* results, because the function will be called anew for each element of each
|
* iteration.
|
*
|
* Also, beware that the sequence of calls to this function may be out of order
|
* in time with respect to the logical order of the Dataset. This is due to the
|
* asynchronous lazy nature of stream processing, and depends on downstream
|
* transformations (e.g. .shuffle()). If the provided function is pure, this is
|
* no problem, but if it is a closure over a mutable state (e.g., a traversal
|
* pointer), then the order of the produced elements may be scrambled.
|
*
|
* ```js
|
* let i = -1;
|
* const func = () =>
|
* ++i < 5 ? {value: i, done: false} : {value: null, done: true};
|
* const ds = tf.data.func(func);
|
* await ds.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param f A function that produces one data element on each call.
|
*/
|
function func(f) {
|
const iter = iteratorFromFunction(f);
|
return datasetFromIteratorFn(async () => iter);
|
}
|
/**
|
* Create a `Dataset` that produces each element from provided JavaScript
|
* generator, which is a function*
|
* (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions),
|
* or a function that returns an
|
* iterator
|
* (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions).
|
*
|
* The returned iterator should have `.next()` function that returns element in
|
* format of `{value: TensorContainer, done:boolean}`.
|
*
|
* Example of creating a dataset from an iterator factory:
|
* ```js
|
* function makeIterator() {
|
* const numElements = 10;
|
* let index = 0;
|
*
|
* const iterator = {
|
* next: () => {
|
* let result;
|
* if (index < numElements) {
|
* result = {value: index, done: false};
|
* index++;
|
* return result;
|
* }
|
* return {value: index, done: true};
|
* }
|
* };
|
* return iterator;
|
* }
|
* const ds = tf.data.generator(makeIterator);
|
* await ds.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* Example of creating a dataset from a generator:
|
* ```js
|
* function* dataGenerator() {
|
* const numElements = 10;
|
* let index = 0;
|
* while (index < numElements) {
|
* const x = index;
|
* index++;
|
* yield x;
|
* }
|
* }
|
*
|
* const ds = tf.data.generator(dataGenerator);
|
* await ds.forEachAsync(e => console.log(e));
|
* ```
|
*
|
* @param generator A JavaScript generator function that returns a JavaScript
|
* iterator.
|
*
|
* @doc {
|
* heading: 'Data',
|
* subheading: 'Creation',
|
* namespace: 'data',
|
* configParamIndices: [1]
|
* }
|
*/
|
function generator(generator) {
|
return datasetFromIteratorFn(async () => {
|
const gen = await generator();
|
return iteratorFromFunction(() => gen.next());
|
});
|
}
|
/**
|
* Create an iterator that generates `Tensor`s from webcam video stream. This
|
* API only works in Browser environment when the device has webcam.
|
*
|
* Note: this code snippet only works when the device has a webcam. It will
|
* request permission to open the webcam when running.
|
* ```js
|
* const videoElement = document.createElement('video');
|
* videoElement.width = 100;
|
* videoElement.height = 100;
|
* const cam = await tf.data.webcam(videoElement);
|
* const img = await cam.capture();
|
* img.print();
|
* cam.stop();
|
* ```
|
*
|
* @param webcamVideoElement A `HTMLVideoElement` used to play video from
|
* webcam. If this element is not provided, a hidden `HTMLVideoElement` will
|
* be created. In that case, `resizeWidth` and `resizeHeight` must be
|
* provided to set the generated tensor shape.
|
* @param webcamConfig A `WebcamConfig` object that contains configurations of
|
* reading and manipulating data from webcam video stream.
|
*
|
* @doc {
|
* heading: 'Data',
|
* subheading: 'Creation',
|
* namespace: 'data',
|
* ignoreCI: true
|
* }
|
*/
|
async function webcam(webcamVideoElement, webcamConfig) {
|
return WebcamIterator.create(webcamVideoElement, webcamConfig);
|
}
|
/**
|
* Create an iterator that generates frequency-domain spectrogram `Tensor`s from
|
* microphone audio stream with browser's native FFT. This API only works in
|
* browser environment when the device has microphone.
|
*
|
* Note: this code snippet only works when the device has a microphone. It will
|
* request permission to open the microphone when running.
|
* ```js
|
* const mic = await tf.data.microphone({
|
* fftSize: 1024,
|
* columnTruncateLength: 232,
|
* numFramesPerSpectrogram: 43,
|
* sampleRateHz:44100,
|
* includeSpectrogram: true,
|
* includeWaveform: true
|
* });
|
* const audioData = await mic.capture();
|
* const spectrogramTensor = audioData.spectrogram;
|
* spectrogramTensor.print();
|
* const waveformTensor = audioData.waveform;
|
* waveformTensor.print();
|
* mic.stop();
|
* ```
|
*
|
* @param microphoneConfig A `MicrophoneConfig` object that contains
|
* configurations of reading audio data from microphone.
|
*
|
* @doc {
|
* heading: 'Data',
|
* subheading: 'Creation',
|
* namespace: 'data',
|
* ignoreCI: true
|
* }
|
*/
|
async function microphone(microphoneConfig) {
|
return MicrophoneIterator.create(microphoneConfig);
|
}
|
|
/** @license See the LICENSE file. */
|
// This code is auto-generated, do not modify this file!
|
const version = '4.15.0';
|
|
export { CSVDataset, Dataset, FileDataSource, TextLineDataset, URLDataSource, array, csv, func, generator, microphone, version as version_data, webcam, zip };
|
//# sourceMappingURL=tf-data.fesm.js.map
|