Implementation of invalid action masking

Created Diff never expires
8 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
401 lines
45 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
439 lines
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.optim as optim
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter


import argparse
import argparse
from distutils.util import strtobool
from distutils.util import strtobool
import numpy as np
import numpy as np
import gym
import gym
import gym_microrts
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.wrappers import TimeLimit, Monitor


from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import time
import random
import random
import os
import os
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper


if __name__ == "__main__":
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PPO agent')
parser = argparse.ArgumentParser(description='PPO agent')
# Common arguments
# Common arguments
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment')
help='the name of this experiment')
parser.add_argument('--gym-id', type=str, default="MicrortsCombinedReward10x10F9BuildCombatUnits-v0",
parser.add_argument('--gym-id', type=str, default="MicrortsCombinedReward10x10F9BuildCombatUnits-v0",
help='the id of the gym environment')
help='the id of the gym environment')
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
help='the learning rate of the optimizer')
help='the learning rate of the optimizer')
parser.add_argument('--seed', type=int, default=1,
parser.add_argument('--seed', type=int, default=1,
help='seed of the experiment')
help='seed of the experiment')
parser.add_argument('--total-timesteps', type=int, default=10000000,
parser.add_argument('--total-timesteps', type=int, default=10000000,
help='total timesteps of the experiments')
help='total timesteps of the experiments')
parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, `torch.backends.cudnn.deterministic=False`')
help='if toggled, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, cuda will not be enabled by default')
help='if toggled, cuda will not be enabled by default')
parser.add_argument('--prod-mode', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--prod-mode', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='run the script in production mode and use wandb to log outputs')
help='run the script in production mode and use wandb to log outputs')
parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='weather to capture videos of the agent performances (check out `videos` folder)')
help='weather to capture videos of the agent performances (check out `videos` folder)')
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
help="the wandb's project name")
help="the wandb's project name")
parser.add_argument('--wandb-entity', type=str, default=None,
parser.add_argument('--wandb-entity', type=str, default=None,
help="the entity (team) of wandb's project")
help="the entity (team) of wandb's project")


# Algorithm specific arguments
# Algorithm specific arguments
parser.add_argument('--n-minibatch', type=int, default=4,
parser.add_argument('--n-minibatch', type=int, default=4,
help='the number of mini batch')
help='the number of mini batch')
parser.add_argument('--num-envs', type=int, default=4,
parser.add_argument('--num-envs', type=int, default=4,
help='the number of parallel game environment')
help='the number of parallel game environment')
parser.add_argument('--num-steps', type=int, default=128,
parser.add_argument('--num-steps', type=int, default=128,
help='the number of steps per game environment')
help='the number of steps per game environment')
parser.add_argument('--gamma', type=float, default=0.99,
parser.add_argument('--gamma', type=float, default=0.99,
help='the discount factor gamma')
help='the discount factor gamma')
parser.add_argument('--gae-lambda', type=float, default=0.95,
parser.add_argument('--gae-lambda', type=float, default=0.95,
help='the lambda for the general advantage estimation')
help='the lambda for the general advantage estimation')
parser.add_argument('--ent-coef', type=float, default=0.01,
parser.add_argument('--ent-coef', type=float, default=0.01,
help="coefficient of the entropy")
help="coefficient of the entropy")
parser.add_argument('--vf-coef', type=float, default=0.5,
parser.add_argument('--vf-coef', type=float, default=0.5,
help="coefficient of the value function")
help="coefficient of the value function")
parser.add_argument('--max-grad-norm', type=float, default=0.5,
parser.add_argument('--max-grad-norm', type=float, default=0.5,
help='the maximum norm for the gradient clipping')
help='the maximum norm for the gradient clipping')
parser.add_argument('--clip-coef', type=float, default=0.1,
parser.add_argument('--clip-coef', type=float, default=0.1,
help="the surrogate clipping coefficient")
help="the surrogate clipping coefficient")
parser.add_argument('--update-epochs', type=int, default=4,
parser.add_argument('--update-epochs', type=int, default=4,
help="the K epochs to update the policy")
help="the K epochs to update the policy")
parser.add_argument('--kle-stop', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--kle-stop', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
parser.add_argument('--kle-rollback', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--kle-rollback', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
parser.add_argument('--target-kl', type=float, default=0.03,
parser.add_argument('--target-kl', type=float, default=0.03,
help='the target-kl variable that is referred by --kl')
help='the target-kl variable that is referred by --kl')
parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='Use GAE for advantage computation')
help='Use GAE for advantage computation')
parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggles advantages normalization")
help="Toggles advantages normalization")
parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggle learning rate annealing for policy and value networks")
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')


