Gym's Vector API vs SB3's Vector API

Created Diff never expires
12 removals
325 lines
12 additions
325 lines
import argparse
import argparse
import os
import os
import random
import random
import time
import time
from distutils.util import strtobool
from distutils.util import strtobool


import gym
import gym
import numpy as np
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.optim as optim
import torch.optim as optim
from gym.spaces import Discrete
from gym.spaces import Discrete
from gym.wrappers import Monitor
from gym.wrappers import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from gym.vector import SyncVectorEnv
from torch.distributions.categorical import Categorical
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter




def parse_args():
def parse_args():
# fmt: off
# fmt: off
parser = argparse.ArgumentParser(description='PPO agent')
parser = argparse.ArgumentParser(description='PPO agent')
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment')
help='the name of this experiment')
parser.add_argument('--gym-id', type=str, default="CartPole-v1",
parser.add_argument('--gym-id', type=str, default="CartPole-v1",
help='the id of the gym environment')
help='the id of the gym environment')
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
help='the learning rate of the optimizer')
help='the learning rate of the optimizer')
parser.add_argument('--seed', type=int, default=1,
parser.add_argument('--seed', type=int, default=1,
help='seed of the experiment')
help='seed of the experiment')
parser.add_argument('--total-timesteps', type=int, default=25000,
parser.add_argument('--total-timesteps', type=int, default=25000,
help='total timesteps of the experiments')
help='total timesteps of the experiments')
parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, `torch.backends.cudnn.deterministic=False`')
help='if toggled, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, cuda will not be enabled by default')
help='if toggled, cuda will not be enabled by default')
parser.add_argument('--prod-mode', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--prod-mode', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='run the script in production mode and use wandb to log outputs')
help='run the script in production mode and use wandb to log outputs')
parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='weather to capture videos of the agent performances (check out `videos` folder)')
help='weather to capture videos of the agent performances (check out `videos` folder)')
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
help="the wandb's project name")
help="the wandb's project name")
parser.add_argument('--wandb-entity', type=str, default=None,
parser.add_argument('--wandb-entity', type=str, default=None,
help="the entity (team) of wandb's project")
help="the entity (team) of wandb's project")


# Algorithm specific arguments
# Algorithm specific arguments
parser.add_argument('--n-minibatch', type=int, default=4,
parser.add_argument('--n-minibatch', type=int, default=4,
help='the number of mini batch')
help='the number of mini batch')
parser.add_argument('--num-envs', type=int, default=4,
parser.add_argument('--num-envs', type=int, default=4,
help='the number of parallel game environment')
help='the number of parallel game environment')
parser.add_argument('--num-steps', type=int, default=128,
parser.add_argument('--num-steps', type=int, default=128,
help='the number of steps per game environment')
help='the number of steps per game environment')
parser.add_argument('--gamma', type=float, default=0.99,
parser.add_argument('--gamma', type=float, default=0.99,
help='the discount factor gamma')
help='the discount factor gamma')
parser.add_argument('--gae-lambda', type=float, default=0.95,
parser.add_argument('--gae-lambda', type=float, default=0.95,
help='the lambda for the general advantage estimation')
help='the lambda for the general advantage estimation')
parser.add_argument('--ent-coef', type=float, default=0.01,
parser.add_argument('--ent-coef', type=float, default=0.01,
help="coefficient of the entropy")
help="coefficient of the entropy")
parser.add_argument('--vf-coef', type=float, default=0.5,
parser.add_argument('--vf-coef', type=float, default=0.5,
help="coefficient of the value function")
help="coefficient of the value function")
parser.add_argument('--max-grad-norm', type=float, default=0.5,
parser.add_argument('--max-grad-norm', type=float, default=0.5,
help='the maximum norm for the gradient clipping')
help='the maximum norm for the gradient clipping')
parser.add_argument('--clip-coef', type=float, default=0.2,
parser.add_argument('--clip-coef', type=float, default=0.2,
help="the surrogate clipping coefficient")
help="the surrogate clipping coefficient")
parser.add_argument('--update-epochs', type=int, default=4,
parser.add_argument('--update-epochs', type=int, default=4,
help="the K epochs to update the policy")
help="the K epochs to update the policy")
parser.add_argument('--kle-stop', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--kle-stop', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
parser.add_argument('--target-kl', type=float, default=0.03,
parser.add_argument('--target-kl', type=float, default=0.03,
help='the target-kl variable that is referred by --kl')
help='the target-kl variable that is referred by --kl')
parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='Use GAE for advantage computation')
help='Use GAE for advantage computation')
parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggles advantages normalization")
help="Toggles advantages normalization")
parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggle learning rate annealing for policy and value networks")
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
args = parser.parse_args()
args = parser.parse_args()
if not args.seed:
if not args.seed:
args.seed = int(time.time())
args.seed = int(time.time())
args.batch_size = int(args.num_envs * args.num_steps)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.n_minibatch)
args.minibatch_size = int(args.batch_size // args.n_minibatch)
# fmt: on
# fmt: on
return args
return args




def make_env(gym_id, seed, idx, capture_video, run_name):
def make_env(gym_id, seed, idx, capture_video, run_name):
def thunk():
def thunk():
env = gym.make(gym_id)
env = gym.make(gym_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if capture_video:
if idx == 0:
if idx == 0:
env = Monitor(env, f'videos/{run_name}')
env = Monitor(env, f'videos/{run_name}')
env.seed(seed)
env.seed(seed)
env.action_space.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
env.observation_space.seed(seed)
return env
return env
return thunk
return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
return layer




class Agent(nn.Module):
class Agent(nn.Module):
def __init__(self, envs, frames=4):
def __init__(self, envs, frames=4):
super(Agent, self).__init__()
super(Agent, self).__init__()
self.critic = nn.Sequential(
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.observation_space.shape).prod(), 64)),
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
nn.Tanh(),
layer_init(nn.Linear(64, 1), std=1.),
layer_init(nn.Linear(64, 1), std=1.),
)
)
self.actor = nn.Sequential(
self.actor = nn.Sequential(
layer_init(nn.Linear(np.array(envs.observation_space.shape).prod(), 64)),
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
nn.Tanh(),
layer_init(nn.Linear(64, envs.action_space.n), std=0.01),
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
)
)


