RuntimeError: stack expects each tensor to be equal size ??


RuntimeError: stack expects each tensor to be equal size, but got [1200, 1200, 3] at entry 0 and [1200, 1344, 3] at entry 1

pytorch 数据处理错误, 网上的各种方法都试过了

1: 检查过数据的输入通道是3, 标签是1,但是输入的大小尺寸不同

2: 进行如下方法也不行!!

data_tf = transforms.Compose([transforms.Resize((1024,1024)), # transforms.CenterCrop(1020),# transforms.RandomHorizontalFlip(),transforms.ToTensor(),])

:3: bs=1,不报错,bs>1 报错

4: bug 未解决, 使用本地书记resize 处理

最终的解决之道!!!!!

出现的问题是collate_fn 参数 默认打包image, label,如果一个batch有多个输出,而且大小不一样,且使用了默认的collate_fn,则会报错
理论如下:

collate_fn 参数
当继承Dataset类自定义类时,__getitem__方法一般返回一组类似于(image,label)的一个样本,在创建DataLoader类的对象时,collate_fn函数会将batch_size个样本整理成一个batch样本,便于批量训练。

default_collate(batch)中的参数就是这里的 [self.dataset[i] for i in indices],indices是从所有样本的索引中选取的batch_size个索引,表示本次批量获取这些样本进行训练。self.dataset[i]就是自定义Dataset子类中__getitem__返回的结果。默认的函数default_collate(batch) 只能对大小相同image的batch_size个image整理,如[(img0, label0), (img1, label1),(img2, label2), ] 整理成([img0,img1,img2,], [label0,label1,label2,]), 这里要求多个img的size相同。所以在我们的图像大小不同时,需要自定义函数callate_fn来将batch个图像整理成统一大小的,若读取的数据有(img, box, label)这种你也需要自定义,因为默认只能处理(img,label)。当然你可以提前将数据集全部整理成统一大小的。
参考:原文链接:https://blog.csdn.net/Decennie/article/details/121000380

方法:重写collate_fn,不同大小的image要resize到同一个大小, 例子如下:
例子1:

def yolo_dataset_collate(batch):images = []bboxes = []for img, box in batch:images.append(img)bboxes.append(box)images = np.array(images)bboxes = np.array(bboxes)return images, bboxes

例子2:

# DataLoader中collate_fn使用, 数据大小保持一致def deeplab_dataset_collate(batch):images = []pngs = []seg_labels = []for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)images = np.array(images)pngs = np.array(pngs)seg_labels = np.array(seg_labels)return images, pngs, seg_labelstrain_dataset = DeeplabDataset(train_lines, inputs_size, NUM_CLASSES, True)val_dataset = DeeplabDataset(val_lines, inputs_size, NUM_CLASSES, False)gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=2, pin_memory=True,drop_last=True, collate_fn=deeplab_dataset_collate)gen_val = DataLoader(val_dataset, batch_size=Batch_size, num_workers=2,pin_memory=True, drop_last=True, collate_fn=deeplab_dataset_collate)
© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享