1. 程式人生 > 實用技巧 >強化學習 5 —— SARSA 和 Q-Learning演算法程式碼實現

強化學習 5 —— SARSA 和 Q-Learning演算法程式碼實現

上篇文章 強化學習——時序差分 (TD) --- SARSA and Q-Learning 我們介紹了時序差分TD演算法解決強化學習的評估和控制問題,TD對比MC有很多優勢,比如TD有更低方差,可以學習不完整的序列。所以我們可以在策略控制迴圈中使用TD來代替MC。優於TD演算法的諸多優點,因此現在主流的強化學習求解方法都是基於TD的。這篇文章會使用就用程式碼實現 SARSA 和 Q-Learning 這兩種演算法。

一、演算法介紹

關於SARSA 和 Q-Learning演算法的詳細介紹,本篇部落格不做過多介紹,若不熟悉可點選文章開頭連結檢視。

Sarsa 和 QLearning 時序差分TD解決強化學習控制問題的兩種演算法,兩者非常相似,從更新公式就能看出來:

  • SARSA:

\[A(S_t, A_t) \leftarrow A(S_t, A_t) + \alpha \left[R_{t+1} + \gamma Q(S_{t+1}, A_{t+1}) - A(S_t, A_t)\right] \]

  • Q-Learning

\[Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha[R_{t+1} + \gamma \; max_aQ(S_{t+1}, a) - Q(S_t, A_t)] \]

可以看出來,兩者的區別就在計算 TD-Target 的時候,下一個動作 a' 是如何選取的

對於 Sarsa 來說:

  • 1)在狀態 s' 時,就知道了要採取那個動作 a',並且真的採取了這個動作
  • 2)當前動作 a 和下一個動作 a' 都是 根據 \(\epsilon\) -貪婪策略選取的,因此稱為on-policy學習

對於 Q-Learning:

  • 1)在狀態s'時,只是計算了 在 s' 時要採取哪個 a' 可以得到更大的 Q 值,並沒有真的採取這個動作 a'。
  • 2)動作 a 的選取是根據當前 Q 網路以及 \(\epsilon\)-貪婪策略,即每一步都會根據當前的狀況選擇一個動作A,目標Q值的計算是根據 Q 值最大的動作 a' 計算得來,因此為 off-policy 學習。

二、程式碼

1、SARSA

定義 SARSA agent 類,

class Sarsa:
    def __init__(self, state_dim, action_dim, lr=0.01, gamma=0.9, e_greed=0.1):
        self.action_dim = action_dim
        self.lr = lr
        self.gamma = gamma
        self.epsilon = e_greed
        self.Q = np.zeros((state_dim, action_dim))

    def sample(self, state):
        """
        使用 epsilon 貪婪策略獲取動作
        return: action
        """
        if np.random.uniform() < self.epsilon:
            action = np.random.choice(self.action_dim)
        else: action = self.predict(state)
        return action

    def predict(self, state):
        """ 根據輸入觀察值,預測輸出的動作值 """
        all_actions = self.Q[state, :]
        max_action = np.max(all_actions)
        # 防止最大的 Q 值有多個,找出所有最大的 Q,然後再隨機選擇
        # where函式返回一個 array, 每個元素為下標
        max_action_list = np.where(all_actions == max_action)[0]
        action = np.random.choice(max_action_list)
        return action

    def learn(self, state, action, reward, next_state, next_action, done):
        """
        更新 Q-table 方法
        next_action 就是下一步選的動作,所以直接用 self.Q[next_state, next_action]
        然後計算 td-target,然後更新 Q-table
        """
        if done: target_q = reward
        else:
            target_q = reward + self.gamma * self.Q[next_state, next_action]
        self.Q[state, action] += self.lr * (target_q - self.Q[state, action])

上面程式碼重點是 learn() 方法中的 Q-table 的更新,結合公式還是比較容易理解的。下面是每一個 episode 的流程:對於一個 episode 先呼叫 reset() 方法獲得初始化狀態state,然後選擇當前的動作 action ,使用當前的動作讓環境執行一步,獲取到下一個狀態 next_state 以及獎勵 reward ,然後利用這些資料進行更新Q表格,注意 更新之後要把下一個狀態和動作賦值給當前的狀態和動作,然後迴圈。

def run_episode(self, render=False):
    state = self.env.reset()
    action = self.model.sample(state)
    while True:
        next_state, reward, done, _ = self.env.step(action)
        next_action = self.model.sample(next_state)
        # 訓練 Q-learning演算法
        self.model.learn(state, action, reward, next_state, next_action, done)
        state = next_state
        action = next_action
        if render: self.env.render()
        if done: break

完整程式碼見強化學習——SARSA 演算法 ,勞煩大人點個 star 可好?

2、Q-Learning

由上可知,Q-Learning 和 SARSA 演算法很相似,程式碼幾乎相同,下面就展示下與 SARSA 演算法不同的部分

class QLearning:
    # ...
    # 其他方法見 SARSA 部分
    def learn(self, state, action, reward, next_state, done):
        """
        Q-Learning 更新 Q-table 方法
        這裡沒有明確選擇下一個動作 next_action, 而是選擇 next_state 下有最大價值的動作
        所以用 np.max(self.Q[next_state, :]) 來計算 td-target
        然後更新 Q-table
        """
        if done:
            target_q = reward
        else:
            target_q = reward + self.gamma * np.max(self.Q[next_state, :])
        self.Q[state, action] += self.lr * (target_q - self.Q[state, action])

對於 Q-Learning 的演算法流程部分 ,和 SARSA 也有些細微區別:在Q-Learning 中的 learn() 方法不需要傳入 next_action 引數,因為在計算td-target 時只是查看了一下下一個狀態的所有動作價值,並選擇一個最優動作讓環境去執行。還請仔細區分兩者的不同:

def run_episode(self, render=False):
    state = self.env.reset()
    while True:
        action = self.model.sample(state)
        next_state, reward, done, _ = self.env.step(action)
        # 訓練 Q-learning演算法
        self.model.learn(state, action, reward, next_state, done)
        
        state = next_state
        if render: self.env.render()
        if done: break

完整程式碼見強化學習——Q-Learning 演算法,勞煩大人點個 star 可好?