Implementation of Invalid Action Masking

-8 Removals
+22 Additions
import torchimport torch
import torch.nn as nnimport torch.nn as nn
import torch.optim as optimimport torch.optim as optim
import torch.nn.functional as Fimport torch.nn.functional as F
from torch.distributions.categorical import Categoricalfrom torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriterfrom torch.utils.tensorboard import SummaryWriter
from cleanrl.common import preprocess_obs_space, preprocess_ac_spacefrom cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparseimport argparse
import numpy as npimport numpy as np
import gymimport gym
import gym_micrortsimport gym_microrts
from gym.wrappers import TimeLimit, Monitorfrom gym.wrappers import TimeLimit, Monitor
import pybullet_envsimport pybullet_envs
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Spacefrom gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import timeimport time
import randomimport random
import osimport os
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):class RunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=()): def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, 'float64') self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64') self.var = np.ones(shape, 'float64')
self.count = epsilon self.count = epsilon
def update(self, x): def update(self, x):
batch_mean = np.mean([x], axis=0) batch_mean = np.mean([x], axis=0)
batch_var = np.var([x], axis=0) batch_var = np.var([x], axis=0)
batch_count = 1 batch_count = 1
self.update_from_moments(batch_mean, batch_var, batch_count) self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count): def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments( self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count) self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
delta = batch_mean - mean delta = batch_mean - mean
tot_count = count + batch_count tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count new_mean = mean + delta * batch_count / tot_count
m_a = var * count m_a = var * count
m_b = batch_var * batch_count m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count new_var = M2 / tot_count
new_count = tot_count new_count = tot_count
return new_mean, new_var, new_count return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):class NormalizedEnv(gym.core.Wrapper):
def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8): def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
super(NormalizedEnv, self).__init__(env) super(NormalizedEnv, self).__init__(env)
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
self.clipob = clipob self.clipob = clipob
self.cliprew = cliprew self.cliprew = cliprew
self.ret = np.zeros(()) self.ret = np.zeros(())
self.gamma = gamma self.gamma = gamma
self.epsilon = epsilon self.epsilon = epsilon
def step(self, action): def step(self, action):
obs, rews, news, infos = self.env.step(action) obs, rews, news, infos = self.env.step(action)
infos['real_reward'] = rews infos['real_reward'] = rews
# print("before", self.ret) # print("before", self.ret)
self.ret = self.ret * self.gamma + rews self.ret = self.ret * self.gamma + rews
# print("after", self.ret) # print("after", self.ret)
obs = self._obfilt(obs) obs = self._obfilt(obs)
if self.ret_rms: if self.ret_rms:
self.ret_rms.update(np.array([self.ret].copy())) self.ret_rms.update(np.array([self.ret].copy()))
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew) rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
self.ret = self.ret * (1-float(news)) self.ret = self.ret * (1-float(news))
return obs, rews, news, infos return obs, rews, news, infos
def _obfilt(self, obs): def _obfilt(self, obs):
if self.ob_rms: if self.ob_rms:
self.ob_rms.update(obs) self.ob_rms.update(obs)
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs return obs
else: else:
return obs return obs
def reset(self): def reset(self):
self.ret = np.zeros(()) self.ret = np.zeros(())
obs = self.env.reset() obs = self.env.reset()
return self._obfilt(obs) return self._obfilt(obs)
if __name__ == "__main__":if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PPO agent') parser = argparse.ArgumentParser(description='PPO agent')
# Common arguments # Common arguments
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"), parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment') help='the name of this experiment')
parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0", parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
help='the id of the gym environment') help='the id of the gym environment')
parser.add_argument('--seed', type=int, default=1, parser.add_argument('--seed', type=int, default=1,
help='seed of the experiment') help='seed of the experiment')
parser.add_argument('--episode-length', type=int, default=0, parser.add_argument('--episode-length', type=int, default=0,
help='the maximum length of each episode') help='the maximum length of each episode')
parser.add_argument('--total-timesteps', type=int, default=100000, parser.add_argument('--total-timesteps', type=int, default=100000,
help='total timesteps of the experiments') help='total timesteps of the experiments')
parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True, parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
help='if toggled, `torch.backends.cudnn.deterministic=False`') help='if toggled, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True, parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
help='if toggled, cuda will not be enabled by default') help='if toggled, cuda will not be enabled by default')
parser.add_argument('--prod-mode', action='store_true', default=False, parser.add_argument('--prod-mode', action='store_true', default=False,
help='run the script in production mode and use wandb to log outputs') help='run the script in production mode and use wandb to log outputs')
parser.add_argument('--capture-video', action='store_true', default=False, parser.add_argument('--capture-video', action='store_true', default=False,
help='weather to capture videos of the agent performances (check out `videos` folder)') help='weather to capture videos of the agent performances (check out `videos` folder)')
parser.add_argument('--wandb-project-name', type=str, default="cleanRL", parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
help="the wandb's project name") help="the wandb's project name")
parser.add_argument('--wandb-entity', type=str, default=None, parser.add_argument('--wandb-entity', type=str, default=None,
help="the entity (team) of wandb's project") help="the entity (team) of wandb's project")
# Algorithm specific arguments # Algorithm specific arguments
parser.add_argument('--batch-size', type=int, default=2048, parser.add_argument('--batch-size', type=int, default=2048,
help='the batch size of ppo') help='the batch size of ppo')
parser.add_argument('--minibatch-size', type=int, default=256, parser.add_argument('--minibatch-size', type=int, default=256,
help='the mini batch size of ppo') help='the mini batch size of ppo')
parser.add_argument('--gamma', type=float, default=0.99, parser.add_argument('--gamma', type=float, default=0.99,
help='the discount factor gamma') help='the discount factor gamma')
parser.add_argument('--gae-lambda', type=float, default=0.97, parser.add_argument('--gae-lambda', type=float, default=0.97,
help='the lambda for the general advantage estimation') help='the lambda for the general advantage estimation')
parser.add_argument('--ent-coef', type=float, default=0.01, parser.add_argument('--ent-coef', type=float, default=0.01,
help="coefficient of the entropy") help="coefficient of the entropy")
parser.add_argument('--max-grad-norm', type=float, default=0.5, parser.add_argument('--max-grad-norm', type=float, default=0.5,
help='the maximum norm for the gradient clipping') help='the maximum norm for the gradient clipping')
parser.add_argument('--clip-coef', type=float, default=0.2, parser.add_argument('--clip-coef', type=float, default=0.2,
help="the surrogate clipping coefficient") help="the surrogate clipping coefficient")
parser.add_argument('--update-epochs', type=int, default=10, parser.add_argument('--update-epochs', type=int, default=10,
help="the K epochs to update the policy") help="the K epochs to update the policy")
parser.add_argument('--kle-stop', action='store_true', default=False, parser.add_argument('--kle-stop', action='store_true', default=False,
help='If toggled, the policy updates will be early stopped w.r.t target-kl') help='If toggled, the policy updates will be early stopped w.r.t target-kl')
parser.add_argument('--kle-rollback', action='store_true', default=False, parser.add_argument('--kle-rollback', action='store_true', default=False,
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl') help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
parser.add_argument('--target-kl', type=float, default=0.015, parser.add_argument('--target-kl', type=float, default=0.015,
help='the target-kl variable that is referred by --kl') help='the target-kl variable that is referred by --kl')
parser.add_argument('--gae', action='store_true', default=True, parser.add_argument('--gae', action='store_true', default=True,
help='Use GAE for advantage computation') help='Use GAE for advantage computation')
parser.add_argument('--policy-lr', type=float, default=3e-4, parser.add_argument('--policy-lr', type=float, default=3e-4,
help="the learning rate of the policy optimizer") help="the learning rate of the policy optimizer")
parser.add_argument('--value-lr', type=float, default=3e-4, parser.add_argument('--value-lr', type=float, default=3e-4,
help="the learning rate of the critic optimizer") help="the learning rate of the critic optimizer")
parser.add_argument('--norm-obs', action='store_true', default=True, parser.add_argument('--norm-obs', action='store_true', default=True,
help="Toggles observation normalization") help="Toggles observation normalization")
parser.add_argument('--norm-returns', action='store_true', default=False, parser.add_argument('--norm-returns', action='store_true', default=False,
help="Toggles returns normalization") help="Toggles returns normalization")
parser.add_argument('--norm-adv', action='store_true', default=True, parser.add_argument('--norm-adv', action='store_true', default=True,
help="Toggles advantages normalization") help="Toggles advantages normalization")
parser.add_argument('--obs-clip', type=float, default=10.0, parser.add_argument('--obs-clip', type=float, default=10.0,
help="Value for reward clipping, as per the paper") help="Value for reward clipping, as per the paper")
parser.add_argument('--rew-clip', type=float, default=10.0, parser.add_argument('--rew-clip', type=float, default=10.0,
help="Value for observation clipping, as per the paper") help="Value for observation clipping, as per the paper")
parser.add_argument('--anneal-lr', action='store_true', default=True, parser.add_argument('--anneal-lr', action='store_true', default=True,
help="Toggle learning rate annealing for policy and value networks") help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'], parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
help='Selects the scheme to be used for weights initialization'), help='Selects the scheme to be used for weights initialization'),
parser.add_argument('--clip-vloss', action="store_true", default=True, parser.add_argument('--clip-vloss', action="store_true", default=True,
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.') help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
parser.add_argument('--pol-layer-norm', action='store_true', default=False, parser.add_argument('--pol-layer-norm', action='store_true', default=False,
help='Enables layer normalization in the policy network') help='Enables layer normalization in the policy network')
args = parser.parse_args() args = parser.parse_args()
if not args.seed: if not args.seed:
args.seed = int(time.time()) args.seed = int(time.time())
args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()]))) '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
if args.prod_mode:if args.prod_mode:
import wandb import wandb
wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True) wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
writer = SummaryWriter(f"/tmp/{experiment_name}") writer = SummaryWriter(f"/tmp/{experiment_name}")
wandb.save(os.path.abspath(__file__)) wandb.save(os.path.abspath(__file__))
# TRY NOT TO MODIFY: seeding# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)env = gym.make(args.gym_id)
# respect the default timelimit# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):if isinstance(env, TimeLimit):
if int(args.episode_length): if int(args.episode_length):
env._max_episode_steps = int(args.episode_length) env._max_episode_steps = int(args.episode_length)
args.episode_length = env._max_episode_steps args.episode_length = env._max_episode_steps
else:else:
env = TimeLimit(env, int(args.episode_length)) env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)random.seed(args.seed)
np.random.seed(args.seed)np.random.seed(args.seed)
torch.manual_seed(args.seed)torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministictorch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)env.seed(args.seed)
env.action_space.seed(args.seed)env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)env.observation_space.seed(args.seed)
if args.capture_video:if args.capture_video:
env = Monitor(env, f'videos/{experiment_name}') env = Monitor(env, f'videos/{experiment_name}')
# ALGO LOGIC: initialize agent here:# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):class CategoricalMasked(Categorical):
def __init__(self, probs=None, logits=None, validate_args=None, masks=[]): def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
self.masks = masks self.masks = masks
if len(self.masks) == 0: if len(self.masks) == 0:
super(CategoricalMasked, self).__init__(probs, logits, validate_args) super(CategoricalMasked, self).__init__(probs, logits, validate_args)
else: else:
self.masks = masks.type(torch.BoolTensor).to(device) self.masks = masks.type(torch.BoolTensor).to(device)
logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device)) logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
super(CategoricalMasked, self).__init__(probs, logits, validate_args) super(CategoricalMasked, self).__init__(probs, logits, validate_args)
def entropy(self): def entropy(self):
if len(self.masks) == 0: if len(self.masks) == 0:
return super(CategoricalMasked, self).entropy() return super(CategoricalMasked, self).entropy()
p_log_p = self.logits * self.probs p_log_p = self.logits * self.probs
p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device)) p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
return -p_log_p.sum(-1) return -p_log_p.sum(-1)
class Policy(nn.Module):class Policy(nn.Module):
def __init__(self): def __init__(self):
super(Policy, self).__init__() super(Policy, self).__init__()
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(27, 16, kernel_size=3,), nn.Conv2d(27, 16, kernel_size=3,),
nn.MaxPool2d(1), nn.MaxPool2d(1),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3), nn.Conv2d(16, 32, kernel_size=3),
nn.MaxPool2d(1), nn.MaxPool2d(1),
nn.ReLU()) nn.ReLU())
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(32*6*6, 128), nn.Linear(32*6*6, 128),
nn.ReLU(), nn.ReLU(),
nn.Linear(128, env.action_space.nvec.sum()) nn.Linear(128, env.action_space.nvec.sum())
) )
def forward(self, x): def forward(self, x):
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device) x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = self.features(x) x = self.features(x)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
x = self.fc(x) x = self.fc(x)
return x return x
def get_action(self, x, action=None): def get_action(self, x, action=None, invalid_action_masks=None):
logits = self.forward(x) logits = self.forward(x)
split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1) split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
if invalid_action_masks is not None:
split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
else:
multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
if action is None: if action is None:
action = torch.stack([categorical.sample() for categorical in multi_categoricals]) action = torch.stack([categorical.sample() for categorical in multi_categoricals])
logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)]) logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
# entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
return action, logprob, [], multi_categoricals return action, logprob, [], multi_categoricals
class Value(nn.Module):class Value(nn.Module):
def __init__(self): def __init__(self):
super(Value, self).__init__() super(Value, self).__init__()
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(27, 16, kernel_size=3,), nn.Conv2d(27, 16, kernel_size=3,),
nn.MaxPool2d(1), nn.MaxPool2d(1),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3), nn.Conv2d(16, 32, kernel_size=3),
nn.MaxPool2d(1), nn.MaxPool2d(1),
nn.ReLU()) nn.ReLU())
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(32*6*6, 128), nn.Linear(32*6*6, 128),
nn.ReLU(), nn.ReLU(),
nn.Linear(128, 1) nn.Linear(128, 1)
) )
def forward(self, x): def forward(self, x):
x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device) x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
x = self.features(x) x = self.features(x)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
x = self.fc(x) x = self.fc(x)
return x return x
def discount_cumsum(x, dones, gamma):def discount_cumsum(x, dones, gamma):
""" """
computing discounted cumulative sums of vectors that resets with dones computing discounted cumulative sums of vectors that resets with dones
input: input:
vector x, vector dones, vector x, vector dones,
[x0, [0, [x0, [0,
x1, 0, x1, 0,
x2 1, x2 1,
x3 0, x3 0,
x4] 0] x4] 0]
output: output:
[x0 + discount * x1 + discount^2 * x2, [x0 + discount * x1 + discount^2 * x2,
x1 + discount * x2, x1 + discount * x2,
x2, x2,
x3 + discount * x4, x3 + discount * x4,
x4] x4]
""" """
discount_cumsum = np.zeros_like(x) discount_cumsum = np.zeros_like(x)
discount_cumsum[-1] = x[-1] discount_cumsum[-1] = x[-1]
for t in reversed(range(x.shape[0]-1)): for t in reversed(range(x.shape[0]-1)):
discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t]) discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
return discount_cumsum return discount_cumsum
pg = Policy().to(device)pg = Policy().to(device)
vf = Value().to(device)vf = Value().to(device)
# MODIFIED: Separate optimizer and learning rates# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)
# MODIFIED: Initializing learning rate anneal scheduler when need# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:if args.anneal_lr:
anneal_fn = lambda f: max(0, 1-f / args.total_timesteps) anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn) pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn) vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)
loss_fn = nn.MSELoss()loss_fn = nn.MSELoss()
# TRY NOT TO MODIFY: start the game# TRY NOT TO MODIFY: start the game
global_step = 0global_step = 0
while global_step < args.total_timesteps:while global_step < args.total_timesteps:
if args.capture_video: if args.capture_video:
env.stats_recorder.done=True env.stats_recorder.done=True
next_obs = np.array(env.reset()) next_obs = np.array(env.reset())
# ALGO Logic: Storage for epoch data # ALGO Logic: Storage for epoch data
obs = np.empty((args.batch_size,) + env.observation_space.shape) obs = np.empty((args.batch_size,) + env.observation_space.shape)
actions = np.empty((args.batch_size,) + env.action_space.shape) actions = np.empty((args.batch_size,) + env.action_space.shape)
logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device) logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)
rewards = np.zeros((args.batch_size,)) rewards = np.zeros((args.batch_size,))
raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
real_rewards = [] real_rewards = []
test_reward = []
returns = np.zeros((args.batch_size,)) returns = np.zeros((args.batch_size,))
dones = np.zeros((args.batch_size,)) dones = np.zeros((args.batch_size,))
values = torch.zeros((args.batch_size,)).to(device) values = torch.zeros((args.batch_size,)).to(device)
invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
# TRY NOT TO MODIFY: prepare the execution of the game. # TRY NOT TO MODIFY: prepare the execution of the game.
for step in range(args.batch_size): for step in range(args.batch_size):
env.render() env.render()
global_step += 1 global_step += 1
obs[step] = next_obs.copy() obs[step] = next_obs.copy()
# ALGO LOGIC: put action logic here # ALGO LOGIC: put action logic here
invalid_action_mask = torch.ones(env.action_space.nvec.sum())
invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
invalid_action_masks[step] = invalid_action_mask
with torch.no_grad(): with torch.no_grad():
values[step] = vf.forward(obs[step:step+1]) values[step] = vf.forward(obs[step:step+1])
action, logproba, _, probs = pg.get_action(obs[step:step+1]) action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
actions[step] = action[:,0].data.cpu().numpy() actions[step] = action[:,0].data.cpu().numpy()
logprobs[:,[step]] = logproba logprobs[:,[step]] = logproba
# TRY NOT TO MODIFY: execute the game and log data. # TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy()) next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
raw_rewards[:,step] = info["rewards"]
real_rewards += [info['real_reward']] real_rewards += [info['real_reward']]
next_obs = np.array(next_obs) next_obs = np.array(next_obs)
# Annealing the rate if instructed to do so. # Annealing the rate if instructed to do so.
if args.anneal_lr: if args.anneal_lr:
pg_lr_scheduler.step() pg_lr_scheduler.step()
vf_lr_scheduler.step() vf_lr_scheduler.step()
if dones[step]: if dones[step]:
# Computing the discounted returns: # Computing the discounted returns:
writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step) writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}") print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
for i in range(len(env.rfs)):
writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
real_rewards = [] real_rewards = []
next_obs = np.array(env.reset()) next_obs = np.array(env.reset())
# bootstrap reward if not done. reached the batch limit # bootstrap reward if not done. reached the batch limit
last_value = 0 last_value = 0
if not dones[step]: if not dones[step]:
last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0] last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
bootstrapped_rewards = np.append(rewards, last_value) bootstrapped_rewards = np.append(rewards, last_value)
# calculate the returns and advantages # calculate the returns and advantages
if args.gae: if args.gae:
bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value) bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1] deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda) advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
advantages = torch.Tensor(advantages).to(device) advantages = torch.Tensor(advantages).to(device)
returns = advantages + values returns = advantages + values
else: else:
returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1] returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
advantages = returns - values.detach().cpu().numpy() advantages = returns - values.detach().cpu().numpy()
advantages = torch.Tensor(advantages).to(device) advantages = torch.Tensor(advantages).to(device)
returns = torch.Tensor(returns).to(device) returns = torch.Tensor(returns).to(device)
# Advantage normalization # Advantage normalization
if args.norm_adv: if args.norm_adv:
EPS = 1e-10 EPS = 1e-10
advantages = (advantages - advantages.mean()) / (advantages.std() + EPS) advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)
# Optimizaing policy network # Optimizaing policy network
entropys = [] entropys = []
target_pg = Policy().to(device) target_pg = Policy().to(device)
inds = np.arange(args.batch_size,) inds = np.arange(args.batch_size,)
for i_epoch_pi in range(args.update_epochs): for i_epoch_pi in range(args.update_epochs):
np.random.shuffle(inds) np.random.shuffle(inds)
for start in range(0, args.batch_size, args.minibatch_size): for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size end = start + args.minibatch_size
minibatch_ind = inds[start:end] minibatch_ind = inds[start:end]
target_pg.load_state_dict(pg.state_dict()) target_pg.load_state_dict(pg.state_dict())
_, newlogproba, _, _ = pg.get_action( _, newlogproba, _, _ = pg.get_action(
obs[minibatch_ind], obs[minibatch_ind],
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T) torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
invalid_action_masks[minibatch_ind])
ratio = (newlogproba - logprobs[:,minibatch_ind]).exp() ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()
# Policy loss as in OpenAI SpinUp # Policy loss as in OpenAI SpinUp
clip_adv = torch.where(advantages[minibatch_ind] > 0, clip_adv = torch.where(advantages[minibatch_ind] > 0,
(1.+args.clip_coef) * advantages[minibatch_ind], (1.+args.clip_coef) * advantages[minibatch_ind],
(1.-args.clip_coef) * advantages[minibatch_ind]).to(device) (1.-args.clip_coef) * advantages[minibatch_ind]).to(device)
# Entropy computation with resampled actions # Entropy computation with resampled actions
entropy = -(newlogproba.exp() * newlogproba).mean() entropy = -(newlogproba.exp() * newlogproba).mean()
entropys.append(entropy.item()) entropys.append(entropy.item())
policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
policy_loss = policy_loss.mean() policy_loss = policy_loss.mean()
pg_optimizer.zero_grad() pg_optimizer.zero_grad()
policy_loss.backward() policy_loss.backward()
nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
pg_optimizer.step() pg_optimizer.step()
approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean() approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
# Resample values # Optimizing value network
new_values = vf.forward(obs[minibatch_ind]).view(-1) new_values = vf.forward(obs[minibatch_ind]).view(-1)
# Value loss clipping # Value loss clipping
if args.clip_vloss: if args.clip_vloss:
v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2) v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef) v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_loss_clipped = (v_clipped - returns[minibatch_ind])**2 v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean() v_loss = 0.5 * v_loss_max.mean()
else: else:
v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2)) v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))
v_optimizer.zero_grad() v_optimizer.zero_grad()
v_loss.backward() v_loss.backward()
nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
v_optimizer.step() v_optimizer.step()
if args.kle_stop: if args.kle_stop:
if approx_kl > args.target_kl: if approx_kl > args.target_kl:
break break
if args.kle_rollback: if args.kle_rollback:
if (logprobs[:,minibatch_ind] - if (logprobs[:,minibatch_ind] -
pg.get_action( pg.get_action(
obs[minibatch_ind], obs[minibatch_ind],
torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T, torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
)[1]).mean() > args.target_kl: invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
pg.load_state_dict(target_pg.state_dict()) pg.load_state_dict(target_pg.state_dict())
break break
# TRY NOT TO MODIFY: record rewards for plotting purposes # TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step) writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step) writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step) writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
writer.add_scalar("losses/entropy", np.mean(entropys), global_step) writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
if args.kle_stop or args.kle_rollback: if args.kle_stop or args.kle_rollback:
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step) writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
env.close()env.close()
writer.close()writer.close()
Editor
Clear
Original Text
Changed Text