/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from './engine';
|
import { Tensor, Variable } from './tensor';
|
import { convertToTensor, convertToTensorArray } from './tensor_util_env';
|
import * as util from './util';
|
/**
|
* Provided `f(x)`, returns another function `g(x, dy?)`, which gives the
|
* gradient of `f(x)` with respect to `x`.
|
*
|
* If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to
|
* `x` is computed instead. `f(x)` must take a single tensor `x` and return a
|
* single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.
|
*
|
* ```js
|
* // f(x) = x ^ 2
|
* const f = x => x.square();
|
* // f'(x) = 2x
|
* const g = tf.grad(f);
|
*
|
* const x = tf.tensor1d([2, 3]);
|
* g(x).print();
|
* ```
|
*
|
* ```js
|
* // f(x) = x ^ 3
|
* const f = x => x.pow(tf.scalar(3, 'int32'));
|
* // f'(x) = 3x ^ 2
|
* const g = tf.grad(f);
|
* // f''(x) = 6x
|
* const gg = tf.grad(g);
|
*
|
* const x = tf.tensor1d([2, 3]);
|
* gg(x).print();
|
* ```
|
*
|
* @param f The function f(x), to compute gradient for.
|
*
|
* @doc {heading: 'Training', subheading: 'Gradients'}
|
*/
|
function grad(f) {
|
util.assert(util.isFunction(f), () => 'The f passed in grad(f) must be a function');
|
return (x, dy) => {
|
// x can be of any dtype, thus null as the last argument.
|
const $x = convertToTensor(x, 'x', 'tf.grad', 'string_or_numeric');
|
const $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null;
|
return ENGINE.tidy(() => {
|
const { value, grads } = ENGINE.gradients(() => f($x), [$x], $dy);
|
if ($dy != null) {
|
util.assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' +
|
'returned by f(x)');
|
}
|
checkGrads(grads);
|
return grads[0];
|
});
|
};
|
}
|
/**
|
* Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,
|
* which gives an array of gradients of `f()` with respect to each input
|
* [`x1`,`x2`,...].
|
*
|
* If `dy` is passed when calling `g()`, the gradient of
|
* `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.
|
* The provided `f` must take one or more tensors and return a single tensor
|
* `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.
|
*
|
* ```js
|
* // f(a, b) = a * b
|
* const f = (a, b) => a.mul(b);
|
* // df / da = b, df / db = a
|
* const g = tf.grads(f);
|
*
|
* const a = tf.tensor1d([2, 3]);
|
* const b = tf.tensor1d([-2, -3]);
|
* const [da, db] = g([a, b]);
|
* console.log('da');
|
* da.print();
|
* console.log('db');
|
* db.print();
|
* ```
|
*
|
* @param f The function `f(x1, x2,...)` to compute gradients for.
|
*
|
* @doc {heading: 'Training', subheading: 'Gradients'}
|
*/
|
function grads(f) {
|
util.assert(util.isFunction(f), () => 'The f passed in grads(f) must be a function');
|
return (args, dy) => {
|
util.assert(Array.isArray(args), () => 'The args passed in grads(f)(args) must be an array ' +
|
'of `Tensor`s or `TensorLike`s');
|
// args can be of any dtype, thus null as the last argument.
|
const $args = convertToTensorArray(args, 'args', 'tf.grads', 'string_or_numeric');
|
const $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null;
|
return ENGINE.tidy(() => {
|
const { value, grads } = ENGINE.gradients(() => f(...$args), $args, $dy);
|
if ($dy != null) {
|
util.assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' +
|
'match the shape returned by f([x1,...])');
|
}
|
checkGrads(grads);
|
return grads;
|
});
|
};
|
}
|
/**
|
* Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`
|
* returns a metric you want to show.
|
*
|
* The result is a rich object with the following properties:
|
* - grad: The gradient of `f(x)` w.r.t. `x` (result of `tf.grad`).
|
* - value: The value returned by `f(x)`.
|
*
|
* ```js
|
* // f(x) = x ^ 2
|
* const f = x => x.square();
|
* // f'(x) = 2x
|
* const g = tf.valueAndGrad(f);
|
*
|
* const x = tf.tensor1d([2, 3]);
|
* const {value, grad} = g(x);
|
*
|
* console.log('value');
|
* value.print();
|
* console.log('grad');
|
* grad.print();
|
* ```
|
*
|
* @doc {heading: 'Training', subheading: 'Gradients'}
|
*/
|
function valueAndGrad(f) {
|
util.assert(util.isFunction(f), () => 'The f passed in valueAndGrad(f) must be a function');
|
return (x, dy) => {
|
util.assert(x instanceof Tensor, () => 'The x passed in valueAndGrad(f)(x) must be a tensor');
|
util.assert(dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');
|
const { grads, value } = ENGINE.gradients(() => f(x), [x], dy);
|
checkGrads(grads);
|
return { grad: grads[0], value };
|
};
|
}
|
/**
|
* Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`
|
* returns a metric you want to show.
|
*
|
* The result is a rich object with the following properties:
|
* - grads: The gradients of `f()` w.r.t. each input (result of `tf.grads`).
|
* - value: The value returned by `f(x)`.
|
*
|
* ```js
|
* // f(a, b) = a * b
|
* const f = (a, b) => a.mul(b);
|
* // df/da = b, df/db = a
|
* const g = tf.valueAndGrads(f);
|
*
|
* const a = tf.tensor1d([2, 3]);
|
* const b = tf.tensor1d([-2, -3]);
|
* const {value, grads} = g([a, b]);
|
*
|
* const [da, db] = grads;
|
*
|
* console.log('value');
|
* value.print();
|
*
|
* console.log('da');
|
* da.print();
|
* console.log('db');
|
* db.print();
|
* ```
|
*
|
* @doc {heading: 'Training', subheading: 'Gradients'}
|
*/
|
function valueAndGrads(f) {
|
util.assert(util.isFunction(f), () => 'The f passed in valueAndGrads(f) must be a function');
|
return (args, dy) => {
|
util.assert(Array.isArray(args) && args.every(arg => arg instanceof Tensor), () => 'The args passed in valueAndGrads(f)(args) must be array of ' +
|
'tensors');
|
util.assert(dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');
|
const res = ENGINE.gradients(() => f(...args), args, dy);
|
if (dy != null) {
|
util.assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +
|
'match the shape returned by f([x1,...])');
|
}
|
checkGrads(res.grads);
|
return res;
|
};
|
}
|
/**
|
* Computes and returns the gradient of f(x) with respect to the list of
|
* trainable variables provided by `varList`. If no list is provided, it
|
* defaults to all trainable variables.
|
*
|
* ```js
|
* const a = tf.variable(tf.tensor1d([3, 4]));
|
* const b = tf.variable(tf.tensor1d([5, 6]));
|
* const x = tf.tensor1d([1, 2]);
|
*
|
* // f(a, b) = a * x ^ 2 + b * x
|
* const f = () => a.mul(x.square()).add(b.mul(x)).sum();
|
* // df/da = x ^ 2, df/db = x
|
* const {value, grads} = tf.variableGrads(f);
|
*
|
* Object.keys(grads).forEach(varName => grads[varName].print());
|
* ```
|
*
|
* @param f The function to execute. f() should return a scalar.
|
* @param varList The list of variables to compute the gradients with respect
|
* to. Defaults to all trainable variables.
|
* @returns An object with the following keys and values:
|
* - `value`: The value of the function `f`.
|
* - `grads`: A map from the names of the variables to the gradients.
|
* If the `varList` argument is provided explicitly and contains a subset of
|
* non-trainable variables, this map in the return value will contain keys
|
* that map the names of the non-trainable variables to `null`.
|
*
|
* @doc {heading: 'Training', subheading: 'Gradients'}
|
*/
|
function variableGrads(f, varList) {
|
util.assert(util.isFunction(f), () => 'The f passed in variableGrads(f) must be a function');
|
util.assert(varList == null ||
|
Array.isArray(varList) && varList.every(v => v instanceof Variable), () => 'The varList passed in variableGrads(f, varList) must be an array ' +
|
'of variables');
|
const specifiedVarList = varList != null;
|
if (!specifiedVarList) {
|
// Get all of the trainable variables.
|
varList = [];
|
for (const varName in ENGINE.registeredVariables) {
|
varList.push(ENGINE.registeredVariables[varName]);
|
}
|
}
|
const specifiedNonTrainable = specifiedVarList ? varList.filter(variable => !variable.trainable) : null;
|
// Prune non-trainable variables.
|
const originalVarCount = varList.length;
|
varList = varList.filter(variable => variable.trainable);
|
util.assert(varList.length > 0, () => `variableGrads() expects at least one of the input variables to ` +
|
`be trainable, but none of the ${originalVarCount} variables is ` +
|
`trainable.`);
|
const allowNoGradients = true;
|
const { value, grads } = ENGINE.gradients(f, varList, null, allowNoGradients);
|
util.assert(grads.some(g => g != null), () => 'Cannot find a connection between any variable and the result of ' +
|
'the loss function y=f(x). Please make sure the operations that ' +
|
'use variables are inside the function f passed to minimize().');
|
util.assert(value.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it ` +
|
`returned a rank-${value.rank} tensor`);
|
const namedGrads = {};
|
varList.forEach((v, i) => {
|
if (grads[i] != null) {
|
namedGrads[v.name] = grads[i];
|
}
|
});
|
if (specifiedNonTrainable != null) {
|
// If varList is explicitly provided and contains non-trainable values,
|
// add them to the returned gradients with `null` values.
|
specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);
|
}
|
return { value, grads: namedGrads };
|
}
|
/**
|
* Overrides the gradient computation of a function `f`.
|
*
|
* Takes a function
|
* `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`
|
* and returns another function `g(...inputs)` which takes the same inputs as
|
* `f`. When called, `g` returns `f().value`. In backward mode, custom gradients
|
* with respect to each input of `f` are computed using `f().gradFunc`.
|
*
|
* The `save` function passed to `f` should be used for saving tensors needed
|
* in the gradient. And the `saved` passed to the `gradFunc` is a
|
* `NamedTensorMap`, which contains those saved tensors.
|
*
|
* ```js
|
* const customOp = tf.customGrad((x, save) => {
|
* // Save x to make sure it's available later for the gradient.
|
* save([x]);
|
* // Override gradient of our custom x ^ 2 op to be dy * abs(x);
|
* return {
|
* value: x.square(),
|
* // Note `saved.x` which points to the `x` we saved earlier.
|
* gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]
|
* };
|
* });
|
*
|
* const x = tf.tensor1d([-1, -2, 3]);
|
* const dx = tf.grad(x => customOp(x));
|
*
|
* console.log(`f(x):`);
|
* customOp(x).print();
|
* console.log(`f'(x):`);
|
* dx(x).print();
|
* ```
|
*
|
* @param f The function to evaluate in forward mode, which should return
|
* `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`
|
* returns the custom gradients of `f` with respect to its inputs.
|
*
|
* @doc {heading: 'Training', subheading: 'Gradients'}
|
*/
|
function customGrad(f) {
|
return ENGINE.customGrad(f);
|
}
|
function checkGrads(grads) {
|
const numNullGradients = grads.filter(g => g == null).length;
|
if (numNullGradients > 0) {
|
throw new Error(`Cannot compute gradient of y=f(x) with respect to x. Make sure that
|
the f you passed encloses all operations that lead from x to y.`);
|
}
|
}
|
export { customGrad, variableGrads, valueAndGrad, valueAndGrads, grad, grads, };
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"gradients.js","sourceRoot":"","sources":["../../../../../tfjs-core/src/gradients.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAqB,MAAM,EAAC,MAAM,UAAU,CAAC;AACpD,OAAO,EAAS,MAAM,EAAE,QAAQ,EAAC,MAAM,UAAU,CAAC;AAElD,OAAO,EAAC,eAAe,EAAE,oBAAoB,EAAC,MAAM,mBAAmB,CAAC;AAExE,OAAO,KAAK,IAAI,MAAM,QAAQ,CAAC;AAE/B;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAiCG;AACH,SAAS,IAAI,CAAC,CAAwB;IAEpC,IAAI,CAAC,MAAM,CACP,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,GAAG,EAAE,CAAC,4CAA4C,CAAC,CAAC;IAC5E,OAAO,CAAC,CAAoB,EAAE,EAAsB,EAAU,EAAE;QAC9D,yDAAyD;QACzD,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,GAAG,EAAE,SAAS,EAAE,mBAAmB,CAAC,CAAC;QACnE,MAAM,GAAG,GACL,CAAC,EAAE,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC,eAAe,CAAC,EAAE,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;QAC/D,OAAO,MAAM,CAAC,IAAI,CAAC,GAAG,EAAE;YACtB,MAAM,EAAC,KAAK,EAAE,KAAK,EAAC,GAAG,MAAM,CAAC,SAAS,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,GAAG,CAAC,CAAC;YAChE,IAAI,GAAG,IAAI,IAAI,EAAE;gBACf,IAAI,CAAC,iBAAiB,CAClB,KAAK,CAAC,KAAK,EAAE,GAAG,CAAC,KAAK,EACtB,gEAAgE;oBAC5D,kBAAkB,CAAC,CAAC;aAC7B;YACD,UAAU,CAAC,KAAK,CAAC,CAAC;YAClB,OAAO,KAAK,CAAC,CAAC,CAAC,CAAC;QAClB,CAAC,CAAC,CAAC;IACL,CAAC,CAAC;AACJ,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA4BG;AACH,SAAS,KAAK,CAAC,CAAgC;IAE7C,IAAI,CAAC,MAAM,CACP,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,GAAG,EAAE,CAAC,6CAA6C,CAAC,CAAC;IAC7E,OAAO,CAAC,IAA8B,EAAE,EAAsB,EAAY,EAAE;QAC1E,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,EACnB,GAAG,EAAE,CAAC,qDAAqD;YACvD,+BAA+B,CAAC,CAAC;QACzC,4DAA4D;QAC5D,MAAM,KAAK,GACP,oBAAoB,CAAC,IAAI,EAAE,MAAM,EAAE,UAAU,EAAE,mBAAmB,CAAC,CAAC;QACxE,MAAM,GAAG,GACL,CAAC,EAAE,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC,eAAe,CAAC,EAAE,EAAE,IAAI,EAAE,UAAU,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;QAChE,OAAO,MAAM,CAAC,IAAI,CAAC,GAAG,EAAE;YACtB,MAAM,EAAC,KAAK,EAAE,KAAK,EAAC,GAAG,MAAM,CAAC,SAAS,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,EAAE,KAAK,EAAE,GAAG,CAAC,CAAC;YACvE,IAAI,GAAG,IAAI,IAAI,EAAE;gBACf,IAAI,CAAC,iBAAiB,CAClB,KAAK,CAAC,KAAK,EAAE,GAAG,CAAC,KAAK,EACtB,wDAAwD;oBACpD,yCAAyC,CAAC,CAAC;aACpD;YACD,UAAU,CAAC,KAAK,CAAC,CAAC;YAClB,OAAO,KAAK,CAAC;QACf,CAAC,CAAC,CAAC;IACL,CAAC,CAAC;AACJ,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;;GAwBG;AACH,SAAS,YAAY,CAAqC,CAAc;IAKtE,IAAI,CAAC,MAAM,CACP,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAClB,GAAG,EAAE,CAAC,oDAAoD,CAAC,CAAC;IAChE,OAAO,CAAC,CAAI,EAAE,EAAM,EAAE,EAAE;QACtB,IAAI,CAAC,MAAM,CACP,CAAC,YAAY,MAAM,EACnB,GAAG,EAAE,CAAC,qDAAqD,CAAC,CAAC;QACjE,IAAI,CAAC,MAAM,CACP,EAAE,IAAI,IAAI,IAAI,EAAE,YAAY,MAAM,EAClC,GAAG,EAAE,CAAC,0DAA0D,CAAC,CAAC;QACtE,MAAM,EAAC,KAAK,EAAE,KAAK,EAAC,GAAG,MAAM,CAAC,SAAS,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAC7D,UAAU,CAAC,KAAK,CAAC,CAAC;QAClB,OAAO,EAAC,IAAI,EAAE,KAAK,CAAC,CAAC,CAAM,EAAE,KAAK,EAAC,CAAC;IACtC,CAAC,CAAC;AACJ,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA8BG;AACH,SAAS,aAAa,CAAmB,CAA2B;IAKlE,IAAI,CAAC,MAAM,CACP,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAClB,GAAG,EAAE,CAAC,qDAAqD,CAAC,CAAC;IACjE,OAAO,CAAC,IAAc,EAAE,EAAM,EAAE,EAAE;QAChC,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,IAAI,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,YAAY,MAAM,CAAC,EAC/D,GAAG,EAAE,CAAC,6DAA6D;YAC/D,SAAS,CAAC,CAAC;QACnB,IAAI,CAAC,MAAM,CACP,EAAE,IAAI,IAAI,IAAI,EAAE,YAAY,MAAM,EAClC,GAAG,EAAE,CAAC,8DAA8D,CAAC,CAAC;QAC1E,MAAM,GAAG,GAAG,MAAM,CAAC,SAAS,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,EAAE,IAAI,EAAE,EAAE,CAAC,CAAC;QACzD,IAAI,EAAE,IAAI,IAAI,EAAE;YACd,IAAI,CAAC,iBAAiB,CAClB,GAAG,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,CAAC,KAAK,EACzB,gEAAgE;gBAC5D,yCAAyC,CAAC,CAAC;SACpD;QACD,UAAU,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC;QACtB,OAAO,GAAG,CAAC;IACb,CAAC,CAAC;AACJ,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AACH,SAAS,aAAa,CAAC,CAAe,EAAE,OAAoB;IAE1D,IAAI,CAAC,MAAM,CACP,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,EAClB,GAAG,EAAE,CAAC,qDAAqD,CAAC,CAAC;IACjE,IAAI,CAAC,MAAM,CACP,OAAO,IAAI,IAAI;QACX,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,IAAI,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,YAAY,QAAQ,CAAC,EACvE,GAAG,EAAE,CACD,mEAAmE;QACnE,cAAc,CAAC,CAAC;IAExB,MAAM,gBAAgB,GAAG,OAAO,IAAI,IAAI,CAAC;IACzC,IAAI,CAAC,gBAAgB,EAAE;QACrB,sCAAsC;QACtC,OAAO,GAAG,EAAE,CAAC;QACb,KAAK,MAAM,OAAO,IAAI,MAAM,CAAC,mBAAmB,EAAE;YAChD,OAAO,CAAC,IAAI,CAAC,MAAM,CAAC,mBAAmB,CAAC,OAAO,CAAC,CAAC,CAAC;SACnD;KACF;IAED,MAAM,qBAAqB,GACvB,gBAAgB,CAAC,CAAC,CAAC,OAAO,CAAC,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,QAAQ,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;IAE9E,iCAAiC;IACjC,MAAM,gBAAgB,GAAG,OAAO,CAAC,MAAM,CAAC;IACxC,OAAO,GAAG,OAAO,CAAC,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,QAAQ,CAAC,SAAS,CAAC,CAAC;IACzD,IAAI,CAAC,MAAM,CACP,OAAO,CAAC,MAAM,GAAG,CAAC,EAClB,GAAG,EAAE,CAAC,iEAAiE;QACnE,iCAAiC,gBAAgB,gBAAgB;QACjE,YAAY,CAAC,CAAC;IAEtB,MAAM,gBAAgB,GAAG,IAAI,CAAC;IAC9B,MAAM,EAAC,KAAK,EAAE,KAAK,EAAC,GAAG,MAAM,CAAC,SAAS,CAAC,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,gBAAgB,CAAC,CAAC;IAE5E,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,IAAI,IAAI,CAAC,EAC1B,GAAG,EAAE,CAAC,kEAAkE;QACpE,iEAAiE;QACjE,+DAA+D,CAAC,CAAC;IACzE,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,IAAI,KAAK,CAAC,EAChB,GAAG,EAAE,CAAC,gEAAgE;QAClE,mBAAmB,KAAK,CAAC,IAAI,SAAS,CAAC,CAAC;IAEhD,MAAM,UAAU,GAAmB,EAAE,CAAC;IACtC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE;QACvB,IAAI,KAAK,CAAC,CAAC,CAAC,IAAI,IAAI,EAAE;YACpB,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;SAC/B;IACH,CAAC,CAAC,CAAC;IACH,IAAI,qBAAqB,IAAI,IAAI,EAAE;QACjC,uEAAuE;QACvE,yDAAyD;QACzD,qBAAqB,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,IAAI,CAAC,CAAC;KAC/D;IACD,OAAO,EAAC,KAAK,EAAE,KAAK,EAAE,UAAU,EAAC,CAAC;AACpC,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAuCG;AACH,SAAS,UAAU,CAAmB,CAAwB;IAE5D,OAAO,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;AAC9B,CAAC;AAED,SAAS,UAAU,CAAC,KAAe;IACjC,MAAM,gBAAgB,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC,MAAM,CAAC;IAC7D,IAAI,gBAAgB,GAAG,CAAC,EAAE;QACxB,MAAM,IAAI,KAAK,CACX;oEAC4D,CAAC,CAAC;KACnE;AACH,CAAC;AAED,OAAO,EACL,UAAU,EACV,aAAa,EACb,YAAY,EACZ,aAAa,EACb,IAAI,EACJ,KAAK,GACN,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {CustomGradientFunc, ENGINE} from './engine';\nimport {Scalar, Tensor, Variable} from './tensor';\nimport {NamedTensorMap} from './tensor_types';\nimport {convertToTensor, convertToTensorArray} from './tensor_util_env';\nimport {TensorLike} from './types';\nimport * as util from './util';\n\n/**\n * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the\n * gradient of `f(x)` with respect to `x`.\n *\n * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to\n * `x` is computed instead. `f(x)` must take a single tensor `x` and return a\n * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.\n *\n * ```js\n * // f(x) = x ^ 2\n * const f = x => x.square();\n * // f'(x) = 2x\n * const g = tf.grad(f);\n *\n * const x = tf.tensor1d([2, 3]);\n * g(x).print();\n * ```\n *\n * ```js\n * // f(x) = x ^ 3\n * const f = x => x.pow(tf.scalar(3, 'int32'));\n * // f'(x) = 3x ^ 2\n * const g = tf.grad(f);\n * // f''(x) = 6x\n * const gg = tf.grad(g);\n *\n * const x = tf.tensor1d([2, 3]);\n * gg(x).print();\n * ```\n *\n * @param f The function f(x), to compute gradient for.\n *\n * @doc {heading: 'Training', subheading: 'Gradients'}\n */\nfunction grad(f: (x: Tensor) => Tensor): (\n    x: TensorLike|Tensor, dy?: TensorLike|Tensor) => Tensor {\n  util.assert(\n      util.isFunction(f), () => 'The f passed in grad(f) must be a function');\n  return (x: TensorLike|Tensor, dy?: TensorLike|Tensor): Tensor => {\n    // x can be of any dtype, thus null as the last argument.\n    const $x = convertToTensor(x, 'x', 'tf.grad', 'string_or_numeric');\n    const $dy: Tensor =\n        (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null;\n    return ENGINE.tidy(() => {\n      const {value, grads} = ENGINE.gradients(() => f($x), [$x], $dy);\n      if ($dy != null) {\n        util.assertShapesMatch(\n            value.shape, $dy.shape,\n            'The shape of dy passed in grad(f)(x, dy) must match the shape ' +\n                'returned by f(x)');\n      }\n      checkGrads(grads);\n      return grads[0];\n    });\n  };\n}\n\n/**\n * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,\n * which gives an array of gradients of `f()` with respect to each input\n * [`x1`,`x2`,...].\n *\n * If `dy` is passed when calling `g()`, the gradient of\n * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.\n * The provided `f` must take one or more tensors and return a single tensor\n * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.\n *\n * ```js\n * // f(a, b) = a * b\n * const f = (a, b) => a.mul(b);\n * // df / da = b, df / db = a\n * const g = tf.grads(f);\n *\n * const a = tf.tensor1d([2, 3]);\n * const b = tf.tensor1d([-2, -3]);\n * const [da, db] = g([a, b]);\n * console.log('da');\n * da.print();\n * console.log('db');\n * db.print();\n * ```\n *\n * @param f The function `f(x1, x2,...)` to compute gradients for.\n *\n * @doc {heading: 'Training', subheading: 'Gradients'}\n */\nfunction grads(f: (...args: Tensor[]) => Tensor): (\n    args: Array<Tensor|TensorLike>, dy?: Tensor|TensorLike) => Tensor[] {\n  util.assert(\n      util.isFunction(f), () => 'The f passed in grads(f) must be a function');\n  return (args: Array<Tensor|TensorLike>, dy?: Tensor|TensorLike): Tensor[] => {\n    util.assert(\n        Array.isArray(args),\n        () => 'The args passed in grads(f)(args) must be an array ' +\n            'of `Tensor`s or `TensorLike`s');\n    // args can be of any dtype, thus null as the last argument.\n    const $args =\n        convertToTensorArray(args, 'args', 'tf.grads', 'string_or_numeric');\n    const $dy: Tensor =\n        (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null;\n    return ENGINE.tidy(() => {\n      const {value, grads} = ENGINE.gradients(() => f(...$args), $args, $dy);\n      if ($dy != null) {\n        util.assertShapesMatch(\n            value.shape, $dy.shape,\n            'The shape of dy passed in grads(f)([x1,...], dy) must ' +\n                'match the shape returned by f([x1,...])');\n      }\n      checkGrads(grads);\n      return grads;\n    });\n  };\n}\n\n/**\n * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`\n * returns a metric you want to show.\n *\n * The result is a rich object with the following properties:\n * - grad: The gradient of `f(x)` w.r.t. `x` (result of `tf.grad`).\n * - value: The value returned by `f(x)`.\n *\n * ```js\n * // f(x) = x ^ 2\n * const f = x => x.square();\n * // f'(x) = 2x\n * const g = tf.valueAndGrad(f);\n *\n * const x = tf.tensor1d([2, 3]);\n * const {value, grad} = g(x);\n *\n * console.log('value');\n * value.print();\n * console.log('grad');\n * grad.print();\n * ```\n *\n * @doc {heading: 'Training', subheading: 'Gradients'}\n */\nfunction valueAndGrad<I extends Tensor, O extends Tensor>(f: (x: I) => O): (\n    x: I, dy?: O) => {\n  value: O;\n  grad: I;\n} {\n  util.assert(\n      util.isFunction(f),\n      () => 'The f passed in valueAndGrad(f) must be a function');\n  return (x: I, dy?: O) => {\n    util.assert(\n        x instanceof Tensor,\n        () => 'The x passed in valueAndGrad(f)(x) must be a tensor');\n    util.assert(\n        dy == null || dy instanceof Tensor,\n        () => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');\n    const {grads, value} = ENGINE.gradients(() => f(x), [x], dy);\n    checkGrads(grads);\n    return {grad: grads[0] as I, value};\n  };\n}\n\n/**\n * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`\n * returns a metric you want to show.\n *\n * The result is a rich object with the following properties:\n * - grads: The gradients of `f()` w.r.t. each input (result of `tf.grads`).\n * - value: The value returned by `f(x)`.\n *\n * ```js\n * // f(a, b) = a * b\n * const f = (a, b) => a.mul(b);\n * // df/da = b, df/db = a\n * const g = tf.valueAndGrads(f);\n *\n * const a = tf.tensor1d([2, 3]);\n * const b = tf.tensor1d([-2, -3]);\n * const {value, grads} = g([a, b]);\n *\n * const [da, db] = grads;\n *\n * console.log('value');\n * value.print();\n *\n * console.log('da');\n * da.print();\n * console.log('db');\n * db.print();\n * ```\n *\n * @doc {heading: 'Training', subheading: 'Gradients'}\n */\nfunction valueAndGrads<O extends Tensor>(f: (...args: Tensor[]) => O): (\n    args: Tensor[], dy?: O) => {\n  grads: Tensor[];\n  value: O;\n} {\n  util.assert(\n      util.isFunction(f),\n      () => 'The f passed in valueAndGrads(f) must be a function');\n  return (args: Tensor[], dy?: O) => {\n    util.assert(\n        Array.isArray(args) && args.every(arg => arg instanceof Tensor),\n        () => 'The args passed in valueAndGrads(f)(args) must be array of ' +\n            'tensors');\n    util.assert(\n        dy == null || dy instanceof Tensor,\n        () => 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');\n    const res = ENGINE.gradients(() => f(...args), args, dy);\n    if (dy != null) {\n      util.assertShapesMatch(\n          res.value.shape, dy.shape,\n          'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +\n              'match the shape returned by f([x1,...])');\n    }\n    checkGrads(res.grads);\n    return res;\n  };\n}\n\n/**\n * Computes and returns the gradient of f(x) with respect to the list of\n * trainable variables provided by `varList`. If no list is provided, it\n * defaults to all trainable variables.\n *\n * ```js\n * const a = tf.variable(tf.tensor1d([3, 4]));\n * const b = tf.variable(tf.tensor1d([5, 6]));\n * const x = tf.tensor1d([1, 2]);\n *\n * // f(a, b) = a * x ^ 2 + b * x\n * const f = () => a.mul(x.square()).add(b.mul(x)).sum();\n * // df/da = x ^ 2, df/db = x\n * const {value, grads} = tf.variableGrads(f);\n *\n * Object.keys(grads).forEach(varName => grads[varName].print());\n * ```\n *\n * @param f The function to execute. f() should return a scalar.\n * @param varList The list of variables to compute the gradients with respect\n *     to. Defaults to all trainable variables.\n * @returns An object with the following keys and values:\n *   - `value`: The value of the function `f`.\n *   - `grads`: A map from the names of the variables to the gradients.\n *     If the `varList` argument is provided explicitly and contains a subset of\n *     non-trainable variables, this map in the return value will contain keys\n *     that map the names of the non-trainable variables to `null`.\n *\n * @doc {heading: 'Training', subheading: 'Gradients'}\n */\nfunction variableGrads(f: () => Scalar, varList?: Variable[]):\n    {value: Scalar, grads: NamedTensorMap} {\n  util.assert(\n      util.isFunction(f),\n      () => 'The f passed in variableGrads(f) must be a function');\n  util.assert(\n      varList == null ||\n          Array.isArray(varList) && varList.every(v => v instanceof Variable),\n      () =>\n          'The varList passed in variableGrads(f, varList) must be an array ' +\n          'of variables');\n\n  const specifiedVarList = varList != null;\n  if (!specifiedVarList) {\n    // Get all of the trainable variables.\n    varList = [];\n    for (const varName in ENGINE.registeredVariables) {\n      varList.push(ENGINE.registeredVariables[varName]);\n    }\n  }\n\n  const specifiedNonTrainable: Variable[] =\n      specifiedVarList ? varList.filter(variable => !variable.trainable) : null;\n\n  // Prune non-trainable variables.\n  const originalVarCount = varList.length;\n  varList = varList.filter(variable => variable.trainable);\n  util.assert(\n      varList.length > 0,\n      () => `variableGrads() expects at least one of the input variables to ` +\n          `be trainable, but none of the ${originalVarCount} variables is ` +\n          `trainable.`);\n\n  const allowNoGradients = true;\n  const {value, grads} = ENGINE.gradients(f, varList, null, allowNoGradients);\n\n  util.assert(\n      grads.some(g => g != null),\n      () => 'Cannot find a connection between any variable and the result of ' +\n          'the loss function y=f(x). Please make sure the operations that ' +\n          'use variables are inside the function f passed to minimize().');\n  util.assert(\n      value.rank === 0,\n      () => `The f passed in variableGrads(f) must return a scalar, but it ` +\n          `returned a rank-${value.rank} tensor`);\n\n  const namedGrads: NamedTensorMap = {};\n  varList.forEach((v, i) => {\n    if (grads[i] != null) {\n      namedGrads[v.name] = grads[i];\n    }\n  });\n  if (specifiedNonTrainable != null) {\n    // If varList is explicitly provided and contains non-trainable values,\n    // add them to the returned gradients with `null` values.\n    specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);\n  }\n  return {value, grads: namedGrads};\n}\n\n/**\n * Overrides the gradient computation of a function `f`.\n *\n * Takes a function\n * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`\n * and returns another function `g(...inputs)` which takes the same inputs as\n * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients\n * with respect to each input of `f` are computed using `f().gradFunc`.\n *\n * The `save` function passed to `f` should be used for saving tensors needed\n * in the gradient. And the `saved` passed to the `gradFunc` is a\n * `NamedTensorMap`, which contains those saved tensors.\n *\n * ```js\n * const customOp = tf.customGrad((x, save) => {\n *   // Save x to make sure it's available later for the gradient.\n *   save([x]);\n *   // Override gradient of our custom x ^ 2 op to be dy * abs(x);\n *   return {\n *     value: x.square(),\n *     // Note `saved.x` which points to the `x` we saved earlier.\n *     gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]\n *   };\n * });\n *\n * const x = tf.tensor1d([-1, -2, 3]);\n * const dx = tf.grad(x => customOp(x));\n *\n * console.log(`f(x):`);\n * customOp(x).print();\n * console.log(`f'(x):`);\n * dx(x).print();\n * ```\n *\n * @param f The function to evaluate in forward mode, which should return\n *     `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`\n *     returns the custom gradients of `f` with respect to its inputs.\n *\n * @doc {heading: 'Training', subheading: 'Gradients'}\n */\nfunction customGrad<T extends Tensor>(f: CustomGradientFunc<T>):\n    (...args: Tensor[]) => T {\n  return ENGINE.customGrad(f);\n}\n\nfunction checkGrads(grads: Tensor[]) {\n  const numNullGradients = grads.filter(g => g == null).length;\n  if (numNullGradients > 0) {\n    throw new Error(\n        `Cannot compute gradient of y=f(x) with respect to x. Make sure that\n    the f you passed encloses all operations that lead from x to y.`);\n  }\n}\n\nexport {\n  customGrad,\n  variableGrads,\n  valueAndGrad,\n  valueAndGrads,\n  grad,\n  grads,\n};\n"]}
|