gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/**
 * @license
 * Copyright 2023 CodeSmith LLC
 *
 * Use of this source code is governed by an MIT-style
 * license that can be found in the LICENSE file or at
 * https://opensource.org/licenses/MIT.
 * =============================================================================
 */
import { image, serialization, tidy } from '@tensorflow/tfjs-core';
import { getExactlyOneTensor, getExactlyOneShape } from '../../utils/types_utils';
import { ValueError } from '../../errors';
import { BaseRandomLayer } from '../../engine/base_random_layer';
import { randomUniform } from '@tensorflow/tfjs-core';
const INTERPOLATION_KEYS = ['bilinear', 'nearest'];
export const INTERPOLATION_METHODS = new Set(INTERPOLATION_KEYS);
/**
 * Preprocessing Layer with randomly varies image during training
 *
 * This layer randomly adjusts the width of a batch of images of a
 * batch of images by a random factor.
 *
 * The input should be a 3D (unbatched) or
 * 4D (batched) tensor in the `"channels_last"` image data format. Input pixel
 * values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and of interger
 * or floating point dtype. By default, the layer will output floats.
 *
 * tf methods implemented in tfjs: 'bilinear', 'nearest',
 * tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',
 *                                   'gaussian', 'mitchellcubic'
 *
 */
class RandomWidth extends BaseRandomLayer {
    constructor(args) {
        super(args);
        const { factor, interpolation = 'bilinear' } = args;
        this.factor = factor;
        if (Array.isArray(this.factor) && this.factor.length === 2) {
            this.widthLower = this.factor[0];
            this.widthUpper = this.factor[1];
        }
        else if (!Array.isArray(this.factor) && this.factor > 0) {
            this.widthLower = -this.factor;
            this.widthUpper = this.factor;
        }
        else {
            throw new ValueError(`Invalid factor: ${this.factor}. Must be positive number or tuple of 2 numbers`);
        }
        if (this.widthLower < -1.0 || this.widthUpper < -1.0) {
            throw new ValueError(`factor must have values larger than -1. Got: ${this.factor}`);
        }
        if (this.widthUpper < this.widthLower) {
            throw new ValueError(`factor cannot have upper bound less than lower bound.
        Got upper bound: ${this.widthUpper}.
        Got lower bound: ${this.widthLower}
      `);
        }
        if (interpolation) {
            if (INTERPOLATION_METHODS.has(interpolation)) {
                this.interpolation = interpolation;
            }
            else {
                throw new ValueError(`Invalid interpolation parameter: ${interpolation} is not implemented`);
            }
        }
    }
    getConfig() {
        const config = {
            'factor': this.factor,
            'interpolation': this.interpolation,
        };
        const baseConfig = super.getConfig();
        Object.assign(config, baseConfig);
        return config;
    }
    computeOutputShape(inputShape) {
        inputShape = getExactlyOneShape(inputShape);
        const numChannels = inputShape[2];
        return [this.imgHeight, -1, numChannels];
    }
    call(inputs, kwargs) {
        return tidy(() => {
            const input = getExactlyOneTensor(inputs);
            this.imgHeight = input.shape[input.shape.length - 3];
            const imgWidth = input.shape[input.shape.length - 2];
            this.widthFactor = randomUniform([1], (1.0 + this.widthLower), (1.0 + this.widthUpper), 'float32', this.randomGenerator.next());
            let adjustedWidth = this.widthFactor.dataSync()[0] * imgWidth;
            adjustedWidth = Math.round(adjustedWidth);
            const size = [this.imgHeight, adjustedWidth];
            switch (this.interpolation) {
                case 'bilinear':
                    return image.resizeBilinear(inputs, size);
                case 'nearest':
                    return image.resizeNearestNeighbor(inputs, size);
                default:
                    throw new Error(`Interpolation is ${this.interpolation}
          but only ${[...INTERPOLATION_METHODS]} are supported`);
            }
        });
    }
}
/** @nocollapse */
RandomWidth.className = 'RandomWidth';
export { RandomWidth };
serialization.registerClass(RandomWidth);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"random_width.js","sourceRoot":"","sources":["../../../../../../../tfjs-layers/src/layers/preprocessing/random_width.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,OAAO,EAAE,KAAK,EAAQ,aAAa,EAAU,IAAI,EAAE,MAAM,uBAAuB,CAAC;AACjF,OAAO,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,MAAM,yBAAyB,CAAC;AAGlF,OAAO,EAAE,UAAU,EAAE,MAAM,cAAc,CAAC;AAC1C,OAAO,EAAuB,eAAe,EAAE,MAAM,gCAAgC,CAAC;AACtF,OAAO,EAAE,aAAa,EAAE,MAAM,uBAAuB,CAAC;AAStD,MAAM,kBAAkB,GAAG,CAAC,UAAU,EAAE,SAAS,CAAU,CAAC;AAC5D,MAAM,CAAC,MAAM,qBAAqB,GAAG,IAAI,GAAG,CAAC,kBAAkB,CAAC,CAAC;AAGjE;;;;;;;;;;;;;;;GAeG;AAEH,MAAa,WAAY,SAAQ,eAAe;IAU9C,YAAY,IAAqB;QAC/B,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,MAAM,EAAC,MAAM,EAAE,aAAa,GAAG,UAAU,EAAC,GAAG,IAAI,CAAC;QAElD,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QAErB,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,MAAM,CAAC,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;YAC1D,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;YACjC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;SAClC;aAAM,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,MAAM,CAAC,IAAI,IAAI,CAAC,MAAM,GAAG,CAAC,EAAC;YACxD,IAAI,CAAC,UAAU,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC;YAC/B,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC;SAC/B;aAAM;YACL,MAAM,IAAI,UAAU,CAClB,mBAAmB,IAAI,CAAC,MAAM,iDAAiD,CAChF,CAAC;SACH;QACD,IAAI,IAAI,CAAC,UAAU,GAAG,CAAC,GAAG,IAAI,IAAI,CAAC,UAAU,GAAG,CAAC,GAAG,EAAE;YACpD,MAAM,IAAI,UAAU,CAClB,gDAAgD,IAAI,CAAC,MAAM,EAAE,CAC9D,CAAC;SACH;QAED,IAAI,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,EAAE;YACrC,MAAM,IAAI,UAAU,CAClB;2BACmB,IAAI,CAAC,UAAU;2BACf,IAAI,CAAC,UAAU;OACnC,CAAC,CAAC;SACJ;QAED,IAAI,aAAa,EAAE;YACjB,IAAI,qBAAqB,CAAC,GAAG,CAAC,aAAa,CAAC,EAAE;gBAC5C,IAAI,CAAC,aAAa,GAAG,aAAa,CAAC;aACpC;iBAAM;gBACL,MAAM,IAAI,UAAU,CAAC,oCACjB,aAAa,qBAAqB,CAAC,CAAC;aACzC;SACF;IACH,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B;YACvC,QAAQ,EAAE,IAAI,CAAC,MAAM;YACrB,eAAe,EAAE,IAAI,CAAC,aAAa;SACpC,CAAC;QAEF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,MAAM,WAAW,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;QAClC,OAAO,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC,CAAC,EAAE,WAAW,CAAC,CAAC;IAC3C,CAAC;IAEQ,IAAI,CAAC,MAAuC,EACnD,MAAc;QAEd,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,KAAK,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;YAC1C,IAAI,CAAC,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YACrD,MAAM,QAAQ,GAAG,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAErD,IAAI,CAAC,WAAW,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC,EAClC,CAAC,GAAG,GAAG,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC,GAAG,GAAG,IAAI,CAAC,UAAU,CAAC,EAChD,SAAS,EAAE,IAAI,CAAC,eAAe,CAAC,IAAI,EAAE,CACvC,CAAC;YAEF,IAAI,aAAa,GAAG,IAAI,CAAC,WAAW,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,GAAG,QAAQ,CAAC;YAC9D,aAAa,GAAG,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC;YAE1C,MAAM,IAAI,GAAoB,CAAC,IAAI,CAAC,SAAS,EAAE,aAAa,CAAC,CAAC;YAE9D,QAAQ,IAAI,CAAC,aAAa,EAAE;gBAC1B,KAAK,UAAU;oBACb,OAAO,KAAK,CAAC,cAAc,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC;gBAC5C,KAAK,SAAS;oBACZ,OAAO,KAAK,CAAC,qBAAqB,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC;gBACnD;oBACE,MAAM,IAAI,KAAK,CAAC,oBAAoB,IAAI,CAAC,aAAa;qBAC3C,CAAC,GAAG,qBAAqB,CAAC,gBAAgB,CAAC,CAAC;aAC1D;QACH,CAAC,CAAC,CAAC;IACL,CAAC;;AA/FD,kBAAkB;AACF,qBAAS,GAAG,aAAa,CAAC;SAF/B,WAAW;AAmGxB,aAAa,CAAC,aAAa,CAAC,WAAW,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 CodeSmith LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\nimport { image, Rank, serialization, Tensor, tidy } from '@tensorflow/tfjs-core';\nimport { getExactlyOneTensor, getExactlyOneShape } from '../../utils/types_utils';\nimport { Shape } from '../../keras_format/common';\nimport { Kwargs } from '../../types';\nimport { ValueError } from '../../errors';\nimport { BaseRandomLayerArgs, BaseRandomLayer } from '../../engine/base_random_layer';\nimport { randomUniform } from '@tensorflow/tfjs-core';\n\nexport declare interface RandomWidthArgs extends BaseRandomLayerArgs {\n   factor: number | [number, number];\n   interpolation?: InterpolationType; // default = 'bilinear';\n   seed?: number; // default = null;\n   autoVectorize?: boolean;\n}\n\nconst INTERPOLATION_KEYS = ['bilinear', 'nearest'] as const;\nexport const INTERPOLATION_METHODS = new Set(INTERPOLATION_KEYS);\ntype InterpolationType = typeof INTERPOLATION_KEYS[number];\n\n/**\n * Preprocessing Layer with randomly varies image during training\n *\n * This layer randomly adjusts the width of a batch of images of a\n * batch of images by a random factor.\n *\n * The input should be a 3D (unbatched) or\n * 4D (batched) tensor in the `\"channels_last\"` image data format. Input pixel\n * values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and of interger\n * or floating point dtype. By default, the layer will output floats.\n *\n * tf methods implemented in tfjs: 'bilinear', 'nearest',\n * tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',\n *                                   'gaussian', 'mitchellcubic'\n *\n */\n\nexport class RandomWidth extends BaseRandomLayer {\n  /** @nocollapse */\n  static override className = 'RandomWidth';\n  private readonly factor: number | [number, number];\n  private readonly interpolation?: InterpolationType;  // defualt = 'bilinear\n  private widthLower: number;\n  private widthUpper: number;\n  private imgHeight: number;\n  private widthFactor: Tensor<Rank.R1>;\n\n  constructor(args: RandomWidthArgs) {\n    super(args);\n    const {factor, interpolation = 'bilinear'} = args;\n\n    this.factor = factor;\n\n    if (Array.isArray(this.factor) && this.factor.length === 2) {\n      this.widthLower = this.factor[0];\n      this.widthUpper = this.factor[1];\n    } else if (!Array.isArray(this.factor) && this.factor > 0){\n      this.widthLower = -this.factor;\n      this.widthUpper = this.factor;\n    } else {\n      throw new ValueError(\n        `Invalid factor: ${this.factor}. Must be positive number or tuple of 2 numbers`\n      );\n    }\n    if (this.widthLower < -1.0 || this.widthUpper < -1.0) {\n      throw new ValueError(\n        `factor must have values larger than -1. Got: ${this.factor}`\n      );\n    }\n\n    if (this.widthUpper < this.widthLower) {\n      throw new ValueError(\n        `factor cannot have upper bound less than lower bound.\n        Got upper bound: ${this.widthUpper}.\n        Got lower bound: ${this.widthLower}\n      `);\n    }\n\n    if (interpolation) {\n      if (INTERPOLATION_METHODS.has(interpolation)) {\n        this.interpolation = interpolation;\n      } else {\n        throw new ValueError(`Invalid interpolation parameter: ${\n            interpolation} is not implemented`);\n      }\n    } \n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {\n      'factor': this.factor,\n      'interpolation': this.interpolation,\n    };\n\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    inputShape = getExactlyOneShape(inputShape);\n    const numChannels = inputShape[2];\n    return [this.imgHeight, -1, numChannels];\n  }\n\n  override call(inputs: Tensor<Rank.R3>|Tensor<Rank.R4>,\n    kwargs: Kwargs): Tensor[]|Tensor {\n\n    return tidy(() => {\n      const input = getExactlyOneTensor(inputs);\n      this.imgHeight = input.shape[input.shape.length - 3];\n      const imgWidth = input.shape[input.shape.length - 2];\n\n      this.widthFactor = randomUniform([1],\n        (1.0 + this.widthLower), (1.0 + this.widthUpper),\n        'float32', this.randomGenerator.next()\n      );\n\n      let adjustedWidth = this.widthFactor.dataSync()[0] * imgWidth;\n      adjustedWidth = Math.round(adjustedWidth);\n\n      const size:[number, number] = [this.imgHeight, adjustedWidth];\n\n      switch (this.interpolation) {\n        case 'bilinear':\n          return image.resizeBilinear(inputs, size);\n        case 'nearest':\n          return image.resizeNearestNeighbor(inputs, size);\n        default:\n          throw new Error(`Interpolation is ${this.interpolation}\n          but only ${[...INTERPOLATION_METHODS]} are supported`);\n      }\n    });\n  }\n}\n\nserialization.registerClass(RandomWidth);\n"]}