如下所示:
RuntimeError: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2
1
2
3
4
|
train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.Resize(( 224 )) ### |
原因是
transforms.Resize() 的参数设置问题,改为如下设置就可以了
1
2
3
4
|
train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.Resize(( 224 , 224 )), |
同理,val_dataset中也调整为transforms.Resize((224,224))。
补充:pytorch之dataloader深入剖析
- dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;
- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;
- 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;
- 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;
① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存
② Queue的特点
当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。
当数据满了: queue.put() 会阻塞
③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展
输入数据PipeLine
pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
1
2
3
4
5
6
|
dataset = MyDataset() dataloader = DataLoader(dataset) num_epoches = 100 for epoch in range (num_epoches): for img, label in dataloader: .... |
所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。
首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。
1.DataLoader
先介绍一下DataLoader(object)的参数:
dataset(Dataset)
: 传入的数据集
batch_size(int, optional)
: 每个batch有多少个样本
shuffle(bool, optional)
: 在每个epoch开始的时候,对数据进行重新排序
sampler(Sampler, optional)
: 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
batch_sampler(Sampler, optional)
: 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
num_workers (int, optional)
: 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
collate_fn (callable, optional)
: 将一个list的sample组成一个mini-batch的函数
pin_memory (bool, optional)
: 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
drop_last (bool, optional)
: 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
timeout(numeric, optional)
: 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
worker_init_fn (callable, optional)
: 每个worker初始化函数 If not None, this will be called on each
1
2
|
worker subprocess with the worker id (an int in [ 0 , num_workers - 1 ]) as input , after seeding and before data loading. (default: None ) |
- 首先dataloader初始化时得到datasets的采样list
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
class DataLoader( object ): r """ Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: 1). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False. batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0) worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None) .. note:: By default, each worker will have its PyTorch seed set to ``base_seed + worker_id``, where ``base_seed`` is a long generated by main process using its RNG. However, seeds for other libraies may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See :ref:`dataloader-workers-random-seed` section in FAQ.) You may use ``torch.initial_seed()`` to access the PyTorch seed for each worker in :attr:`worker_init_fn`, and use it to set other seeds before data loading. .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. """ __initialized = False def __init__( self , dataset, batch_size = 1 , shuffle = False , sampler = None , batch_sampler = None , num_workers = 0 , collate_fn = default_collate, pin_memory = False , drop_last = False , timeout = 0 , worker_init_fn = None ): self .dataset = dataset self .batch_size = batch_size self .num_workers = num_workers self .collate_fn = collate_fn self .pin_memory = pin_memory self .drop_last = drop_last self .timeout = timeout self .worker_init_fn = worker_init_fn if timeout < 0 : raise ValueError( 'timeout option should be non-negative' ) if batch_sampler is not None : if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError( 'batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last' ) self .batch_size = None self .drop_last = None if sampler is not None and shuffle: raise ValueError( 'sampler option is mutually exclusive with ' 'shuffle' ) if self .num_workers < 0 : raise ValueError( 'num_workers option cannot be negative; ' 'use num_workers=0 to disable multiprocessing.' ) if batch_sampler is None : if sampler is None : if shuffle: sampler = RandomSampler(dataset) / / 将 list 打乱 else : sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self .sampler = sampler self .batch_sampler = batch_sampler self .__initialized = True def __setattr__( self , attr, val): if self .__initialized and attr in ( 'batch_size' , 'sampler' , 'drop_last' ): raise ValueError( '{} attribute should not be set after {} is ' 'initialized' . format (attr, self .__class__.__name__)) super (DataLoader, self ).__setattr__(attr, val) def __iter__( self ): return _DataLoaderIter( self ) def __len__( self ): return len ( self .batch_sampler) |
其中:RandomSampler,BatchSampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!
1
2
3
4
5
6
7
8
9
10
11
|
class RandomSampler(Sampler): r """Samples elements randomly, without replacement. Arguments: data_source (Dataset): dataset to sample from """ def __init__( self , data_source): self .data_source = data_source def __iter__( self ): return iter (torch.randperm( len ( self .data_source)).tolist()) def __len__( self ): return len ( self .data_source) |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
class BatchSampler(Sampler): r """Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__( self , sampler, batch_size, drop_last): if not isinstance (sampler, Sampler): raise ValueError( "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}" . format (sampler)) if not isinstance (batch_size, _int_classes) or isinstance (batch_size, bool ) or \ batch_size < = 0 : raise ValueError( "batch_size should be a positive integeral value, " "but got batch_size={}" . format (batch_size)) if not isinstance (drop_last, bool ): raise ValueError( "drop_last should be a boolean value, but got " "drop_last={}" . format (drop_last)) self .sampler = sampler self .batch_size = batch_size self .drop_last = drop_last def __iter__( self ): batch = [] for idx in self .sampler: batch.append(idx) if len (batch) = = self .batch_size: yield batch batch = [] if len (batch) > 0 and not self .drop_last: yield batch def __len__( self ): if self .drop_last: return len ( self .sampler) / / self .batch_size else : return ( len ( self .sampler) + self .batch_size - 1 ) / / self .batch_size |
- 其中 _DataLoaderIter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;
- 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的
__getitem__()方法
- 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
|
class _DataLoaderIter( object ): r """Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__( self , loader): self .dataset = loader.dataset self .collate_fn = loader.collate_fn self .batch_sampler = loader.batch_sampler self .num_workers = loader.num_workers self .pin_memory = loader.pin_memory and torch.cuda.is_available() self .timeout = loader.timeout self .done_event = threading.Event() self .sample_iter = iter ( self .batch_sampler) base_seed = torch.LongTensor( 1 ).random_().item() if self .num_workers > 0 : self .worker_init_fn = loader.worker_init_fn self .index_queues = [multiprocessing.Queue() for _ in range ( self .num_workers)] self .worker_queue_idx = 0 self .worker_result_queue = multiprocessing.SimpleQueue() self .batches_outstanding = 0 self .worker_pids_set = False self .shutdown = False self .send_idx = 0 self .rcvd_idx = 0 self .reorder_dict = {} self .workers = [ multiprocessing.Process( target = _worker_loop, args = ( self .dataset, self .index_queues[i], self .worker_result_queue, self .collate_fn, base_seed + i, self .worker_init_fn, i)) for i in range ( self .num_workers)] if self .pin_memory or self .timeout > 0 : self .data_queue = queue.Queue() if self .pin_memory: maybe_device_id = torch.cuda.current_device() else : # do not initialize cuda context if not necessary maybe_device_id = None self .worker_manager_thread = threading.Thread( target = _worker_manager_loop, args = ( self .worker_result_queue, self .data_queue, self .done_event, self .pin_memory, maybe_device_id)) self .worker_manager_thread.daemon = True self .worker_manager_thread.start() else : self .data_queue = self .worker_result_queue for w in self .workers: w.daemon = True # ensure that the worker exits on process exit w.start() _update_worker_pids( id ( self ), tuple (w.pid for w in self .workers)) _set_SIGCHLD_handler() self .worker_pids_set = True # prime the prefetch loop for _ in range ( 2 * self .num_workers): self ._put_indices() def __len__( self ): return len ( self .batch_sampler) def _get_batch( self ): if self .timeout > 0 : try : return self .data_queue.get(timeout = self .timeout) except queue.Empty: raise RuntimeError( 'DataLoader timed out after {} seconds' . format ( self .timeout)) else : return self .data_queue.get() def __next__( self ): if self .num_workers = = 0 : # same-process loading indices = next ( self .sample_iter) # may raise StopIteration batch = self .collate_fn([ self .dataset[i] for i in indices]) if self .pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self .rcvd_idx in self .reorder_dict: batch = self .reorder_dict.pop( self .rcvd_idx) return self ._process_next_batch(batch) if self .batches_outstanding = = 0 : self ._shutdown_workers() raise StopIteration while True : assert ( not self .shutdown and self .batches_outstanding > 0 ) idx, batch = self ._get_batch() self .batches_outstanding - = 1 if idx ! = self .rcvd_idx: # store out-of-order samples self .reorder_dict[idx] = batch continue return self ._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__( self ): return self def _put_indices( self ): assert self .batches_outstanding < 2 * self .num_workers indices = next ( self .sample_iter, None ) if indices is None : return self .index_queues[ self .worker_queue_idx].put(( self .send_idx, indices)) self .worker_queue_idx = ( self .worker_queue_idx + 1 ) % self .num_workers self .batches_outstanding + = 1 self .send_idx + = 1 def _process_next_batch( self , batch): self .rcvd_idx + = 1 self ._put_indices() if isinstance (batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): global _use_shared_memory _use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads( 1 ) random.seed(seed) torch.manual_seed(seed) if init_fn is not None : init_fn(worker_id) watchdog = ManagerWatchdog() while True : try : r = index_queue.get(timeout = MANAGER_STATUS_CHECK_INTERVAL) except queue.Empty: if watchdog.is_alive(): continue else : break if r is None : break idx, batch_indices = r try : samples = collate_fn([dataset[i] for i in batch_indices]) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else : data_queue.put((idx, samples)) del samples |
- 需要对队列操作,缓存数据,使得加载提速!
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/hfw6310/article/details/106992968