gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import * as tf from '../../dist/tfjs.esm';
 
function IOU(boxes: tf.Tensor2D, i: number, j: number) {
  const boxesData = boxes.arraySync();
  const yminI = Math.min(boxesData[i][0], boxesData[i][2]);
  const xminI = Math.min(boxesData[i][1], boxesData[i][3]);
  const ymaxI = Math.max(boxesData[i][0], boxesData[i][2]);
  const xmaxI = Math.max(boxesData[i][1], boxesData[i][3]);
  const yminJ = Math.min(boxesData[j][0], boxesData[j][2]);
  const xminJ = Math.min(boxesData[j][1], boxesData[j][3]);
  const ymaxJ = Math.max(boxesData[j][0], boxesData[j][2]);
  const xmaxJ = Math.max(boxesData[j][1], boxesData[j][3]);
  const areaI = (ymaxI - yminI) * (xmaxI - xminI);
  const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
  if (areaI <= 0 || areaJ <= 0) return 0.0;
  const intersectionYmin = Math.max(yminI, yminJ);
  const intersectionXmin = Math.max(xminI, xminJ);
  const intersectionYmax = Math.min(ymaxI, ymaxJ);
  const intersectionXmax = Math.min(xmaxI, xmaxJ);
  const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * Math.max(intersectionXmax - intersectionXmin, 0.0);
  return intersectionArea / (areaI + areaJ - intersectionArea);
}
 
export function nonMaxSuppression(
  boxes: tf.Tensor2D,
  scores: number[],
  maxOutputSize: number,
  iouThreshold: number,
  scoreThreshold: number,
): number[] {
  const numBoxes = boxes.shape[0];
  const outputSize = Math.min(maxOutputSize, numBoxes);
 
  const candidates = scores
    .map((score, boxIndex) => ({ score, boxIndex }))
    .filter((c) => c.score > scoreThreshold)
    .sort((c1, c2) => c2.score - c1.score);
 
  const suppressFunc = (x: number) => (x <= iouThreshold ? 1 : 0);
  const selected: number[] = [];
 
  candidates.forEach((c) => {
    if (selected.length >= outputSize) return;
    const originalScore = c.score;
    for (let j = selected.length - 1; j >= 0; --j) {
      const iou = IOU(boxes, c.boxIndex, selected[j]);
      if (iou === 0.0) continue;
      c.score *= suppressFunc(iou);
      if (c.score <= scoreThreshold) break;
    }
    if (originalScore === c.score) {
      selected.push(c.boxIndex);
    }
  });
  return selected;
}