args = parser.parse_args()
args = parser.parse_args()
if not args.seed:
if not args.seed:
args.seed = int(time.time())
args.seed = int(time.time())


args.batch_size = int(args.num_envs * args.num_steps)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.n_minibatch)
args.minibatch_size = int(args.batch_size // args.n_minibatch)


class ImageToPyTorch(gym.ObservationWrapper):
class ImageToPyTorch(gym.ObservationWrapper):
def __init__(self, env):
def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
super(ImageToPyTorch, self).__init__(env)
old_shape = self.observation_space.shape
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(
self.observation_space = gym.spaces.Box(
low=0,
low=0,
high=1,
high=1,
shape=(old_shape[-1], old_shape[0], old_shape[1]),
shape=(old_shape[-1], old_shape[0], old_shape[1]),
dtype=np.int32,
dtype=np.int32,
)
)


def observation(self, observation):
def observation(self, observation):
return np.transpose(observation, axes=(2, 0, 1))
return np.transpose(observation, axes=(2, 0, 1))


class VecPyTorch(VecEnvWrapper):
class VecPyTorch(VecEnvWrapper):
def __init__(self, venv, device):
def __init__(self, venv, device):
super(VecPyTorch, self).__init__(venv)
super(VecPyTorch, self).__init__(venv)
self.device = device
self.device = device


def reset(self):
def reset(self):
obs = self.venv.reset()
obs = self.venv.reset()
obs = torch.from_numpy(obs).float().to(self.device)
obs = torch.from_numpy(obs).float().to(self.device)
return obs
return obs


def step_async(self, actions):
def step_async(self, actions):
actions = actions.cpu().numpy()
actions = actions.cpu().numpy()
self.venv.step_async(actions)
self.venv.step_async(actions)


def step_wait(self):
def step_wait(self):
obs, reward, done, info = self.venv.step_wait()
obs, reward, done, info = self.venv.step_wait()
obs = torch.from_numpy(obs).float().to(self.device)
obs = torch.from_numpy(obs).float().to(self.device)
reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
return obs, reward, done, info
return obs, reward, done, info


class MicroRTSStatsRecorder(gym.Wrapper):
class MicroRTSStatsRecorder(gym.Wrapper):


def reset(self, **kwargs):
def reset(self, **kwargs):
observation = super(MicroRTSStatsRecorder, self).reset(**kwargs)
observation = super(MicroRTSStatsRecorder, self).reset(**kwargs)
self.raw_rewards = []
self.raw_rewards = []
return observation
return observation


def step(self, action):
def step(self, action):
observation, reward, done, info = super(MicroRTSStatsRecorder, self).step(action)
observation, reward, done, info = super(MicroRTSStatsRecorder, self).step(action)
self.raw_rewards += [info["raw_rewards"]]
self.raw_rewards += [info["raw_rewards"]]
if done:
if done:
raw_rewards = np.array(self.raw_rewards).sum(0)
raw_rewards = np.array(self.raw_rewards).sum(0)
raw_names = [str(rf) for rf in self.rfs]
raw_names = [str(rf) for rf in self.rfs]
info['microrts_stats'] = dict(zip(raw_names, raw_rewards))
info['microrts_stats'] = dict(zip(raw_names, raw_rewards))
self.raw_rewards = []
self.raw_rewards = []
return observation, reward, done, info
return observation, reward, done, info




# TRY NOT TO MODIFY: setup the environment
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
if args.prod_mode:
if args.prod_mode:
import wandb
import wandb
run = wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True, save_code=True)
run = wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True, save_code=True)
writer = SummaryWriter(f"/tmp/{experiment_name}")
writer = SummaryWriter(f"/tmp/{experiment_name}")


