gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import * as tf from '../../dist/tfjs.esm';
 
import { conv, convDown, convNoRelu } from './convLayer';
import { ResidualLayerParams } from './types';
 
export function residual(x: tf.Tensor4D, params: ResidualLayerParams): tf.Tensor4D {
  let out = conv(x, params.conv1);
  out = convNoRelu(out, params.conv2);
  out = tf.add(out, x);
  out = tf.relu(out);
  return out;
}
 
export function residualDown(x: tf.Tensor4D, params: ResidualLayerParams): tf.Tensor4D {
  let out = convDown(x, params.conv1);
  out = convNoRelu(out, params.conv2);
 
  let pooled = tf.avgPool(x, 2, 2, 'valid') as tf.Tensor4D;
  const zeros = tf.zeros<tf.Rank.R4>(pooled.shape);
  const isPad = pooled.shape[3] !== out.shape[3];
  const isAdjustShape = pooled.shape[1] !== out.shape[1] || pooled.shape[2] !== out.shape[2];
 
  if (isAdjustShape) {
    const padShapeX = [...out.shape] as [number, number, number, number];
    padShapeX[1] = 1;
    const zerosW = tf.zeros<tf.Rank.R4>(padShapeX);
    out = tf.concat([out, zerosW], 1);
 
    const padShapeY = [...out.shape] as [number, number, number, number];
    padShapeY[2] = 1;
    const zerosH = tf.zeros<tf.Rank.R4>(padShapeY);
    out = tf.concat([out, zerosH], 2);
  }
 
  pooled = isPad ? tf.concat([pooled, zeros], 3) : pooled;
  out = tf.add(pooled, out) as tf.Tensor4D;
 
  out = tf.relu(out);
  return out;
}