省流

碰到这种问题,尤其是平常运行的好好的,换个数据集就报错,那大概率就是数据集本身有问题。顺着这个思路去debug即可。

问题描述

dataloader在设置num_workers为任何大于0的数时出现如下报错:

Traceback (most recent call last):File "/home/username/distort/main.py", line 131, in <module>model, perms, accs = train_model(dinfos, args.mid, args.pretrained, args.num_classes, args.treps, args.testep, args.test_dist, device, args.distort)File "/home/username/distort/main.py", line 65, in train_modelfor img, y in train_dataloader:File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__data = self._next_data()File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_datareturn self._process_data(data)File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_datadata.reraise()File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraiseraise exceptionRuntimeError: Caught RuntimeError in DataLoader worker process 0.Original Traceback (most recent call last):File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loopdata = fetcher.fetch(index)File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetchreturn self.collate_fn(data)File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collatereturn [default_collate(samples) for samples in transposed]# Backwards compatibility.File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>return [default_collate(samples) for samples in transposed]# Backwards compatibility.File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 140, in default_collateout = elem.new(storage).resize_(len(batch), *list(elem.size()))RuntimeError: Trying to resize storage that is not resizable

num_workers设置为0时则出现新的报错:

Traceback (most recent call last):File "/home/username/distort/main.py", line 130, in <module>model, perms, accs = train_model(dinfos, args.mid, args.pretrained, args.num_classes, args.treps, args.testep, args.test_dist, device, args.distort)File "/home/username/distort/main.py", line 64, in train_modelfor img, y in train_dataloader:File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__data = self._next_data()File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 721, in _next_datadata = self._dataset_fetcher.fetch(index)# may raise StopIterationFile "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetchreturn self.collate_fn(data)File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collatereturn [default_collate(samples) for samples in transposed]# Backwards compatibility.File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>return [default_collate(samples) for samples in transposed]# Backwards compatibility.File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 141, in default_collatereturn torch.stack(batch, 0, out=out)RuntimeError: stack expects each tensor to be equal size, but got [3, 64, 64] at entry 0 and [1, 64, 64] at entry 32

问题排查

第二个报错还是比较容易排查的。在自定义dataset类的__getitem__()函数中加入代码:当读取的tensor的shape[0]为1时打印该tensor对应原始数据文件的路径。

发现数据集中确实有通道数为1的图片(我用的tiny-imagenet-200),没想到真的是数据集的锅。

问题解决

在__getitem__()函数使用tensor类的expand,对于通道数不对的tensor,调用expand(3,-1,-1)即可。之后num_workers设置为0或者其他正数时都能正常加载数据集。

另外需要注意,有的博客说num_workers需要匹配GPU核心的数量,这逻辑属实离谱。从上面的第一个报错就能看出来,出错点和CUDA库毫无关系,因此不可能是GPU相关的问题。至少按照常用的加载数据集的方法,num_workers就是规定dataloader使用CPU线程的最大数量。