# TRY NOT TO MODIFY: seeding
# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
random.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.backends.cudnn.deterministic = args.torch_deterministic
def make_env(gym_id, seed, idx):
def make_env(gym_id, seed, idx):
def thunk():
def thunk():
env = gym.make(gym_id)
env = gym.make(gym_id)
env = ImageToPyTorch(env)
env = ImageToPyTorch(env)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = MicroRTSStatsRecorder(env)
env = MicroRTSStatsRecorder(env)
if args.capture_video:
if args.capture_video:
if idx == 0:
if idx == 0:
env = Monitor(env, f'videos/{experiment_name}')
env = Monitor(env, f'videos/{experiment_name}')
env.seed(seed)
env.seed(seed)
env.action_space.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
env.observation_space.seed(seed)
return env
return env
return thunk
return thunk
envs = VecPyTorch(DummyVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)]), device)
envs = VecPyTorch(DummyVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)]), device)
# if args.prod_mode:
# if args.prod_mode:
# envs = VecPyTorch(
# envs = VecPyTorch(
# SubprocVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)], "fork"),
# SubprocVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)], "fork"),
# device
# device
# )
# )
assert isinstance(envs.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(envs.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"


# ALGO LOGIC: initialize agent here:
# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
self.masks = masks
if len(self.masks) == 0:
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
else:
self.masks = masks.type(torch.BoolTensor).to(device)
logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
def entropy(self):
if len(self.masks) == 0:
return super(CategoricalMasked, self).entropy()
p_log_p = self.logits * self.probs
p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
return -p_log_p.sum(-1)

class Scale(nn.Module):
class Scale(nn.Module):
def __init__(self, scale):
def __init__(self, scale):
super().__init__()
super().__init__()
self.scale = scale
self.scale = scale


def forward(self, x):
def forward(self, x):
return x * self.scale
return x * self.scale


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
return layer


class Agent(nn.Module):
class Agent(nn.Module):
def __init__(self, frames=4):
def __init__(self, frames=4):
super(Agent, self).__init__()
super(Agent, self).__init__()
self.network = nn.Sequential(
self.network = nn.Sequential(
layer_init(nn.Conv2d(27, 16, kernel_size=3, stride=2)),
layer_init(nn.Conv2d(27, 16, kernel_size=3, stride=2)),
nn.ReLU(),
nn.ReLU(),
layer_init(nn.Conv2d(16, 32, kernel_size=2)),
layer_init(nn.Conv2d(16, 32, kernel_size=2)),
nn.ReLU(),
nn.ReLU(),
nn.Flatten(),
nn.Flatten(),
layer_init(nn.Linear(32*3*3, 128)),
layer_init(nn.Linear(32*3*3, 128)),
nn.ReLU(),)
nn.ReLU(),)
self.actor = layer_init(nn.Linear(128, envs.action_space.nvec.sum()), std=0.01)
self.actor = layer_init(nn.Linear(128, envs.action_space.nvec.sum()), std=0.01)
self.critic = layer_init(nn.Linear(128, 1), std=1)
self.critic = layer_init(nn.Linear(128, 1), std=1)


def forward(self, x):
def forward(self, x):
return self.network(x)
return self.network(x)


def get_action(self, x, action=None):
def get_action(self, x, action=None, invalid_action_masks=None, envs=None):
logits = self.actor(self.forward(x))
logits = self.actor(self.forward(x))
split_logits = torch.split(logits, envs.action_space.nvec.tolist(), dim=1)
split_logits = torch.split(logits, envs.action_space.nvec.tolist(), dim=1)
multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
if action is None:
if action is None:
# 1. select source unit based on source unit mask
source_unit_mask = torch.Tensor(np.array(envs.env_method("get_unit_location_mask", player=1)))
multi_categoricals = [CategoricalMasked(logits=split_logits[0], masks=source_unit_mask)]
action_components = [multi_categoricals[0].sample()]
# 2. select action type and parameter section based on the
# source-unit mask of action type and parameters
source_unit_action_mask = torch.Tensor(
[envs.env_method("get_unit_action_mask", unit=action_components[0][i], player=1, indices=i)[0]
for i in range(envs.num_envs)])
split_suam = torch.split(source_unit_action_mask, envs.action_space.nvec.tolist()[1:], dim=1)
multi_categoricals = multi_categoricals + [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits[1:], split_suam)]
invalid_action_masks = torch.cat((source_unit_mask, source_unit_action_mask), 1)
action = torch.stack([categorical.sample() for categorical in multi_categoricals])
action = torch.stack([categorical.sample() for categorical in multi_categoricals])
else:
split_invalid_action_masks = torch.split(invalid_action_masks, envs.action_space.nvec.tolist(), dim=1)
multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
return action, logprob.sum(0), entropy.sum(0)
return action, logprob.sum(0), entropy.sum(0), invalid_action_masks


def get_value(self, x):
def get_value(self, x):
return self.critic(self.forward(x))
return self.critic(self.forward(x))


agent = Agent().to(device)
agent = Agent().to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
if args.anneal_lr:
if args.anneal_lr:
# https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
# https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
lr = lambda f: f * args.learning_rate
lr = lambda f: f * args.learning_rate


# ALGO Logic: Storage for epoch data
# ALGO Logic: Storage for epoch data
obs = torch.zeros((args.num_steps, args.num_envs) + envs.observation_space.shape).to(device)
obs = torch.zeros((args.num_steps, args.num_envs) + envs.observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.action_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
invalid_action_masks = torch.zeros((args.num_steps, args.num_envs) + (envs.action_space.nvec.sum(),)).to(device)
# TRY NOT TO MODIFY: start the game
# TRY NOT TO MODIFY: start the game
global_step = 0
global_step = 0
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
next_obs = envs.reset()
next_obs = envs.reset()
next_done = torch.zeros(args.num_envs).to(device)
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size
num_updates = args.total_timesteps // args.batch_size


## CRASH AND RESUME LOGIC:
## CRASH AND RESUME LOGIC:
starting_update = 1
starting_update = 1
if args.prod_mode and wandb.run.resumed:
if args.prod_mode and wandb.run.resumed:
print("previous run.summary", run.summary)
print("previous run.summary", run.summary)
starting_update = run.summary['charts/update'] + 1
starting_update = run.summary['charts/update'] + 1
global_step = starting_update * args.batch_size
global_step = starting_update * args.batch_size
api = wandb.Api()
api = wandb.Api()
run = api.run(run.get_url()[len("https://app.wandb.ai/"):])
run = api.run(run.get_url()[len("https://app.wandb.ai/"):])
model = run.file('agent.pt')
model = run.file('agent.pt')
model.download(f"models/{experiment_name}/")
model.download(f"models/{experiment_name}/")
agent.load_state_dict(torch.load(f"models/{experiment_name}/agent.pt"))
agent.load_state_dict(torch.load(f"models/{experiment_name}/agent.pt"))
agent.eval()
agent.eval()
print(f"resumed at update {starting_update}")
print(f"resumed at update {starting_update}")
for update in range(starting_update, num_updates+1):
for update in range(starting_update, num_updates+1):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
if args.anneal_lr:
if args.anneal_lr:
frac = 1.0 - (update - 1.0) / num_updates
frac = 1.0 - (update - 1.0) / num_updates
lrnow = lr(frac)
lrnow = lr(frac)
optimizer.param_groups[0]['lr'] = lrnow
optimizer.param_groups[0]['lr'] = lrnow


# TRY NOT TO MODIFY: prepare the execution of the game.
# TRY NOT TO MODIFY: prepare the execution of the game.
for step in range(0, args.num_steps):
for step in range(0, args.num_steps):
envs.env_method("render", indices=0)
envs.env_method("render", indices=0)
global_step += 1 * args.num_envs
global_step += 1 * args.num_envs
obs[step] = next_obs
obs[step] = next_obs
dones[step] = next_done
dones[step] = next_done


# ALGO LOGIC: put action logic here
# ALGO LOGIC: put action logic here
with torch.no_grad():
with torch.no_grad():
values[step] = agent.get_value(obs[step]).flatten()
values[step] = agent.get_value(obs[step]).flatten()
action, logproba, _ = agent.get_action(obs[step])
action, logproba, _, invalid_action_masks[step] = agent.get_action(obs[step], envs=envs)

actions[step] = action.T
actions[step] = action.T
logprobs[step] = logproba
logprobs[step] = logproba


# TRY NOT TO MODIFY: execute the game and log data.
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rs, ds, infos = envs.step(action.T)
next_obs, rs, ds, infos = envs.step(action.T)
rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)
rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)


