gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util } from '@tensorflow/tfjs-core';
export class BatchNormPackedProgram {
    constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
        this.packedInputs = true;
        this.packedOutput = true;
        this.variableNames = ['x', 'mean', 'variance'];
        backend_util.assertAndGetBroadcastShape(xShape, meanShape);
        backend_util.assertAndGetBroadcastShape(xShape, varianceShape);
        let offsetSnippet = 'vec4(0.0)';
        if (offsetShape != null) {
            backend_util.assertAndGetBroadcastShape(xShape, offsetShape);
            this.variableNames.push('offset');
            offsetSnippet = 'getOffsetAtOutCoords()';
        }
        let scaleSnippet = 'vec4(1.0)';
        if (scaleShape != null) {
            backend_util.assertAndGetBroadcastShape(xShape, scaleShape);
            this.variableNames.push('scale');
            scaleSnippet = 'getScaleAtOutCoords()';
        }
        this.outputShape = xShape;
        this.userCode = `
      void main() {
        vec4 offset = ${offsetSnippet};
        vec4 scale = ${scaleSnippet};
 
        vec4 x = getXAtOutCoords();
        vec4 mean = getMeanAtOutCoords();
        vec4 variance = getVarianceAtOutCoords();
 
        vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));
 
        setOutput((x - mean) * inv + offset);
      }
    `;
    }
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmF0Y2hub3JtX3BhY2tlZF9ncHUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2JhdGNobm9ybV9wYWNrZWRfZ3B1LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUduRCxNQUFNLE9BQU8sc0JBQXNCO0lBT2pDLFlBQ0ksTUFBZ0IsRUFBRSxTQUFtQixFQUFFLGFBQXVCLEVBQzlELFdBQTBCLEVBQUUsVUFBeUIsRUFDckQsZUFBdUI7UUFOM0IsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFDcEIsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFNbEIsSUFBSSxDQUFDLGFBQWEsR0FBRyxDQUFDLEdBQUcsRUFBRSxNQUFNLEVBQUUsVUFBVSxDQUFDLENBQUM7UUFDL0MsWUFBWSxDQUFDLDBCQUEwQixDQUFDLE1BQU0sRUFBRSxTQUFTLENBQUMsQ0FBQztRQUMzRCxZQUFZLENBQUMsMEJBQTBCLENBQUMsTUFBTSxFQUFFLGFBQWEsQ0FBQyxDQUFDO1FBRS9ELElBQUksYUFBYSxHQUFHLFdBQVcsQ0FBQztRQUNoQyxJQUFJLFdBQVcsSUFBSSxJQUFJLEVBQUU7WUFDdkIsWUFBWSxDQUFDLDBCQUEwQixDQUFDLE1BQU0sRUFBRSxXQUFXLENBQUMsQ0FBQztZQUM3RCxJQUFJLENBQUMsYUFBYSxDQUFDLElBQUksQ0FBQyxRQUFRLENBQUMsQ0FBQztZQUNsQyxhQUFhLEdBQUcsd0JBQXdCLENBQUM7U0FDMUM7UUFFRCxJQUFJLFlBQVksR0FBRyxXQUFXLENBQUM7UUFDL0IsSUFBSSxVQUFVLElBQUksSUFBSSxFQUFFO1lBQ3RCLFlBQVksQ0FBQywwQkFBMEIsQ0FBQyxNQUFNLEVBQUUsVUFBVSxDQUFDLENBQUM7WUFDNUQsSUFBSSxDQUFDLGFBQWEsQ0FBQyxJQUFJLENBQUMsT0FBTyxDQUFDLENBQUM7WUFDakMsWUFBWSxHQUFHLHVCQUF1QixDQUFDO1NBQ3hDO1FBRUQsSUFBSSxDQUFDLFdBQVcsR0FBRyxNQUFNLENBQUM7UUFDMUIsSUFBSSxDQUFDLFFBQVEsR0FBRzs7d0JBRUksYUFBYTt1QkFDZCxZQUFZOzs7Ozs7eURBTXNCLGVBQWU7Ozs7S0FJbkUsQ0FBQztJQUNKLENBQUM7Q0FDRiIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5pbXBvcnQge0dQR1BVUHJvZ3JhbX0gZnJvbSAnLi9ncGdwdV9tYXRoJztcblxuZXhwb3J0IGNsYXNzIEJhdGNoTm9ybVBhY2tlZFByb2dyYW0gaW1wbGVtZW50cyBHUEdQVVByb2dyYW0ge1xuICB2YXJpYWJsZU5hbWVzOiBzdHJpbmdbXTtcbiAgb3V0cHV0U2hhcGU6IG51bWJlcltdO1xuICB1c2VyQ29kZTogc3RyaW5nO1xuICBwYWNrZWRJbnB1dHMgPSB0cnVlO1xuICBwYWNrZWRPdXRwdXQgPSB0cnVlO1xuXG4gIGNvbnN0cnVjdG9yKFxuICAgICAgeFNoYXBlOiBudW1iZXJbXSwgbWVhblNoYXBlOiBudW1iZXJbXSwgdmFyaWFuY2VTaGFwZTogbnVtYmVyW10sXG4gICAgICBvZmZzZXRTaGFwZTogbnVtYmVyW118bnVsbCwgc2NhbGVTaGFwZTogbnVtYmVyW118bnVsbCxcbiAgICAgIHZhcmlhbmNlRXBzaWxvbjogbnVtYmVyKSB7XG4gICAgdGhpcy52YXJpYWJsZU5hbWVzID0gWyd4JywgJ21lYW4nLCAndmFyaWFuY2UnXTtcbiAgICBiYWNrZW5kX3V0aWwuYXNzZXJ0QW5kR2V0QnJvYWRjYXN0U2hhcGUoeFNoYXBlLCBtZWFuU2hhcGUpO1xuICAgIGJhY2tlbmRfdXRpbC5hc3NlcnRBbmRHZXRCcm9hZGNhc3RTaGFwZSh4U2hhcGUsIHZhcmlhbmNlU2hhcGUpO1xuXG4gICAgbGV0IG9mZnNldFNuaXBwZXQgPSAndmVjNCgwLjApJztcbiAgICBpZiAob2Zmc2V0U2hhcGUgIT0gbnVsbCkge1xuICAgICAgYmFja2VuZF91dGlsLmFzc2VydEFuZEdldEJyb2FkY2FzdFNoYXBlKHhTaGFwZSwgb2Zmc2V0U2hhcGUpO1xuICAgICAgdGhpcy52YXJpYWJsZU5hbWVzLnB1c2goJ29mZnNldCcpO1xuICAgICAgb2Zmc2V0U25pcHBldCA9ICdnZXRPZmZzZXRBdE91dENvb3JkcygpJztcbiAgICB9XG5cbiAgICBsZXQgc2NhbGVTbmlwcGV0ID0gJ3ZlYzQoMS4wKSc7XG4gICAgaWYgKHNjYWxlU2hhcGUgIT0gbnVsbCkge1xuICAgICAgYmFja2VuZF91dGlsLmFzc2VydEFuZEdldEJyb2FkY2FzdFNoYXBlKHhTaGFwZSwgc2NhbGVTaGFwZSk7XG4gICAgICB0aGlzLnZhcmlhYmxlTmFtZXMucHVzaCgnc2NhbGUnKTtcbiAgICAgIHNjYWxlU25pcHBldCA9ICdnZXRTY2FsZUF0T3V0Q29vcmRzKCknO1xuICAgIH1cblxuICAgIHRoaXMub3V0cHV0U2hhcGUgPSB4U2hhcGU7XG4gICAgdGhpcy51c2VyQ29kZSA9IGBcbiAgICAgIHZvaWQgbWFpbigpIHtcbiAgICAgICAgdmVjNCBvZmZzZXQgPSAke29mZnNldFNuaXBwZXR9O1xuICAgICAgICB2ZWM0IHNjYWxlID0gJHtzY2FsZVNuaXBwZXR9O1xuXG4gICAgICAgIHZlYzQgeCA9IGdldFhBdE91dENvb3JkcygpO1xuICAgICAgICB2ZWM0IG1lYW4gPSBnZXRNZWFuQXRPdXRDb29yZHMoKTtcbiAgICAgICAgdmVjNCB2YXJpYW5jZSA9IGdldFZhcmlhbmNlQXRPdXRDb29yZHMoKTtcblxuICAgICAgICB2ZWM0IGludiA9IHNjYWxlICogaW52ZXJzZXNxcnQodmFyaWFuY2UgKyB2ZWM0KCR7dmFyaWFuY2VFcHNpbG9ufSkpO1xuXG4gICAgICAgIHNldE91dHB1dCgoeCAtIG1lYW4pICogaW52ICsgb2Zmc2V0KTtcbiAgICAgIH1cbiAgICBgO1xuICB9XG59XG4iXX0=