/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { FusedBatchNorm, util } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
export function batchNorm(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, scale, offset, mean, variance } = inputs;
|
util.assert(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
|
'equal ranks.');
|
util.assert(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
|
'equal ranks.');
|
util.assert(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
|
'equal ranks.');
|
assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');
|
let { varianceEpsilon } = attrs;
|
if (varianceEpsilon == null) {
|
varianceEpsilon = 0.001;
|
}
|
const xVals = backend.data.get(x.dataId).values;
|
const mVals = backend.data.get(mean.dataId).values;
|
const varVals = backend.data.get(variance.dataId).values;
|
const sVals = scale ? backend.data.get(scale.dataId).values :
|
new Float32Array([1]);
|
const offVals = offset ?
|
backend.data.get(offset.dataId).values :
|
new Float32Array([0]);
|
const outVals = new Float32Array(xVals.length);
|
const offValsLength = offVals.length;
|
const sValsLength = sVals.length;
|
const varValsLength = varVals.length;
|
const mValsLength = mVals.length;
|
let offi = 0;
|
let mi = 0;
|
let si = 0;
|
let vi = 0;
|
for (let i = 0; i < xVals.length; ++i) {
|
outVals[i] = offVals[offi++] +
|
(xVals[i] - mVals[mi++]) * sVals[si++] /
|
Math.sqrt(varVals[vi++] + varianceEpsilon);
|
if (offi >= offValsLength) {
|
offi = 0;
|
}
|
if (mi >= mValsLength) {
|
mi = 0;
|
}
|
if (si >= sValsLength) {
|
si = 0;
|
}
|
if (vi >= varValsLength) {
|
vi = 0;
|
}
|
}
|
return backend.makeTensorInfo(x.shape, x.dtype, outVals);
|
}
|
export const batchNormConfig = {
|
kernelName: FusedBatchNorm,
|
backendName: 'cpu',
|
kernelFunc: batchNorm,
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQmF0Y2hOb3JtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9CYXRjaE5vcm0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLGNBQWMsRUFBK0YsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHeEosT0FBTyxFQUFDLGdCQUFnQixFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRTdDLE1BQU0sVUFBVSxTQUFTLENBQUMsSUFJekI7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBRSxLQUFLLEVBQUUsTUFBTSxFQUFFLElBQUksRUFBRSxRQUFRLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFFbEQsSUFBSSxDQUFDLE1BQU0sQ0FDUCxJQUFJLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxRQUFRLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFDM0MsR0FBRyxFQUFFLENBQUMsa0VBQWtFO1FBQ3BFLGNBQWMsQ0FBQyxDQUFDO0lBQ3hCLElBQUksQ0FBQyxNQUFNLENBQ1AsTUFBTSxJQUFJLElBQUksSUFBSSxJQUFJLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxNQUFNLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFDM0QsR0FBRyxFQUFFLENBQUMsZ0VBQWdFO1FBQ2xFLGNBQWMsQ0FBQyxDQUFDO0lBQ3hCLElBQUksQ0FBQyxNQUFNLENBQ1AsS0FBSyxJQUFJLElBQUksSUFBSSxJQUFJLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxLQUFLLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFDekQsR0FBRyxFQUFFLENBQUMsK0RBQStEO1FBQ2pFLGNBQWMsQ0FBQyxDQUFDO0lBRXhCLGdCQUFnQixDQUFDLENBQUMsQ0FBQyxFQUFFLElBQUksRUFBRSxRQUFRLEVBQUUsS0FBSyxFQUFFLE1BQU0sQ0FBQyxFQUFFLFdBQVcsQ0FBQyxDQUFDO0lBRWxFLElBQUksRUFBQyxlQUFlLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFDOUIsSUFBSSxlQUFlLElBQUksSUFBSSxFQUFFO1FBQzNCLGVBQWUsR0FBRyxLQUFLLENBQUM7S0FDekI7SUFFRCxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUM5RCxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUNqRSxNQUFNLE9BQU8sR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxRQUFRLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUN2RSxNQUFNLEtBQUssR0FBRyxLQUFLLENBQUMsQ0FBQyxDQUFDLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDLENBQUM7UUFDckQsSUFBSSxZQUFZLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQzVDLE1BQU0sT0FBTyxHQUFHLE1BQU0sQ0FBQyxDQUFDO1FBQ3BCLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDLENBQUM7UUFDdEQsSUFBSSxZQUFZLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQzFCLE1BQU0sT0FBTyxHQUFHLElBQUksWUFBWSxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsQ0FBQztJQUUvQyxNQUFNLGFBQWEsR0FBRyxPQUFPLENBQUMsTUFBTSxDQUFDO0lBQ3JDLE1BQU0sV0FBVyxHQUFHLEtBQUssQ0FBQyxNQUFNLENBQUM7SUFDakMsTUFBTSxhQUFhLEdBQUcsT0FBTyxDQUFDLE1BQU0sQ0FBQztJQUNyQyxNQUFNLFdBQVcsR0FBRyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBRWpDLElBQUksSUFBSSxHQUFHLENBQUMsQ0FBQztJQUNiLElBQUksRUFBRSxHQUFHLENBQUMsQ0FBQztJQUNYLElBQUksRUFBRSxHQUFHLENBQUMsQ0FBQztJQUNYLElBQUksRUFBRSxHQUFHLENBQUMsQ0FBQztJQUNYLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxLQUFLLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQ3JDLE9BQU8sQ0FBQyxDQUFDLENBQUMsR0FBRyxPQUFPLENBQUMsSUFBSSxFQUFFLENBQUM7WUFDeEIsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEdBQUcsS0FBSyxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxLQUFLLENBQUMsRUFBRSxFQUFFLENBQUM7Z0JBQ2xDLElBQUksQ0FBQyxJQUFJLENBQUMsT0FBTyxDQUFDLEVBQUUsRUFBRSxDQUFDLEdBQUcsZUFBZSxDQUFDLENBQUM7UUFDbkQsSUFBSSxJQUFJLElBQUksYUFBYSxFQUFFO1lBQ3pCLElBQUksR0FBRyxDQUFDLENBQUM7U0FDVjtRQUNELElBQUksRUFBRSxJQUFJLFdBQVcsRUFBRTtZQUNyQixFQUFFLEdBQUcsQ0FBQyxDQUFDO1NBQ1I7UUFDRCxJQUFJLEVBQUUsSUFBSSxXQUFXLEVBQUU7WUFDckIsRUFBRSxHQUFHLENBQUMsQ0FBQztTQUNSO1FBQ0QsSUFBSSxFQUFFLElBQUksYUFBYSxFQUFFO1lBQ3ZCLEVBQUUsR0FBRyxDQUFDLENBQUM7U0FDUjtLQUNGO0lBQ0QsT0FBTyxPQUFPLENBQUMsY0FBYyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxPQUFPLENBQUMsQ0FBQztBQUMzRCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sZUFBZSxHQUFpQjtJQUMzQyxVQUFVLEVBQUUsY0FBYztJQUMxQixXQUFXLEVBQUUsS0FBSztJQUNsQixVQUFVLEVBQUUsU0FBa0M7Q0FDL0MsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtGdXNlZEJhdGNoTm9ybSwgRnVzZWRCYXRjaE5vcm1BdHRycywgRnVzZWRCYXRjaE5vcm1JbnB1dHMsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgVGVuc29ySW5mbywgVHlwZWRBcnJheSwgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZENQVX0gZnJvbSAnLi4vYmFja2VuZF9jcHUnO1xuaW1wb3J0IHthc3NlcnROb3RDb21wbGV4fSBmcm9tICcuLi9jcHVfdXRpbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBiYXRjaE5vcm0oYXJnczoge1xuICBpbnB1dHM6IEZ1c2VkQmF0Y2hOb3JtSW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSxcbiAgYXR0cnM6IEZ1c2VkQmF0Y2hOb3JtQXR0cnNcbn0pOiBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3gsIHNjYWxlLCBvZmZzZXQsIG1lYW4sIHZhcmlhbmNlfSA9IGlucHV0cztcblxuICB1dGlsLmFzc2VydChcbiAgICAgIG1lYW4uc2hhcGUubGVuZ3RoID09PSB2YXJpYW5jZS5zaGFwZS5sZW5ndGgsXG4gICAgICAoKSA9PiAnQmF0Y2ggbm9ybWFsaXphdGlvbiBncmFkaWVudCByZXF1aXJlcyBtZWFuIGFuZCB2YXJpYW5jZSB0byBoYXZlICcgK1xuICAgICAgICAgICdlcXVhbCByYW5rcy4nKTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBvZmZzZXQgPT0gbnVsbCB8fCBtZWFuLnNoYXBlLmxlbmd0aCA9PT0gb2Zmc2V0LnNoYXBlLmxlbmd0aCxcbiAgICAgICgpID0+ICdCYXRjaCBub3JtYWxpemF0aW9uIGdyYWRpZW50IHJlcXVpcmVzIG1lYW4gYW5kIG9mZnNldCB0byBoYXZlICcgK1xuICAgICAgICAgICdlcXVhbCByYW5rcy4nKTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBzY2FsZSA9PSBudWxsIHx8IG1lYW4uc2hhcGUubGVuZ3RoID09PSBzY2FsZS5zaGFwZS5sZW5ndGgsXG4gICAgICAoKSA9PiAnQmF0Y2ggbm9ybWFsaXphdGlvbiBncmFkaWVudCByZXF1aXJlcyBtZWFuIGFuZCBzY2FsZSB0byBoYXZlICcgK1xuICAgICAgICAgICdlcXVhbCByYW5rcy4nKTtcblxuICBhc3NlcnROb3RDb21wbGV4KFt4LCBtZWFuLCB2YXJpYW5jZSwgc2NhbGUsIG9mZnNldF0sICdiYXRjaE5vcm0nKTtcblxuICBsZXQge3ZhcmlhbmNlRXBzaWxvbn0gPSBhdHRycztcbiAgaWYgKHZhcmlhbmNlRXBzaWxvbiA9PSBudWxsKSB7XG4gICAgdmFyaWFuY2VFcHNpbG9uID0gMC4wMDE7XG4gIH1cblxuICBjb25zdCB4VmFscyA9IGJhY2tlbmQuZGF0YS5nZXQoeC5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuICBjb25zdCBtVmFscyA9IGJhY2tlbmQuZGF0YS5nZXQobWVhbi5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuICBjb25zdCB2YXJWYWxzID0gYmFja2VuZC5kYXRhLmdldCh2YXJpYW5jZS5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuICBjb25zdCBzVmFscyA9IHNjYWxlID8gYmFja2VuZC5kYXRhLmdldChzY2FsZS5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5IDpcbiAgICAgICAgICAgICAgICAgICAgICAgIG5ldyBGbG9hdDMyQXJyYXkoWzFdKTtcbiAgY29uc3Qgb2ZmVmFscyA9IG9mZnNldCA/XG4gICAgICBiYWNrZW5kLmRhdGEuZ2V0KG9mZnNldC5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5IDpcbiAgICAgIG5ldyBGbG9hdDMyQXJyYXkoWzBdKTtcbiAgY29uc3Qgb3V0VmFscyA9IG5ldyBGbG9hdDMyQXJyYXkoeFZhbHMubGVuZ3RoKTtcblxuICBjb25zdCBvZmZWYWxzTGVuZ3RoID0gb2ZmVmFscy5sZW5ndGg7XG4gIGNvbnN0IHNWYWxzTGVuZ3RoID0gc1ZhbHMubGVuZ3RoO1xuICBjb25zdCB2YXJWYWxzTGVuZ3RoID0gdmFyVmFscy5sZW5ndGg7XG4gIGNvbnN0IG1WYWxzTGVuZ3RoID0gbVZhbHMubGVuZ3RoO1xuXG4gIGxldCBvZmZpID0gMDtcbiAgbGV0IG1pID0gMDtcbiAgbGV0IHNpID0gMDtcbiAgbGV0IHZpID0gMDtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCB4VmFscy5sZW5ndGg7ICsraSkge1xuICAgIG91dFZhbHNbaV0gPSBvZmZWYWxzW29mZmkrK10gK1xuICAgICAgICAoeFZhbHNbaV0gLSBtVmFsc1ttaSsrXSkgKiBzVmFsc1tzaSsrXSAvXG4gICAgICAgICAgICBNYXRoLnNxcnQodmFyVmFsc1t2aSsrXSArIHZhcmlhbmNlRXBzaWxvbik7XG4gICAgaWYgKG9mZmkgPj0gb2ZmVmFsc0xlbmd0aCkge1xuICAgICAgb2ZmaSA9IDA7XG4gICAgfVxuICAgIGlmIChtaSA+PSBtVmFsc0xlbmd0aCkge1xuICAgICAgbWkgPSAwO1xuICAgIH1cbiAgICBpZiAoc2kgPj0gc1ZhbHNMZW5ndGgpIHtcbiAgICAgIHNpID0gMDtcbiAgICB9XG4gICAgaWYgKHZpID49IHZhclZhbHNMZW5ndGgpIHtcbiAgICAgIHZpID0gMDtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8oeC5zaGFwZSwgeC5kdHlwZSwgb3V0VmFscyk7XG59XG5cbmV4cG9ydCBjb25zdCBiYXRjaE5vcm1Db25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogRnVzZWRCYXRjaE5vcm0sXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogYmF0Y2hOb3JtIGFzIHVua25vd24gYXMgS2VybmVsRnVuYyxcbn07XG4iXX0=
|