for info in infos:
for info in infos:
if 'episode' in info.keys():
if 'episode' in info.keys():
print(f"global_step={global_step}, episode_reward={info['episode']['r']}")
print(f"global_step={global_step}, episode_reward={info['episode']['r']}")
writer.add_scalar("charts/episode_reward", info['episode']['r'], global_step)
writer.add_scalar("charts/episode_reward", info['episode']['r'], global_step)
for key in info['microrts_stats']:
for key in info['microrts_stats']:
writer.add_scalar(f"charts/episode_reward/{key}", info['microrts_stats'][key], global_step)
writer.add_scalar(f"charts/episode_reward/{key}", info['microrts_stats'][key], global_step)
break
break


# bootstrap reward if not done. reached the batch limit
# bootstrap reward if not done. reached the batch limit
with torch.no_grad():
with torch.no_grad():
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)
if args.gae:
if args.gae:
advantages = torch.zeros_like(rewards).to(device)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
lastgaelam = 0
for t in reversed(range(args.num_steps)):
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = 1.0 - next_done
nextvalues = last_value
nextvalues = last_value
else:
else:
nextnonterminal = 1.0 - dones[t+1]
nextnonterminal = 1.0 - dones[t+1]
nextvalues = values[t+1]
nextvalues = values[t+1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
returns = advantages + values
else:
else:
returns = torch.zeros_like(rewards).to(device)
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = 1.0 - next_done
next_return = last_value
next_return = last_value
else:
else:
nextnonterminal = 1.0 - dones[t+1]
nextnonterminal = 1.0 - dones[t+1]
next_return = returns[t+1]
next_return = returns[t+1]
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values
advantages = returns - values


# flatten the batch
# flatten the batch
b_obs = obs.reshape((-1,)+envs.observation_space.shape)
b_obs = obs.reshape((-1,)+envs.observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,)+envs.action_space.shape)
b_actions = actions.reshape((-1,)+envs.action_space.shape)
b_advantages = advantages.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_values = values.reshape(-1)
b_invalid_action_masks = invalid_action_masks.reshape((-1, invalid_action_masks.shape[-1]))


