"use strict";
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
var __extends = (this && this.__extends) || (function () {
|
var extendStatics = function (d, b) {
|
extendStatics = Object.setPrototypeOf ||
|
({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
|
function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; };
|
return extendStatics(d, b);
|
};
|
return function (d, b) {
|
if (typeof b !== "function" && b !== null)
|
throw new TypeError("Class extends value " + String(b) + " is not a constructor or null");
|
extendStatics(d, b);
|
function __() { this.constructor = d; }
|
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
|
};
|
})();
|
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
|
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
|
return new (P || (P = Promise))(function (resolve, reject) {
|
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
|
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
|
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
|
step((generator = generator.apply(thisArg, _arguments || [])).next());
|
});
|
};
|
var __generator = (this && this.__generator) || function (thisArg, body) {
|
var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
|
return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
|
function verb(n) { return function (v) { return step([n, v]); }; }
|
function step(op) {
|
if (f) throw new TypeError("Generator is already executing.");
|
while (g && (g = 0, op[0] && (_ = 0)), _) try {
|
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
|
if (y = 0, t) op = [op[0] & 2, t.value];
|
switch (op[0]) {
|
case 0: case 1: t = op; break;
|
case 4: _.label++; return { value: op[1], done: false };
|
case 5: _.label++; y = op[1]; op = [0]; continue;
|
case 7: op = _.ops.pop(); _.trys.pop(); continue;
|
default:
|
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
|
if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
|
if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
|
if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
|
if (t[2]) _.ops.pop();
|
_.trys.pop(); continue;
|
}
|
op = body.call(thisArg, _);
|
} catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
|
if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
|
}
|
};
|
Object.defineProperty(exports, "__esModule", { value: true });
|
exports.tensorBoard = exports.TensorBoardCallback = exports.getDisplayDecimalPlaces = exports.getSuccinctNumberDisplay = exports.ProgbarLogger = exports.progressBarHelper = void 0;
|
var tfjs_1 = require("@tensorflow/tfjs");
|
var path = require("path");
|
var ProgressBar = require("progress");
|
var tensorboard_1 = require("./tensorboard");
|
// A helper class created for testing with the jasmine `spyOn` method, which
|
// operates only on member methods of objects.
|
// tslint:disable-next-line:no-any
|
exports.progressBarHelper = {
|
ProgressBar: ProgressBar,
|
log: console.log
|
};
|
/**
|
* Terminal-based progress bar callback for tf.Model.fit().
|
*/
|
var ProgbarLogger = /** @class */ (function (_super) {
|
__extends(ProgbarLogger, _super);
|
/**
|
* Construtor of LoggingCallback.
|
*/
|
function ProgbarLogger() {
|
var _this = _super.call(this, {
|
onTrainBegin: function (logs) { return __awaiter(_this, void 0, void 0, function () {
|
var samples, batchSize, steps;
|
return __generator(this, function (_a) {
|
samples = this.params.samples;
|
batchSize = this.params.batchSize;
|
steps = this.params.steps;
|
if (samples != null || steps != null) {
|
this.numTrainBatchesPerEpoch =
|
samples != null ? Math.ceil(samples / batchSize) : steps;
|
}
|
else {
|
// Undetermined number of batches per epoch, e.g., due to
|
// `fitDataset()` without `batchesPerEpoch`.
|
this.numTrainBatchesPerEpoch = 0;
|
}
|
return [2 /*return*/];
|
});
|
}); },
|
onEpochBegin: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () {
|
return __generator(this, function (_a) {
|
exports.progressBarHelper.log("Epoch ".concat(epoch + 1, " / ").concat(this.params.epochs));
|
this.currentEpochBegin = tfjs_1.util.now();
|
this.epochDurationMillis = null;
|
this.usPerStep = null;
|
this.batchesInLatestEpoch = 0;
|
this.terminalWidth = process.stderr.columns;
|
return [2 /*return*/];
|
});
|
}); },
|
onBatchEnd: function (batch, logs) { return __awaiter(_this, void 0, void 0, function () {
|
var maxMetricsStringLength, tickTokens;
|
return __generator(this, function (_a) {
|
switch (_a.label) {
|
case 0:
|
this.batchesInLatestEpoch++;
|
if (batch === 0) {
|
this.progressBar = new exports.progressBarHelper.ProgressBar('eta=:eta :bar :placeholderForLossesAndMetrics', {
|
width: Math.floor(0.5 * this.terminalWidth),
|
total: this.numTrainBatchesPerEpoch + 1,
|
head: ">",
|
renderThrottle: this.RENDER_THROTTLE_MS
|
});
|
}
|
maxMetricsStringLength = Math.floor(this.terminalWidth * 0.5 - 12);
|
tickTokens = {
|
placeholderForLossesAndMetrics: this.formatLogsAsMetricsContent(logs, maxMetricsStringLength)
|
};
|
if (this.numTrainBatchesPerEpoch === 0) {
|
// Undetermined number of batches per epoch.
|
this.progressBar.tick(0, tickTokens);
|
}
|
else {
|
this.progressBar.tick(tickTokens);
|
}
|
return [4 /*yield*/, (0, tfjs_1.nextFrame)()];
|
case 1:
|
_a.sent();
|
if (batch === this.numTrainBatchesPerEpoch - 1) {
|
this.epochDurationMillis = tfjs_1.util.now() - this.currentEpochBegin;
|
this.usPerStep = this.params.samples != null ?
|
this.epochDurationMillis / this.params.samples * 1e3 :
|
this.epochDurationMillis / this.batchesInLatestEpoch * 1e3;
|
}
|
return [2 /*return*/];
|
}
|
});
|
}); },
|
onEpochEnd: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () {
|
var lossesAndMetricsString;
|
return __generator(this, function (_a) {
|
switch (_a.label) {
|
case 0:
|
if (this.epochDurationMillis == null) {
|
// In cases where the number of batches per epoch is not determined,
|
// the calculation of the per-step duration is done at the end of the
|
// epoch. N.B., this includes the time spent on validation.
|
this.epochDurationMillis = tfjs_1.util.now() - this.currentEpochBegin;
|
this.usPerStep =
|
this.epochDurationMillis / this.batchesInLatestEpoch * 1e3;
|
}
|
this.progressBar.tick({ placeholderForLossesAndMetrics: '' });
|
lossesAndMetricsString = this.formatLogsAsMetricsContent(logs);
|
exports.progressBarHelper.log("".concat(this.epochDurationMillis.toFixed(0), "ms ") +
|
"".concat(this.usPerStep.toFixed(0), "us/step - ") +
|
"".concat(lossesAndMetricsString));
|
return [4 /*yield*/, (0, tfjs_1.nextFrame)()];
|
case 1:
|
_a.sent();
|
return [2 /*return*/];
|
}
|
});
|
}); },
|
}) || this;
|
_this.RENDER_THROTTLE_MS = 50;
|
return _this;
|
}
|
ProgbarLogger.prototype.formatLogsAsMetricsContent = function (logs, maxMetricsLength) {
|
var metricsContent = '';
|
var keys = Object.keys(logs).sort();
|
for (var _i = 0, keys_1 = keys; _i < keys_1.length; _i++) {
|
var key = keys_1[_i];
|
if (this.isFieldRelevant(key)) {
|
var value = logs[key];
|
metricsContent += "".concat(key, "=").concat(getSuccinctNumberDisplay(value), " ");
|
}
|
}
|
if (maxMetricsLength != null && metricsContent.length > maxMetricsLength) {
|
// Cut off metrics strings that are too long to avoid new lines being
|
// constantly created.
|
metricsContent = metricsContent.slice(0, maxMetricsLength - 3) + '...';
|
}
|
return metricsContent;
|
};
|
ProgbarLogger.prototype.isFieldRelevant = function (key) {
|
return key !== 'batch' && key !== 'size';
|
};
|
return ProgbarLogger;
|
}(tfjs_1.CustomCallback));
|
exports.ProgbarLogger = ProgbarLogger;
|
var BASE_NUM_DIGITS = 2;
|
var MAX_NUM_DECIMAL_PLACES = 4;
|
/**
|
* Get a succint string representation of a number.
|
*
|
* Uses decimal notation if the number isn't too small.
|
* Otherwise, use engineering notation.
|
*
|
* @param x Input number.
|
* @return Succinct string representing `x`.
|
*/
|
function getSuccinctNumberDisplay(x) {
|
var decimalPlaces = getDisplayDecimalPlaces(x);
|
return decimalPlaces > MAX_NUM_DECIMAL_PLACES ?
|
x.toExponential(BASE_NUM_DIGITS) :
|
x.toFixed(decimalPlaces);
|
}
|
exports.getSuccinctNumberDisplay = getSuccinctNumberDisplay;
|
/**
|
* Determine the number of decimal places to display.
|
*
|
* @param x Number to display.
|
* @return Number of decimal places to display for `x`.
|
*/
|
function getDisplayDecimalPlaces(x) {
|
if (!Number.isFinite(x) || x === 0 || x > 1 || x < -1) {
|
return BASE_NUM_DIGITS;
|
}
|
else {
|
return BASE_NUM_DIGITS - Math.floor(Math.log10(Math.abs(x)));
|
}
|
}
|
exports.getDisplayDecimalPlaces = getDisplayDecimalPlaces;
|
/**
|
* Callback for logging to TensorBoard during training.
|
*
|
* Users are expected to access this class through the `tensorBoardCallback()`
|
* factory method instead.
|
*/
|
var TensorBoardCallback = /** @class */ (function (_super) {
|
__extends(TensorBoardCallback, _super);
|
function TensorBoardCallback(logdir, args) {
|
if (logdir === void 0) { logdir = './logs'; }
|
var _this = _super.call(this, {
|
onBatchEnd: function (batch, logs) { return __awaiter(_this, void 0, void 0, function () {
|
return __generator(this, function (_a) {
|
this.batchesSeen++;
|
if (this.args.updateFreq !== 'epoch') {
|
this.logMetrics(logs, 'batch_', this.batchesSeen);
|
}
|
return [2 /*return*/];
|
});
|
}); },
|
onEpochEnd: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () {
|
return __generator(this, function (_a) {
|
this.logMetrics(logs, 'epoch_', epoch + 1);
|
if (this.args.histogramFreq > 0 &&
|
epoch % this.args.histogramFreq === 0) {
|
this.logWeights(epoch);
|
}
|
return [2 /*return*/];
|
});
|
}); },
|
onTrainEnd: function (logs) { return __awaiter(_this, void 0, void 0, function () {
|
return __generator(this, function (_a) {
|
if (this.trainWriter != null) {
|
this.trainWriter.flush();
|
}
|
if (this.valWriter != null) {
|
this.valWriter.flush();
|
}
|
return [2 /*return*/];
|
});
|
}); }
|
}) || this;
|
_this.logdir = logdir;
|
_this.model = null;
|
_this.args = args == null ? {} : args;
|
if (_this.args.updateFreq == null) {
|
_this.args.updateFreq = 'epoch';
|
}
|
tfjs_1.util.assert(['batch', 'epoch'].indexOf(_this.args.updateFreq) !== -1, function () { return "Expected updateFreq to be 'batch' or 'epoch', but got " +
|
"".concat(_this.args.updateFreq); });
|
if (_this.args.histogramFreq == null) {
|
_this.args.histogramFreq = 0;
|
}
|
tfjs_1.util.assert(Number.isInteger(_this.args.histogramFreq) &&
|
_this.args.histogramFreq >= 0, function () { return "Expected histogramFreq to be a positive integer, but got " +
|
"".concat(_this.args.histogramFreq); });
|
_this.batchesSeen = 0;
|
return _this;
|
}
|
TensorBoardCallback.prototype.setModel = function (model) {
|
// This method is inherited from BaseCallback. To avoid cyclical imports,
|
// that class uses Container instead of LayersModel, and uses a run-time
|
// check to make sure the model is a LayersModel.
|
// Since this subclass isn't imported by tfjs-layers, we can safely use type
|
// the parameter as a LayersModel.
|
this.model = model;
|
};
|
TensorBoardCallback.prototype.logMetrics = function (logs, prefix, step) {
|
for (var key in logs) {
|
if (key === 'batch' || key === 'size' || key === 'num_steps') {
|
continue;
|
}
|
var VAL_PREFIX = 'val_';
|
if (key.startsWith(VAL_PREFIX)) {
|
this.ensureValWriterCreated();
|
var scalarName = prefix + key.slice(VAL_PREFIX.length);
|
this.valWriter.scalar(scalarName, logs[key], step);
|
}
|
else {
|
this.ensureTrainWriterCreated();
|
this.trainWriter.scalar("".concat(prefix).concat(key), logs[key], step);
|
}
|
}
|
};
|
TensorBoardCallback.prototype.logWeights = function (step) {
|
for (var _i = 0, _a = this.model.weights; _i < _a.length; _i++) {
|
var weights = _a[_i];
|
this.trainWriter.histogram(weights.name, weights.read(), step);
|
}
|
};
|
TensorBoardCallback.prototype.ensureTrainWriterCreated = function () {
|
this.trainWriter = (0, tensorboard_1.summaryFileWriter)(path.join(this.logdir, 'train'));
|
};
|
TensorBoardCallback.prototype.ensureValWriterCreated = function () {
|
this.valWriter = (0, tensorboard_1.summaryFileWriter)(path.join(this.logdir, 'val'));
|
};
|
return TensorBoardCallback;
|
}(tfjs_1.CustomCallback));
|
exports.TensorBoardCallback = TensorBoardCallback;
|
/**
|
* Callback for logging to TensorBoard during training.
|
*
|
* Writes the loss and metric values (if any) to the specified log directory
|
* (`logdir`) which can be ingested and visualized by TensorBoard.
|
* This callback is usually passed as a callback to `tf.Model.fit()` or
|
* `tf.Model.fitDataset()` calls during model training. The frequency at which
|
* the values are logged can be controlled with the `updateFreq` field of the
|
* configuration object (2nd argument).
|
*
|
* Usage example:
|
* ```js
|
* // Constructor a toy multilayer-perceptron regressor for demo purpose.
|
* const model = tf.sequential();
|
* model.add(
|
* tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]}));
|
* model.add(tf.layers.dense({units: 1}));
|
* model.compile({
|
* loss: 'meanSquaredError',
|
* optimizer: 'sgd',
|
* metrics: ['MAE']
|
* });
|
*
|
* // Generate some random fake data for demo purpose.
|
* const xs = tf.randomUniform([10000, 200]);
|
* const ys = tf.randomUniform([10000, 1]);
|
* const valXs = tf.randomUniform([1000, 200]);
|
* const valYs = tf.randomUniform([1000, 1]);
|
*
|
* // Start model training process.
|
* await model.fit(xs, ys, {
|
* epochs: 100,
|
* validationData: [valXs, valYs],
|
* // Add the tensorBoard callback here.
|
* callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
|
* });
|
* ```
|
*
|
* Then you can use the following commands to point tensorboard
|
* to the logdir:
|
*
|
* ```sh
|
* pip install tensorboard # Unless you've already installed it.
|
* tensorboard --logdir /tmp/fit_logs_1
|
* ```
|
*
|
* @param logdir Directory to which the logs will be written.
|
* @param args Optional configuration arguments.
|
* @returns An instance of `TensorBoardCallback`, which is a subclass of
|
* `tf.CustomCallback`.
|
*
|
* @doc {heading: 'TensorBoard', namespace: 'node'}
|
*/
|
function tensorBoard(logdir, args) {
|
if (logdir === void 0) { logdir = './logs'; }
|
return new TensorBoardCallback(logdir, args);
|
}
|
exports.tensorBoard = tensorBoard;
|