gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { FusedBatchNorm, util } from '@tensorflow/tfjs-core';
import { assertNotComplex } from '../cpu_util';
export function batchNorm(args) {
    const { inputs, backend, attrs } = args;
    const { x, scale, offset, mean, variance } = inputs;
    util.assert(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
        'equal ranks.');
    util.assert(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
        'equal ranks.');
    util.assert(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
        'equal ranks.');
    assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');
    let { varianceEpsilon } = attrs;
    if (varianceEpsilon == null) {
        varianceEpsilon = 0.001;
    }
    const xVals = backend.data.get(x.dataId).values;
    const mVals = backend.data.get(mean.dataId).values;
    const varVals = backend.data.get(variance.dataId).values;
    const sVals = scale ? backend.data.get(scale.dataId).values :
        new Float32Array([1]);
    const offVals = offset ?
        backend.data.get(offset.dataId).values :
        new Float32Array([0]);
    const outVals = new Float32Array(xVals.length);
    const offValsLength = offVals.length;
    const sValsLength = sVals.length;
    const varValsLength = varVals.length;
    const mValsLength = mVals.length;
    let offi = 0;
    let mi = 0;
    let si = 0;
    let vi = 0;
    for (let i = 0; i < xVals.length; ++i) {
        outVals[i] = offVals[offi++] +
            (xVals[i] - mVals[mi++]) * sVals[si++] /
                Math.sqrt(varVals[vi++] + varianceEpsilon);
        if (offi >= offValsLength) {
            offi = 0;
        }
        if (mi >= mValsLength) {
            mi = 0;
        }
        if (si >= sValsLength) {
            si = 0;
        }
        if (vi >= varValsLength) {
            vi = 0;
        }
    }
    return backend.makeTensorInfo(x.shape, x.dtype, outVals);
}
export const batchNormConfig = {
    kernelName: FusedBatchNorm,
    backendName: 'cpu',
    kernelFunc: batchNorm,
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQmF0Y2hOb3JtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9CYXRjaE5vcm0udHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLGNBQWMsRUFBK0YsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHeEosT0FBTyxFQUFDLGdCQUFnQixFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRTdDLE1BQU0sVUFBVSxTQUFTLENBQUMsSUFJekI7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBRSxLQUFLLEVBQUUsTUFBTSxFQUFFLElBQUksRUFBRSxRQUFRLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFFbEQsSUFBSSxDQUFDLE1BQU0sQ0FDUCxJQUFJLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxRQUFRLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFDM0MsR0FBRyxFQUFFLENBQUMsa0VBQWtFO1FBQ3BFLGNBQWMsQ0FBQyxDQUFDO0lBQ3hCLElBQUksQ0FBQyxNQUFNLENBQ1AsTUFBTSxJQUFJLElBQUksSUFBSSxJQUFJLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxNQUFNLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFDM0QsR0FBRyxFQUFFLENBQUMsZ0VBQWdFO1FBQ2xFLGNBQWMsQ0FBQyxDQUFDO0lBQ3hCLElBQUksQ0FBQyxNQUFNLENBQ1AsS0FBSyxJQUFJLElBQUksSUFBSSxJQUFJLENBQUMsS0FBSyxDQUFDLE1BQU0sS0FBSyxLQUFLLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFDekQsR0FBRyxFQUFFLENBQUMsK0RBQStEO1FBQ2pFLGNBQWMsQ0FBQyxDQUFDO0lBRXhCLGdCQUFnQixDQUFDLENBQUMsQ0FBQyxFQUFFLElBQUksRUFBRSxRQUFRLEVBQUUsS0FBSyxFQUFFLE1BQU0sQ0FBQyxFQUFFLFdBQVcsQ0FBQyxDQUFDO0lBRWxFLElBQUksRUFBQyxlQUFlLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFDOUIsSUFBSSxlQUFlLElBQUksSUFBSSxFQUFFO1FBQzNCLGVBQWUsR0FBRyxLQUFLLENBQUM7S0FDekI7SUFFRCxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUM5RCxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUNqRSxNQUFNLE9BQU8sR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxRQUFRLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUN2RSxNQUFNLEtBQUssR0FBRyxLQUFLLENBQUMsQ0FBQyxDQUFDLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDLENBQUM7UUFDckQsSUFBSSxZQUFZLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQzVDLE1BQU0sT0FBTyxHQUFHLE1BQU0sQ0FBQyxDQUFDO1FBQ3BCLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDLENBQUM7UUFDdEQsSUFBSSxZQUFZLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQzFCLE1BQU0sT0FBTyxHQUFHLElBQUksWUFBWSxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsQ0FBQztJQUUvQyxNQUFNLGFBQWEsR0FBRyxPQUFPLENBQUMsTUFBTSxDQUFDO0lBQ3JDLE1BQU0sV0FBVyxHQUFHLEtBQUssQ0FBQyxNQUFNLENBQUM7SUFDakMsTUFBTSxhQUFhLEdBQUcsT0FBTyxDQUFDLE1BQU0sQ0FBQztJQUNyQyxNQUFNLFdBQVcsR0FBRyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBRWpDLElBQUksSUFBSSxHQUFHLENBQUMsQ0FBQztJQUNiLElBQUksRUFBRSxHQUFHLENBQUMsQ0FBQztJQUNYLElBQUksRUFBRSxHQUFHLENBQUMsQ0FBQztJQUNYLElBQUksRUFBRSxHQUFHLENBQUMsQ0FBQztJQUNYLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxLQUFLLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQ3JDLE9BQU8sQ0FBQyxDQUFDLENBQUMsR0FBRyxPQUFPLENBQUMsSUFBSSxFQUFFLENBQUM7WUFDeEIsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEdBQUcsS0FBSyxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxLQUFLLENBQUMsRUFBRSxFQUFFLENBQUM7Z0JBQ2xDLElBQUksQ0FBQyxJQUFJLENBQUMsT0FBTyxDQUFDLEVBQUUsRUFBRSxDQUFDLEdBQUcsZUFBZSxDQUFDLENBQUM7UUFDbkQsSUFBSSxJQUFJLElBQUksYUFBYSxFQUFFO1lBQ3pCLElBQUksR0FBRyxDQUFDLENBQUM7U0FDVjtRQUNELElBQUksRUFBRSxJQUFJLFdBQVcsRUFBRTtZQUNyQixFQUFFLEdBQUcsQ0FBQyxDQUFDO1NBQ1I7UUFDRCxJQUFJLEVBQUUsSUFBSSxXQUFXLEVBQUU7WUFDckIsRUFBRSxHQUFHLENBQUMsQ0FBQztTQUNSO1FBQ0QsSUFBSSxFQUFFLElBQUksYUFBYSxFQUFFO1lBQ3ZCLEVBQUUsR0FBRyxDQUFDLENBQUM7U0FDUjtLQUNGO0lBQ0QsT0FBTyxPQUFPLENBQUMsY0FBYyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDLEtBQUssRUFBRSxPQUFPLENBQUMsQ0FBQztBQUMzRCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sZUFBZSxHQUFpQjtJQUMzQyxVQUFVLEVBQUUsY0FBYztJQUMxQixXQUFXLEVBQUUsS0FBSztJQUNsQixVQUFVLEVBQUUsU0FBa0M7Q0FDL0MsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtGdXNlZEJhdGNoTm9ybSwgRnVzZWRCYXRjaE5vcm1BdHRycywgRnVzZWRCYXRjaE5vcm1JbnB1dHMsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgVGVuc29ySW5mbywgVHlwZWRBcnJheSwgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZENQVX0gZnJvbSAnLi4vYmFja2VuZF9jcHUnO1xuaW1wb3J0IHthc3NlcnROb3RDb21wbGV4fSBmcm9tICcuLi9jcHVfdXRpbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBiYXRjaE5vcm0oYXJnczoge1xuICBpbnB1dHM6IEZ1c2VkQmF0Y2hOb3JtSW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZENQVSxcbiAgYXR0cnM6IEZ1c2VkQmF0Y2hOb3JtQXR0cnNcbn0pOiBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3gsIHNjYWxlLCBvZmZzZXQsIG1lYW4sIHZhcmlhbmNlfSA9IGlucHV0cztcblxuICB1dGlsLmFzc2VydChcbiAgICAgIG1lYW4uc2hhcGUubGVuZ3RoID09PSB2YXJpYW5jZS5zaGFwZS5sZW5ndGgsXG4gICAgICAoKSA9PiAnQmF0Y2ggbm9ybWFsaXphdGlvbiBncmFkaWVudCByZXF1aXJlcyBtZWFuIGFuZCB2YXJpYW5jZSB0byBoYXZlICcgK1xuICAgICAgICAgICdlcXVhbCByYW5rcy4nKTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBvZmZzZXQgPT0gbnVsbCB8fCBtZWFuLnNoYXBlLmxlbmd0aCA9PT0gb2Zmc2V0LnNoYXBlLmxlbmd0aCxcbiAgICAgICgpID0+ICdCYXRjaCBub3JtYWxpemF0aW9uIGdyYWRpZW50IHJlcXVpcmVzIG1lYW4gYW5kIG9mZnNldCB0byBoYXZlICcgK1xuICAgICAgICAgICdlcXVhbCByYW5rcy4nKTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBzY2FsZSA9PSBudWxsIHx8IG1lYW4uc2hhcGUubGVuZ3RoID09PSBzY2FsZS5zaGFwZS5sZW5ndGgsXG4gICAgICAoKSA9PiAnQmF0Y2ggbm9ybWFsaXphdGlvbiBncmFkaWVudCByZXF1aXJlcyBtZWFuIGFuZCBzY2FsZSB0byBoYXZlICcgK1xuICAgICAgICAgICdlcXVhbCByYW5rcy4nKTtcblxuICBhc3NlcnROb3RDb21wbGV4KFt4LCBtZWFuLCB2YXJpYW5jZSwgc2NhbGUsIG9mZnNldF0sICdiYXRjaE5vcm0nKTtcblxuICBsZXQge3ZhcmlhbmNlRXBzaWxvbn0gPSBhdHRycztcbiAgaWYgKHZhcmlhbmNlRXBzaWxvbiA9PSBudWxsKSB7XG4gICAgdmFyaWFuY2VFcHNpbG9uID0gMC4wMDE7XG4gIH1cblxuICBjb25zdCB4VmFscyA9IGJhY2tlbmQuZGF0YS5nZXQoeC5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuICBjb25zdCBtVmFscyA9IGJhY2tlbmQuZGF0YS5nZXQobWVhbi5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuICBjb25zdCB2YXJWYWxzID0gYmFja2VuZC5kYXRhLmdldCh2YXJpYW5jZS5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5O1xuICBjb25zdCBzVmFscyA9IHNjYWxlID8gYmFja2VuZC5kYXRhLmdldChzY2FsZS5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5IDpcbiAgICAgICAgICAgICAgICAgICAgICAgIG5ldyBGbG9hdDMyQXJyYXkoWzFdKTtcbiAgY29uc3Qgb2ZmVmFscyA9IG9mZnNldCA/XG4gICAgICBiYWNrZW5kLmRhdGEuZ2V0KG9mZnNldC5kYXRhSWQpLnZhbHVlcyBhcyBUeXBlZEFycmF5IDpcbiAgICAgIG5ldyBGbG9hdDMyQXJyYXkoWzBdKTtcbiAgY29uc3Qgb3V0VmFscyA9IG5ldyBGbG9hdDMyQXJyYXkoeFZhbHMubGVuZ3RoKTtcblxuICBjb25zdCBvZmZWYWxzTGVuZ3RoID0gb2ZmVmFscy5sZW5ndGg7XG4gIGNvbnN0IHNWYWxzTGVuZ3RoID0gc1ZhbHMubGVuZ3RoO1xuICBjb25zdCB2YXJWYWxzTGVuZ3RoID0gdmFyVmFscy5sZW5ndGg7XG4gIGNvbnN0IG1WYWxzTGVuZ3RoID0gbVZhbHMubGVuZ3RoO1xuXG4gIGxldCBvZmZpID0gMDtcbiAgbGV0IG1pID0gMDtcbiAgbGV0IHNpID0gMDtcbiAgbGV0IHZpID0gMDtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCB4VmFscy5sZW5ndGg7ICsraSkge1xuICAgIG91dFZhbHNbaV0gPSBvZmZWYWxzW29mZmkrK10gK1xuICAgICAgICAoeFZhbHNbaV0gLSBtVmFsc1ttaSsrXSkgKiBzVmFsc1tzaSsrXSAvXG4gICAgICAgICAgICBNYXRoLnNxcnQodmFyVmFsc1t2aSsrXSArIHZhcmlhbmNlRXBzaWxvbik7XG4gICAgaWYgKG9mZmkgPj0gb2ZmVmFsc0xlbmd0aCkge1xuICAgICAgb2ZmaSA9IDA7XG4gICAgfVxuICAgIGlmIChtaSA+PSBtVmFsc0xlbmd0aCkge1xuICAgICAgbWkgPSAwO1xuICAgIH1cbiAgICBpZiAoc2kgPj0gc1ZhbHNMZW5ndGgpIHtcbiAgICAgIHNpID0gMDtcbiAgICB9XG4gICAgaWYgKHZpID49IHZhclZhbHNMZW5ndGgpIHtcbiAgICAgIHZpID0gMDtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8oeC5zaGFwZSwgeC5kdHlwZSwgb3V0VmFscyk7XG59XG5cbmV4cG9ydCBjb25zdCBiYXRjaE5vcm1Db25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogRnVzZWRCYXRjaE5vcm0sXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogYmF0Y2hOb3JtIGFzIHVua25vd24gYXMgS2VybmVsRnVuYyxcbn07XG4iXX0=