# Optimizaing the policy and value network
# Optimizaing the policy and value network
target_agent = Agent().to(device)
target_agent = Agent().to(device)
inds = np.arange(args.batch_size,)
inds = np.arange(args.batch_size,)
for i_epoch_pi in range(args.update_epochs):
for i_epoch_pi in range(args.update_epochs):
np.random.shuffle(inds)
np.random.shuffle(inds)
target_agent.load_state_dict(agent.state_dict())
target_agent.load_state_dict(agent.state_dict())
for start in range(0, args.batch_size, args.minibatch_size):
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
end = start + args.minibatch_size
minibatch_ind = inds[start:end]
minibatch_ind = inds[start:end]
mb_advantages = b_advantages[minibatch_ind]
mb_advantages = b_advantages[minibatch_ind]
if args.norm_adv:
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)


_, newlogproba, entropy = agent.get_action(
_, newlogproba, entropy, _ = agent.get_action(
b_obs[minibatch_ind],
b_obs[minibatch_ind],
b_actions.long()[minibatch_ind].T)
b_actions.long()[minibatch_ind].T,
b_invalid_action_masks[minibatch_ind],
envs)
ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()
ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()


# Stats
# Stats
approx_kl = (b_logprobs[minibatch_ind] - newlogproba).mean()
approx_kl = (b_logprobs[minibatch_ind] - newlogproba).mean()


