Reinforcement learning advantage actor critic loss explodes

Hello, this is a repost from the pytorch Forums, im sorry If this doesnt exactly fit the category but i really need some help…

Hello I have some problems with implementing the Actor Critic Policy Gradient Algorithm,

When I implement REINFORCE like this, everything is okay:

class Estimator(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.num_actions = num_actions
        self.dense_1 = nn.Linear(4, 32)
        self.out = nn.Linear(32, num_actions)

    def forward(self,x):
        x =self.dense_1(x)
        x = F.softmax(self.out(x))
        return x
    
env = gym.make("CartPole-v0")
estimator = Estimator(2)
estimator.cuda()
opt = optim.Adam(estimator.parameters())
loss = []
running_reward = 0
for i in range(100000): # number episodes
    episode = []
    chosen_actions = []
    rewards = []
    done = False
    state = env.reset()
                  
    while not done:
        probs = estimator(Variable(torch.unsqueeze(torch.from_numpy(state),0).float().cuda())) # calculate the probs of choosing actions
        action = probs.multinomial()
        chosen_actions.append(action)
        next_state, reward, done, _ = env.step(action.data[0,0])
        rewards.append(reward)
        state = next_state
       
    
    R = 0
    for r in rewards[::-1]:
        R = r +  R
        rewards.insert(0, R)
        
    for action, r in zip(chosen_actions, rewards):
        action.reinforce(r)
        
    opt.zero_grad()
    autograd.backward(chosen_actions, [None for _ in chosen_actions])
    opt.step()
    running_reward = running_reward * 0.99 + len(chosen_actions) * 0.01
    if (i+1) % 10 == 0:
        print("Episode: {} Running Reward: {}".format(i+1,round(running_reward,2)))

But if I try to implement the actor critic with the using the Generalized Advantage Estimator the Algorithm fails.

Two things are happening:

  1. The policy is learning to ALWAYS choose one off the actions (the probability is approaching 1)
  2. the loss (and the output) of the state value estimator explode.

I checked the implementation of the policy gradient by pretraining a state value estimator using TD and then plugging it into the code. That works just fine, so I suspect I have made some error implementing the state value updates.

Here is the code:

(I tried a lot of different hyperparamenters like learning rate of both optimizers, the number of state value estimator updates per timestep and also starting the policy gradient algorithm after a short amount of time where the state value predictions are already learning, unfortunatly nothing has worked…)

class Estimator(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.num_actions = num_actions
        self.dense_1 = nn.Linear(4, 32)
        self.out = nn.Linear(32, num_actions)

    def forward(self,x):
        x =self.dense_1(x)
        x = F.softmax(self.out(x))
        return x

class V_Estimator(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense_1 = nn.Linear(4, 32)
        self.out = nn.Linear(32, 1)

    def forward(self,x):
        x = F.relu(self.dense_1(x))
        x = self.out(x)
        return x
    
    
estimator = Estimator(2)
estimator.cuda()
v_estimator = V_Estimator()
v_estimator.cuda()
opt = optim.Adam(estimator.parameters(), lr=0.0001)
v_opt = optim.Adam(v_estimator.parameters(), lr=0.0001)
env = gym.make("CartPole-v0")
mse = nn.MSELoss()
buffer = ReplayBuffer(100000)
running_reward = 0
for i in range(10000): # number episodes
    episode_len = 0
    done = False
    state = env.reset()
    

    while not done:
        episode_len += 1
        state_python = state
        state = Variable(torch.unsqueeze(torch.from_numpy(state),0).float().cuda())
        probs = estimator(state) 
        #print(probs.data.cpu().numpy()) # one of the action probabilites just approaches 1
        action = probs.multinomial()

        action_python = action.data[0,0]
        v_estimate_curr = v_estimator(state)
        #v_estimate_curr = v_estimate(state)
        #print(v_estimate_curr)
        next_state, reward, done, _ = env.step(action_python)
        v_estimate_next = v_estimator(Variable(torch.unsqueeze(torch.from_numpy(next_state),0).float().cuda()))
        #v_estimate_next = v_estimate(Variable(torch.unsqueeze(torch.from_numpy(next_state),0).float().cuda()))
        #print(v_estimate_next)
        
        td_error = reward + v_estimate_next - v_estimate_curr
        
        
        buffer.add(state_python, action_python, reward, done, next_state)
        state = next_state
        
        #refit v-estimator
        average_state_value_loss = 0
        state_value_updates = 30
        for j in range(state_value_updates):
            s_batch, a_batch, r_batch, d_batch, s2_batch  = buffer.sample_batch(128)
            #print("s_batch shape: {}".format(s_batch.shape))
            targ = v_estimator(Variable(torch.from_numpy(s2_batch)).float().cuda())
            #print("targ shape: {}".format(targ.data.cpu().numpy().shape))
            torch_rew_batch = Variable(torch.unsqueeze(torch.from_numpy(r_batch).float().cuda(),-1))
            #print("torch_rew_batch shape: {}".format(torch_rew_batch.data.cpu().numpy().shape))
            targ = targ + torch_rew_batch
            targ = targ.detach()
            targ.requires_grad = False
            #print("targ shape: {}".format(targ.data.cpu().numpy().shape))
            out = v_estimator(Variable(torch.from_numpy(s_batch)).float().cuda())
            #print("out shape: {}".format(out.data.cpu().numpy().shape))
            v_loss = mse(out, targ)
            average_state_value_loss += v_loss.data[0] / state_value_updates
            
            v_opt.zero_grad()
            v_loss.backward()
            v_opt.step()

        # update policy gradient
        #if i > 100: # starting after 100 episodes to give the state value nn some time to learn
        opt.zero_grad()
        action.reinforce(td_error.data)
        action.backward()
        opt.step()
    running_reward = running_reward * 0.9 + episode_len * 0.1
    print("current episode: " + str(i)+ " - running reward: " + str(round(running_reward,2)) + " - average state value estimator loss: {}".format(average_state_value_loss))

I looked at the implementation in the pytorch examples repo but they do things a little differently (like sharing policy parameters)

If anybody has any idea on how to fix the error I would greatly appreciate it

Johannes

PS: for reproducebility execute this first, then both code samples run:

import torch
from torch.autograd import Variable
import gym
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from torch import autograd
from torch import optim

from collections import deque
import random
import numpy as np

class ReplayBuffer(object):

    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.count = 0
        self.buffer = deque()

    def add(self, s, a, r, d, s2):
        experience = (s, a, r, d, s2)
        if self.count < self.buffer_size:
            self.buffer.append(experience)
            self.count += 1
        else:
            self.buffer.popleft()
            self.buffer.append(experience)

    def size(self):
        return self.count

    def sample_batch(self, batch_size):
        '''
        batch_size specifies the number of experiences to add
        to the batch. If the replay buffer has less than batch_size
        elements, simply return all of the elements within the buffer.
        '''

        if self.count < batch_size:
            batch = random.sample(self.buffer, self.count)
        else:
            batch = random.sample(self.buffer, batch_size)

        s_batch = np.array([np.array(_[0]) for _ in batch])
        a_batch = np.array([_[1] for _ in batch])
        r_batch = np.array([_[2] for _ in batch])
        d_batch = np.array([_[3] for _ in batch])
        s2_batch = np.array([np.array(_[4]) for _ in batch])

        return s_batch, a_batch, r_batch, d_batch, s2_batch

    def clear(self):
        self.buffer.clear()
        self.count = 0