PyTorch学习笔记:data.RandomSampler——数据随机采样

torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)

功能:随即对样本进行采样

输入:

  • data_source:被采样的数据集合
  • replacement:采样策略,如果为True,则代表使用替换采样策略,即可重复对一个样本进行采样;如果为False,则表示不用替换采样策略,即一个样本最多只能被采一次
  • num_samples:所采样本的数量,默认采全部样本;当replacement规定为True时,可指定采样数量,即修改num_samples的大小;如果replacement设置为False,则该参数不可做修改
  • generator:采样过程中的生成器

代码案例

一般用法

from torch.utils.data import RandomSamplersampler = RandomSampler(range(20))print([i for i in sampler])

输出

这里相当于对原数据做了打乱操作

[7, 17, 8, 1, 13, 9, 6, 4, 12, 18, 19, 14, 10, 3, 2, 16, 5, 15, 0, 11]

replacement设为TrueFalse的区别

from torch.utils.data import RandomSamplersampler_t = RandomSampler(range(20), replacement=True)sampler_f = RandomSampler(range(20), replacement=False)sampler_t_8 = RandomSampler(range(20), replacement=True, num_samples=8)print('sampler_t:', [i for i in sampler_t])print('sampler_f:', [i for i in sampler_f])print('sampler_t_8:', [i for i in sampler_t_8])

输出

# replacement设为True时,会对同一样本多次采样sampler_t: [7, 3, 13, 17, 4, 5, 8, 18, 15, 8, 1, 3, 17, 4, 13, 13, 16, 14, 15, 11]# 否则一个样本只采样一次sampler_f: [3, 5, 19, 10, 6, 7, 13, 16, 15, 9, 14, 0, 4, 18, 12, 2, 11, 17, 1, 8]# replacement设为True时,可以规定采样数量,如这里只采8个sampler_t_8: [1, 9, 4, 5, 11, 18, 18, 4]

官方文档

torch.utils.data.RandomSampler:https://pytorch.org/docs/stable/data.html?highlight=randomsampler#torch.utils.data.RandomSampler

初步完稿于:2022年2月22日