def get_action_and_value(self, x, action=None):
def get_action_and_value(self, x, action=None):
logits = self.actor(x)
logits = self.actor(x)
probs = Categorical(logits=logits)
probs = Categorical(logits=logits)
if action is None:
if action is None:
action = probs.sample()
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
return action, probs.log_prob(action), probs.entropy(), self.critic(x)


def get_value(self, x):
def get_value(self, x):
return self.critic(x)
return self.critic(x)




if __name__ == "__main__":
if __name__ == "__main__":
args = parse_args()
args = parse_args()
run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.prod_mode:
if args.prod_mode:
import wandb
import wandb


wandb.init(
wandb.init(
project=args.wandb_project_name,
project=args.wandb_project_name,
entity=args.wandb_entity,
entity=args.wandb_entity,
sync_tensorboard=True,
sync_tensorboard=True,
config=vars(args),
config=vars(args),
name=run_name,
name=run_name,
monitor_gym=True,
monitor_gym=True,
save_code=True,
save_code=True,
)
)
writer = SummaryWriter(f"runs/{run_name}")
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
writer.add_text(
"hyperparameters",
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
)
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")


# TRY NOT TO MODIFY: seeding
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.backends.cudnn.deterministic = args.torch_deterministic


# env setup
# env setup
envs = DummyVecEnv(
envs = SyncVectorEnv(
[make_env(args.gym_id, args.seed + i, i, args.capture_video, run_name)
[make_env(args.gym_id, args.seed + i, i, args.capture_video, run_name)
for i in range(args.num_envs)])
for i in range(args.num_envs)])
assert isinstance(envs.action_space, Discrete), "only discrete action space is supported"
assert isinstance(envs.single_action_space, Discrete), "only discrete action space is supported"


agent = Agent(envs).to(device)
agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
if args.anneal_lr:
if args.anneal_lr:
# https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
# https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
lr = lambda f: f * args.learning_rate
lr = lambda f: f * args.learning_rate


# ALGO Logic: Storage for epoch data
# ALGO Logic: Storage for epoch data
obs = torch.zeros((args.num_steps, args.num_envs) + envs.observation_space.shape).to(device)
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.action_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)


# TRY NOT TO MODIFY: start the game
# TRY NOT TO MODIFY: start the game
global_step = 0
global_step = 0
start_time = time.time()
start_time = time.time()
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
next_obs = torch.Tensor(envs.reset()).to(device)
next_obs = torch.Tensor(envs.reset()).to(device)
next_done = torch.zeros(args.num_envs).to(device)
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size
num_updates = args.total_timesteps // args.batch_size
for update in range(1, num_updates + 1):
for update in range(1, num_updates + 1):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
if args.anneal_lr:
if args.anneal_lr:
frac = 1.0 - (update - 1.0) / num_updates
frac = 1.0 - (update - 1.0) / num_updates
lrnow = lr(frac)
lrnow = lr(frac)
optimizer.param_groups[0]["lr"] = lrnow
optimizer.param_groups[0]["lr"] = lrnow


# TRY NOT TO MODIFY: prepare the execution of the game.
# TRY NOT TO MODIFY: prepare the execution of the game.
for step in range(0, args.num_steps):
for step in range(0, args.num_steps):
global_step += 1 * args.num_envs
global_step += 1 * args.num_envs
obs[step] = next_obs
obs[step] = next_obs
dones[step] = next_done
dones[step] = next_done


# ALGO LOGIC: put action logic here
# ALGO LOGIC: put action logic here
with torch.no_grad():
with torch.no_grad():
action, logproba, _, vs = agent.get_action_and_value(next_obs)
action, logproba, _, vs = agent.get_action_and_value(next_obs)
values[step] = vs.flatten()
values[step] = vs.flatten()


actions[step] = action
actions[step] = action
logprobs[step] = logproba
logprobs[step] = logproba


