详解:yolov5中推理时置信度,设置的conf和iou_thres具体含义

一、模型输出解析:

设输出图片大小为1280,768,类别个数为2,则yolov5输出的三种特征图,其维度分别为:[1,3,96,160,7],[1,3,48,80,7],[1,3,24,40,7];相当于yolov5模型总共输出(96*160+48*80+24*40)*3=60480个目标框;

其中,[1,3,96,160,7] 中1指代输入图像个数为1,3指的是该尺度下的3种anchor,(96,160) 指的是特征图的尺寸,7具体指的是:(center_x,center_y, width, height, obj_conf, class_1_prob, class_2_prob ),即分别为box框中心点x,y,长和宽 width,height,以及该框存在目标的置信度obj_conf,类别1和类别2 的置信度,若class_1_prob > class_2_prob,则该框的类别为class1;因此,obj_conf和class_1_prob一个指得是该框存在目标的概率,一个指是该框分类为类别1的概率;

二、yolov5后处理解析;

从一可知模型输出了60480个目标框,因此,要经过NMS进行过滤,进NMS之前需要经过初筛(即将obj_conf小于我们设置的置信度的框去除),再计算每个box框的综合置信度conf:conf = obj_conf * max(class_1_prob ,class_2_prob),此时的conf是综合了obj_conf以及class_prob的综合概率;再经过进一步的过滤(即将conf小于我们设置的置信度的框去除),最后,将剩余的框通过NMS算法,得出最终的框;(NMS中用到了我们设置的iou_thres);

因此,最终我们可视化在box上方的置信度是综合了obj_conf以及class_prob的综合概率;

以下是yolov5中NMS源码,可查看细节:

def non_max_suppression(prediction,                        conf_thres=0.25,                        iou_thres=0.45,                        classes=None,                        agnostic=False,                        multi_label=False,                        labels=(),                        max_det=300):    """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes    Returns:         list of detections, on (n,6) tensor per image [xyxy, conf, cls]    """    bs = prediction.shape[0]  # batch size    nc = prediction.shape[2] - 5  # number of classes    xc = prediction[..., 4] > conf_thres  # candidates    # Checks    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'    assert 0 <= iou_thres  1  # multiple labels per box (adds 0.5ms/img)    merge = False  # use merge-NMS    t = time.time()    output = [torch.zeros((0, 6), device=prediction.device)] * bs    for xi, x in enumerate(prediction):  # image index, image inference        # Apply constraints        # x[((x[..., 2:4]  max_wh)).any(1), 4] = 0  # width-height        x = x[xc[xi]]  # confidence        # Cat apriori labels if autolabelling        if labels and len(labels[xi]):            lb = labels[xi]            v = torch.zeros((len(lb), nc + 5), device=x.device)            v[:, :4] = lb[:, 1:5]  # box            v[:, 4] = 1.0  # conf            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls            x = torch.cat((x, v), 0)        # If none remain process next image        if not x.shape[0]:            continue        # Compute conf        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf        # Box (center x, center y, width, height) to (x1, y1, x2, y2)        box = xywh2xyxy(x[:, :4])        # Detections matrix nx6 (xyxy, conf, cls)        if multi_label:            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)        else:  # best class only            conf, j = x[:, 5:].max(1, keepdim=True)            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]        # Filter by class        if classes is not None:            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]        # Apply finite constraint        # if not torch.isfinite(x).all():        #     x = x[torch.isfinite(x).all(1)]        # Check shape        n = x.shape[0]  # number of boxes        if not n:  # no boxes            continue        elif n > max_nms:  # excess boxes            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence        # Batched NMS        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS        if i.shape[0] > max_det:  # limit detections            i = i[:max_det]        if merge and (1 < n  iou_thres  # iou matrix            weights = iou * scores[None]  # box weights            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes            if redundant:                i = i[iou.sum(1) > 1]  # require redundancy        output[xi] = x[i]        if (time.time() - t) > time_limit:            LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')            break  # time limit exceeded    return output

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享