1. 程式人生 > 程式設計 >一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關係

一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關係

以下內容都是針對Pytorch 1.0-1.1介紹。

很多文章都是從Dataset等物件自下往上進行介紹,但是對於初學者而言,其實這並不好理解,因為有的時候會不自覺地陷入到一些細枝末節中去,而不能把握重點,所以本文將會自上而下地對Pytorch資料讀取方法進行介紹。

自上而下理解三者關係

首先我們看一下DataLoader.next的原始碼長什麼樣,為方便理解我只選取了num_works為0的情況(num_works簡單理解就是能夠並行化地讀取資料)。

class DataLoader(object):
	...
	
  def __next__(self):
    if self.num_workers == 0: 
      indices = next(self.sample_iter) # Sampler
      batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
      if self.pin_memory:
        batch = _utils.pin_memory.pin_memory_batch(batch)
      return batch

在閱讀上面程式碼前,我們可以假設我們的資料是一組影象,每一張影象對應一個index,那麼如果我們要讀取資料就只需要對應的index即可,即上面程式碼中的indices,而選取index的方式有多種,有按順序的,也有亂序的,所以這個工作需要Sampler完成,現在你不需要具體的細節,後面會介紹,你只需要知道DataLoader和Sampler在這裡產生關係。

那麼Dataset和DataLoader在什麼時候產生關係呢?沒錯就是下面一行。我們已經拿到了indices,那麼下一步我們只需要根據index對資料進行讀取即可了。

再下面的if語句的作用簡單理解就是,如果pin_memory=True

,那麼Pytorch會採取一系列操作把資料拷貝到GPU,總之就是為了加速。

綜上可以知道DataLoader,Sampler和Dataset三者關係如下:

一文弄懂Pytorch的DataLoader,DataSet,Sampler之間的關係

在閱讀後文的過程中,你始終需要將上面的關係記在心裡,這樣能幫助你更好地理解。

Sampler

引數傳遞

要更加細緻地理解Sampler原理,我們需要先閱讀一下DataLoader 的原始碼,如下:

class DataLoader(object):
  def __init__(self,dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)

可以看到初始化引數裡有兩種sampler:samplerbatch_sampler,都預設為None。前者的作用是生成一系列的index,而batch_sampler則是將sampler生成的indices打包分組,得到一個又一個batch的index。例如下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分組。

>>>in : list(BatchSampler(SequentialSampler(range(10)),batch_size=3,drop_last=False))
>>>out: [[0,1,2],[3,4,5],[6,7,8],[9]]

Pytorch中已經實現的Sampler有如下幾種:

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

需要注意的是DataLoader的部分初始化引數之間存在互斥關係,這個你可以通過閱讀原始碼更深地理解,這裡只做總結:

  • 如果你自定義了batch_sampler,那麼這些引數都必須使用預設值:batch_size,shuffle,sampler,drop_last.
  • 如果你自定義了sampler,那麼shuffle需要設定為False
  • 如果sampler和batch_sampler都為None,那麼batch_sampler使用Pytorch已經實現好的BatchSampler,而sampler分兩種情況:
    • 若shuffle=True,則sampler=RandomSampler(dataset)
    • 若shuffle=False,則sampler=SequentialSampler(dataset)

如何自定義Sampler和BatchSampler?

仔細檢視原始碼其實可以發現,所有采樣器其實都繼承自同一個父類,即Sampler,其程式碼定義如下:

class Sampler(object):
  r"""Base class for all Samplers.
  Every Sampler subclass has to provide an :meth:`__iter__` method,providing a
  way to iterate over indices of dataset elements,and a :meth:`__len__` method
  that returns the length of the returned iterators.
  .. note:: The :meth:`__len__` method isn't strictly required by
       :class:`~torch.utils.data.DataLoader`,but is expected in any
       calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
  """

  def __init__(self,data_source):
    pass

  def __iter__(self):
    raise NotImplementedError
		
  def __len__(self):
    return len(self.data_source)

所以你要做的就是定義好__iter__(self)函式,不過要注意的是該函式的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))

另外BatchSampler與其他Sampler的主要區別是它需要將Sampler作為引數進行打包,進而每次迭代返回以batch size為大小的index列表。也就是說在後面的讀取資料過程中使用的都是batch sampler。

Dataset

Dataset定義方式如下:

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self,index):
		return ...
	
	def __len__(self):
		return ...

上面三個方法是最基本的,其中__getitem__是最主要的方法,它規定了如何讀取資料。但是它又不同於一般的方法,因為它是python built-in方法,其主要作用是能讓該類可以像list一樣通過索引值對資料進行訪問。假如你定義好了一個dataset,那麼你可以直接通過dataset[0]來訪問第一個資料。在此之前我一直沒弄清楚__getitem__是什麼作用,所以一直不知道該怎麼進入到這個函式進行除錯。現在如果你想對__getitem__方法進行除錯,你可以寫一個for迴圈遍歷dataset來進行除錯了,而不用構建dataloader等一大堆東西了,建議學會使用ipdb這個庫,非常實用!!!以後有時間再寫一篇ipdb的使用教程。另外,其實我們通過最前面的Dataloader的__next__函式可以看到DataLoader對資料的讀取其實就是用了for迴圈來遍歷資料,不用往上翻了,我直接複製了一遍,如下:

class DataLoader(object): 
  ... 
   
  def __next__(self): 
    if self.num_workers == 0:  
      indices = next(self.sample_iter) 
      batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
      if self.pin_memory: 
        batch = _utils.pin_memory.pin_memory_batch(batch) 
      return batch

我們仔細看可以發現,前面還有一個self.collate_fn方法,這個是幹嘛用的呢?在介紹前我們需要知道每個引數的意義:

  • indices: 表示每一個iteration,sampler返回的indices,即一個batch size大小的索引列表
  • self.dataset[i]: 前面已經介紹了,這裡就是對第i個數據進行讀取操作,一般來說self.dataset[i]=(img,label)

看到這不難猜出collate_fn的作用就是將一個batch的資料進行合併操作。預設的collate_fn是將img和label分別合併成imgs和labels,所以如果你的__getitem__方法只是返回 img,label,那麼你可以使用預設的collate_fn方法,但是如果你每次讀取的資料有img,box,label等等,那麼你就需要自定義collate_fn來將對應的資料合併成一個batch資料,這樣方便後續的訓練步驟。

到此這篇關於一文弄懂Pytorch的DataLoader,Sampler之間的關係的文章就介紹到這了,更多相關Pytorch DataLoader DataSet Sampler內容請搜尋我們以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援我們!