# Policy loss
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-args.clip_coef, 1+args.clip_coef)
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-args.clip_coef, 1+args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
entropy_loss = entropy.mean()
entropy_loss = entropy.mean()


# Value loss
# Value loss
new_values = agent.get_value(b_obs[minibatch_ind]).view(-1)
new_values = agent.get_value(b_obs[minibatch_ind]).view(-1)
if args.clip_vloss:
if args.clip_vloss:
v_loss_unclipped = ((new_values - b_returns[minibatch_ind]) ** 2)
v_loss_unclipped = ((new_values - b_returns[minibatch_ind]) ** 2)
v_clipped = b_values[minibatch_ind] + torch.clamp(new_values - b_values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_clipped = b_values[minibatch_ind] + torch.clamp(new_values - b_values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_loss_clipped = (v_clipped - b_returns[minibatch_ind])**2
v_loss_clipped = (v_clipped - b_returns[minibatch_ind])**2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
v_loss = 0.5 * v_loss_max.mean()
else:
else:
v_loss = 0.5 *((new_values - b_returns[minibatch_ind]) ** 2)
v_loss = 0.5 *((new_values - b_returns[minibatch_ind]) ** 2)


loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef


optimizer.zero_grad()
optimizer.zero_grad()
loss.backward()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.step()


if args.kle_stop:
if args.kle_stop:
if approx_kl > args.target_kl:
if approx_kl > args.target_kl:
break
break
if args.kle_rollback:
if args.kle_rollback:
if (b_logprobs[minibatch_ind] - agent.get_action(
if (b_logprobs[minibatch_ind] - agent.get_action(
b_obs[minibatch_ind],
b_obs[minibatch_ind],
b_actions.long()[minibatch_ind].T)[1]).mean() > args.target_kl:
b_actions.long()[minibatch_ind].T,
b_invalid_action_masks[minibatch_ind],
envs)[1]).mean() > args.target_kl:
agent.load_state_dict(target_agent.state_dict())
agent.load_state_dict(target_agent.state_dict())
break
break


## CRASH AND RESUME LOGIC:
## CRASH AND RESUME LOGIC:
if args.prod_mode:
if args.prod_mode:
if not os.path.exists(f"models/{experiment_name}"):
if not os.path.exists(f"models/{experiment_name}"):
os.makedirs(f"models/{experiment_name}")
os.makedirs(f"models/{experiment_name}")
torch.save(agent.state_dict(), f"{wandb.run.dir}/agent.pt")
torch.save(agent.state_dict(), f"{wandb.run.dir}/agent.pt")
wandb.save(f"agent.pt")
wandb.save(f"agent.pt")


# TRY NOT TO MODIFY: record rewards for plotting purposes
# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("charts/update", update, global_step)
writer.add_scalar("charts/update", update, global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy.mean().item(), global_step)
writer.add_scalar("losses/entropy", entropy.mean().item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
if args.kle_stop or args.kle_rollback:
if args.kle_stop or args.kle_rollback:
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)


envs.close()
envs.close()
writer.close()
writer.close()