1. 程式人生 > 程式設計 >深度學習入門之Pytorch 資料增強的實現

深度學習入門之Pytorch 資料增強的實現

資料增強

卷積神經網路非常容易出現過擬合的問題,而資料增強的方法是對抗過擬合問題的一個重要方法。

2012 年 AlexNet 在 ImageNet 上大獲全勝,圖片增強方法功不可沒,因為有了圖片增強,使得訓練的資料集比實際資料集多了很多'新'樣本,減少了過擬合的問題,下面我們來具體解釋一下。

常用的資料增強方法

常用的資料增強方法如下:
1.對圖片進行一定比例縮放
2.對圖片進行隨機位置的擷取
3.對圖片進行隨機的水平和豎直翻轉
4.對圖片進行隨機角度的旋轉
5.對圖片進行亮度、對比度和顏色的隨機變化

這些方法 pytorch 都已經為我們內建在了 torchvision 裡面,我們在安裝 pytorch 的時候也安裝了 torchvision,下面我們來依次展示一下這些資料增強方法。

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs

# 讀入一張圖片
im = Image.open('./cat.png')
im

深度學習入門之Pytorch 資料增強的實現

隨機比例放縮

隨機比例縮放主要使用的是 torchvision.transforms.Resize() 這個函式,第一個引數可以是一個整數,那麼圖片會儲存現在的寬和高的比例,並將更短的邊縮放到這個整數的大小,第一個引數也可以是一個 tuple,那麼圖片會直接把寬和高縮放到這個大小;第二個引數表示放縮圖片使用的方法,比如最鄰近法,或者雙線性差值等,一般雙線性差值能夠保留圖片更多的資訊,所以 pytorch 預設使用的是雙線性差值,你可以手動去改這個引數,更多的資訊可以看看文件

# 比例縮放
print('before scale,shape: {}'.format(im.size))
new_im = tfs.Resize((100,200))(im)
print('after scale,shape: {}'.format(new_im.size))
new_im

深度學習入門之Pytorch 資料增強的實現

隨機位置擷取

隨機位置擷取能夠提取出圖片中區域性的資訊,使得網路接受的輸入具有多尺度的特徵,所以能夠有較好的效果。在 torchvision 中主要有下面兩種方式,一個是 torchvision.transforms.RandomCrop(),傳入的引數就是截取出的圖片的長和寬,對圖片在隨機位置進行擷取;第二個是 torchvision.transforms.CenterCrop()

,同樣傳入介曲初的圖片的大小作為引數,會在圖片的中心進行擷取

# 隨機裁剪出 100 x 100 的區域
random_im1 = tfs.RandomCrop(100)(im)
random_im1

深度學習入門之Pytorch 資料增強的實現

# 中心裁剪出 100 x 100 的區域
center_im = tfs.CenterCrop(100)(im)
center_im

深度學習入門之Pytorch 資料增強的實現

隨機的水平和豎直方向翻轉

對於上面這一張貓的圖片,如果我們將它翻轉一下,它仍然是一張貓,但是圖片就有了更多的多樣性,所以隨機翻轉也是一種非常有效的手段。在 torchvision 中,隨機翻轉使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 隨機水平翻轉
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

深度學習入門之Pytorch 資料增強的實現

# 隨機豎直翻轉
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

深度學習入門之Pytorch 資料增強的實現

隨機角度旋轉

一些角度的旋轉仍然是非常有用的資料增強方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 來實現,其中第一個引數就是隨機旋轉的角度,比如填入 10,那麼每次圖片就會在 -10 ~ 10 度之間隨機旋轉

rot_im = tfs.RandomRotation(45)(im)
rot_im

深度學習入門之Pytorch 資料增強的實現

亮度、對比度和顏色的變化

除了形狀變化外,顏色變化又是另外一種增強方式,其中可以設定亮度變化,對比度變化和顏色變化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 來實現的,第一個引數就是亮度的比例,第二個是對比度,第三個是飽和度,第四個是顏色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) # 隨機從 0 ~ 2 之間亮度變化,1 表示原圖
bright_im

深度學習入門之Pytorch 資料增強的實現

# 對比度
contrast_im = tfs.ColorJitter(contrast=1)(im) # 隨機從 0 ~ 2 之間對比度變化,1 表示原圖
contrast_im

深度學習入門之Pytorch 資料增強的實現

# 顏色
color_im = tfs.ColorJitter(hue=0.5)(im) # 隨機從 -0.5 ~ 0.5 之間對顏色變化
color_im

深度學習入門之Pytorch 資料增強的實現

上面我們講了這麼圖片增強的方法,其實這些方法都不是孤立起來用的,可以聯合起來用,比如先做隨機翻轉,然後隨機擷取,再做對比度增強等等,torchvision 裡面有個非常方便的函式能夠將這些變化合起來,就是 torchvision.transforms.Compose(),下面我們舉個例子

im_aug = tfs.Compose([
  tfs.Resize(120),tfs.RandomHorizontalFlip(),tfs.RandomCrop(96),tfs.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8,8)
_,figs = plt.subplots(nrows,ncols,figsize=figsize)
for i in range(nrows):
  for j in range(ncols):
    figs[i][j].imshow(im_aug(im))
    figs[i][j].axes.get_xaxis().set_visible(False)
    figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

深度學習入門之Pytorch 資料增強的實現

可以看到每次做完增強之後的圖片都有一些變化,所以這就是我們前面講的,增加了一些'新'資料
下面我們使用影象增強進行訓練網路,看看具體的提升究竟在什麼地方,使用 ResNet 進行訓練

使用資料增強

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train,resnet
from torchvision import transforms as tfs
# 使用資料增強
def train_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(120),hue=0.5),tfs.ToTensor(),tfs.Normalize([0.5,0.5,0.5],[0.5,0.5])
  ])
  x = im_aug(x)
  return x

def test_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data',train=True,transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set,batch_size=64,shuffle=True)
test_set = CIFAR10('./data',train=False,transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set,batch_size=128,shuffle=False)

net = resnet(3,10)
optimizer = torch.optim.SGD(net.parameters(),lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net,train_data,test_data,10,optimizer,criterion)

深度學習入門之Pytorch 資料增強的實現

不使用資料增強

# 不使用資料增強
def data_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set,transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set,criterion)

深度學習入門之Pytorch 資料增強的實現

從上面可以看出,對於訓練集,不做資料增強跑 10 次,準確率已經到了 95%,而使用了資料增強,跑 10 次準確率只有 75%,說明資料增強之後變得更難了。

而對於測試集,使用資料增強進行訓練的時候,準確率會比不使用更高,因為資料增強提高了模型應對於更多的不同資料集的泛化能力,所以有更好的效果。

以上就是深度學習入門之Pytorch 資料增強的實現的詳細內容,更多關於Pytorch 資料增強的資料請關注我們其它相關文章!