1. 程式人生 > 程式設計 >pytorch中交叉熵損失(nn.CrossEntropyLoss())的計算過程詳解

pytorch中交叉熵損失(nn.CrossEntropyLoss())的計算過程詳解

公式

首先需要了解CrossEntropyLoss的計算過程,交叉熵的函式是這樣的:

其中,其中yi表示真實的分類結果。這裡只給出公式,關於CrossEntropyLoss的其他詳細細節請參照其他博文。

測試程式碼(一維)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1,5,requires_grad=True)
label = torch.empty(1,dtype=torch.long).random_(5)
loss = criterion(output,label)

print("網路輸出為5類:")
print(output)
print("要計算label的類別:")
print(label)
print("計算loss的結果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的計算結果:")
print(res)

測試程式碼(多維)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3,requires_grad=True)
label = torch.empty(3,label)

print("網路輸出為3個5類:")
print(output)
print("要計算loss的類別:")
print(label)
print("計算loss的結果:")
print(loss)

first = [0,0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0,0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的計算結果:")
print(res/3)

nn.CrossEntropyLoss()中的計算方法

注意:在計算CrossEntropyLosss時,真實的label(一個標量)被處理成onehot編碼的形式。

在pytorch中,CrossEntropyLoss計算公式為:

CrossEntropyLoss帶權重的計算公式為(預設weight=None):

以上這篇pytorch中交叉熵損失(nn.CrossEntropyLoss())的計算過程詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。