aboutsummaryrefslogtreecommitdiffstats
path: root/anime-face-detector/nms_wrapper.py
blob: ca900e8968a28957540a6dfc309243e8052f11ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from enum import Enum


class NMSType(Enum):
    PY_NMS = 1
    CPU_NMS = 2
    GPU_NMS = 3


default_nms_type = NMSType.PY_NMS


class NMSWrapper:
    def __init__(self, nms_type=default_nms_type):
        assert type(nms_type) == NMSType
        if nms_type == NMSType.PY_NMS:
            from nms.py_cpu_nms import py_cpu_nms
            self._nms = py_cpu_nms
        elif nms_type == NMSType.CPU_NMS:
            from nms.cpu_nms import cpu_nms
            self._nms = cpu_nms
        elif nms_type == NMSType.GPU_NMS:
            from nms.gpu_nms import gpu_nms
            self._nms = gpu_nms
        else:
            raise ValueError('current nms type is not implemented yet')

    def __call__(self, *args, **kwargs):
        return self._nms(*args, **kwargs)