前言

在修改模型结构时,本来想着简单替换主干网络,用轻量级结构的替换原来的复杂模型,但是过程没想象中的顺利;其中比较关键的一点是两个主干网络输出的特征图类型不一致。

问题描述

主干网络A(轻量级),它输出特征图的类型是tuple,输出维度是[1, 3, 640, 640];

主干网络B(复杂的),它输出特征图的类型是torch.Tensor,输出维度也是[1, 3, 640, 640];

但是如果直接把主干网络B替换为主干网络A,后面接着原来的特征提取结构和任务头,会报错的。

tuple 转 torch.Tensor

把主干网络B替换为主干网络A后,加多一步操作,将输出特征图从tuple 转 torch.Tensor即可。

转换的基本思路是:使用 torch.cat( ) 把特征图进行拼接起来,通常是在维度 dim=0 进行拼接的。

A、当特征图的tuple数量为1

import torch# 假设模型输出的特征图为 feature_map, feature_map 是一个 tuple# 获取特征图个数num_maps = len(feature_map)# 打印原来的特征图信息print("type feature_raw:", type(outs))for out in feature_map:print(out.size())print("len feature_raw:", num_maps)# 按第 0 维度拼接特征图feature_map = torch.cat([fm for fm in feature_map], dim=0)# 检查特征图类型print("type feature_map:", type(feature_map))# 输出: # 检查特征图维度print("size feature_map:", feature_map.size())

示例输出:

type feature_raw:
torch.Size([8, 32, 640, 640])
len feature_raw: 1

type feature_map:
feature_map: torch.Size([8, 32, 640, 640])

B、当特征图的tuple数量为多个

如果主干网络输出的特征图类型为tuple,而且它包含多个特征图。我们想把它们变为一个torch.Tensor,可以使用torch.cat函数把它们拼接在一起。

import torch# 假设模型输出的特征图为 feature_map, feature_map 是一个 tuple# 获取特征图个数num_maps = len(feature_map)# 打印原来的特征图信息print("type feature_raw:", type(outs))for out in feature_map:print(out.size())print("len feature_raw:", num_maps)# 按第 0 维度拼接特征图feature_map = torch.cat([fm.unsqueeze(0) for fm in feature_map], dim=0)# 检查特征图类型print("type feature_map:", type(feature_map))# 输出: # 检查特征图维度print("size feature_map:", feature_map.size())

这样就可以将输出的特征图类型由tuple变为torch.Tensor了。拼接时,通过unsqueeze(0)把每个特征图在第0维度上增加一维,这样才能用torch.cat进行拼接。

示例输出:

type feature_raw:
torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])
len feature_raw: 1

type feature_map:
feature_map: torch.Size([4, 8, 32, 640, 640])

分享完成,欢迎交流~