# TRY NOT TO MODIFY: execute the game and log data.
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rs, ds, infos = envs.step(action.cpu().numpy())
next_obs, rs, ds, infos = envs.step(action.cpu().numpy())
next_obs = torch.Tensor(next_obs).to(device)
next_obs = torch.Tensor(next_obs).to(device)
rewards[step], next_done = torch.tensor(rs).to(device).view(-1), torch.Tensor(ds).to(device)
rewards[step], next_done = torch.tensor(rs).to(device).view(-1), torch.Tensor(ds).to(device)


for info in infos:
for info in infos:
if "episode" in info.keys():
if "episode" in info.keys():
print(f"global_step={global_step}, episode_reward={info['episode']['r']}")
print(f"global_step={global_step}, episode_reward={info['episode']['r']}")
writer.add_scalar("charts/episode_reward", info["episode"]["r"], global_step)
writer.add_scalar("charts/episode_reward", info["episode"]["r"], global_step)
break
break


# bootstrap reward if not done. reached the batch limit
# bootstrap reward if not done. reached the batch limit
with torch.no_grad():
with torch.no_grad():
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)
if args.gae:
if args.gae:
advantages = torch.zeros_like(rewards).to(device)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
lastgaelam = 0
for t in reversed(range(args.num_steps)):
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = 1.0 - next_done
nextvalues = last_value
nextvalues = last_value
else:
else:
nextnonterminal = 1.0 - dones[t + 1]
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
returns = advantages + values
else:
else:
returns = torch.zeros_like(rewards).to(device)
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = 1.0 - next_done
next_return = last_value
next_return = last_value
else:
else:
nextnonterminal = 1.0 - dones[t + 1]
nextnonterminal = 1.0 - dones[t + 1]
next_return = returns[t + 1]
next_return = returns[t + 1]
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values
advantages = returns - values


# flatten the batch
# flatten the batch
b_obs = obs.reshape((-1,) + envs.observation_space.shape)
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.action_space.shape)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_values = values.reshape(-1)


# Optimizaing the policy and value network
# Optimizaing the policy and value network
inds = np.arange(
inds = np.arange(
args.batch_size,
args.batch_size,
)
)
for i_epoch_pi in range(args.update_epochs):
for i_epoch_pi in range(args.update_epochs):
np.random.shuffle(inds)
np.random.shuffle(inds)
for start in range(0, args.batch_size, args.minibatch_size):
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
end = start + args.minibatch_size
minibatch_ind = inds[start:end]
minibatch_ind = inds[start:end]
mb_advantages = b_advantages[minibatch_ind]
mb_advantages = b_advantages[minibatch_ind]
if args.norm_adv:
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)


_, newlogproba, entropy, new_values = agent.get_action_and_value(
_, newlogproba, entropy, new_values = agent.get_action_and_value(
b_obs[minibatch_ind], b_actions.long()[minibatch_ind].to(device)
b_obs[minibatch_ind], b_actions.long()[minibatch_ind].to(device)
)
)
ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()
ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()


# calculate approx_kl http://joschu.net/blog/kl-approx.html
# calculate approx_kl http://joschu.net/blog/kl-approx.html
with torch.no_grad():
with torch.no_grad():
log_ratio = newlogproba - b_logprobs[minibatch_ind]
log_ratio = newlogproba - b_logprobs[minibatch_ind]
approx_kl = ((log_ratio.exp() - 1) - log_ratio).mean()
approx_kl = ((log_ratio.exp() - 1) - log_ratio).mean()


# Policy loss
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
entropy_loss = entropy.mean()
entropy_loss = entropy.mean()


# Value loss
# Value loss
new_values = new_values.view(-1)
new_values = new_values.view(-1)
if args.clip_vloss:
if args.clip_vloss:
v_loss_unclipped = (new_values - b_returns[minibatch_ind]) ** 2
v_loss_unclipped = (new_values - b_returns[minibatch_ind]) ** 2
v_clipped = b_values[minibatch_ind] + torch.clamp(
v_clipped = b_values[minibatch_ind] + torch.clamp(
new_values - b_values[minibatch_ind],
new_values - b_values[minibatch_ind],
-args.clip_coef,
-args.clip_coef,
args.clip_coef,
args.clip_coef,
)
)
v_loss_clipped = (v_clipped - b_returns[minibatch_ind]) ** 2
v_loss_clipped = (v_clipped - b_returns[minibatch_ind]) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
v_loss = 0.5 * v_loss_max.mean()
else:
else:
v_loss = 0.5 * ((new_values - b_returns[minibatch_ind]) ** 2).mean()
v_loss = 0.5 * ((new_values - b_returns[minibatch_ind]) ** 2).mean()


loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef


optimizer.zero_grad()
optimizer.zero_grad()
loss.backward()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.step()


if args.kle_stop:
if args.kle_stop:
if approx_kl > args.target_kl:
if approx_kl > args.target_kl:
break
break


# TRY NOT TO MODIFY: record rewards for plotting purposes
# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy.mean().item(), global_step)
writer.add_scalar("losses/entropy", entropy.mean().item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)


envs.close()
envs.close()
writer.close()
writer.close()