1. 程式人生 > 程式設計 >Pytorch 實現資料集自定義讀取

Pytorch 實現資料集自定義讀取

以讀取VOC2012語義分割資料集為例,具體見程式碼註釋:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC資料集分類對應顏色標籤
VOC_COLORMAP = [[0,0],[128,[0,128,128],[64,[192,64,192,128]]

#顏色標籤空間轉到序號標籤空間,就他媽這裡浪費巨量的時間,這裡還他媽的有問題
def voc_label_indices(colormap,colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:,:,2] * 256 + colormap[ :,1]) * 256+ colormap[:,0])
  #out = np.empty(idx.shape,dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#資料型別轉換
  end = time.time()
  return out

class MyDataset(data.Dataset):#建立自定義的資料讀取類
  def __init__(self,root,is_train,crop_size=(320,480)):
    self.rgb_mean =(0.485,0.456,0.406)
    self.rgb_std = (0.229,0.224,0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#建立空列表存檔名稱
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root,'train.txt' if is_train else 'val.txt')
    with open(txt_fname,'r') as f:
      self.images = f.read().split()
    #資料名稱整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root,"JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root,"SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,"label": label_file,"name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整個迴圈的意思就是將顏色標籤對映為單通道的陣列索引
    for i,cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引讀取每個元素的具體內容
  def __getitem__(self,index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#開啟的是PNG格式的圖片要轉到rgb的格式下,不然結果會比較要命
    #以影象中心為中心擷取固定大小影象,小於固定大小的影象則自動填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),transforms.ToTensor(),transforms.Normalize(self.rgb_mean,self.rgb_std),#影象資料正則化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把標籤資料型別轉為torch
    
    #將顏色標籤圖轉為序號標籤圖
    mylabel=voc_label_indices(croplabel,self.colormap2label)
    
    return cropImage,mylabel
  #返回影象資料長度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC資料集分類對應顏色標籤
VOC_COLORMAP = [[0,128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data,4)

#從資料集中拿出一個批次的資料
for i,data in enumerate(trainloader):
  getimgs,labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor轉numpy
  labels=labels[0]#獲得批次標籤集中的一張標籤影象
  labels = labels.transpose((1,0))#陣列維度切換,將第1維換到第0維,第0維換到第1維

  ##將單通道索引標籤圖片映射回顏色標籤圖片
  newIm= Image.new('RGB',(480,320))#建立一張與標籤大小相同的圖片,用以顯示標籤所對應的顏色
  for i in range(0,480):
    for j in range(0,320):
      sele=labels[i][j]#取得座標點對應畫素的值
      newIm.putpixel((i,j),(int(VOC_COLORMAP[sele][0]),int(VOC_COLORMAP[sele][1]),int(VOC_COLORMAP[sele][2])))

  #顯示影象和標籤
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上這篇Pytorch 實現資料集自定義讀取就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。