ppo_atari_vs_continuous_action

Created Diff never expires
# https://github.com/facebookresearch/torchbeast/blob/master/torchbeast/core/environment.py

import numpy as np
from collections import deque
import gym
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)


class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

def reset(self, **kwargs):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset(**kwargs)
return obs

def step(self, ac):
return self.env.step(ac)

class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3

def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
return obs

def step(self, ac):
return self.env.step(ac)

class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True

def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info

def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs

class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
self._skip = skip

def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2: self._obs_buffer[0] = obs
if i == self._skip - 1: self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)

return max_frame, total_reward, done, info

def reset(self, **kwargs):
return self.env.reset(**kwargs)

class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)

def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)


class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
"""
Warp frames to 84x84 as done in the Nature paper and later work.
If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
observation should be warped.
"""
super().__init__(env)
self._width = width
self._height = height
self._grayscale = grayscale
self._key = dict_space_key
if self._grayscale:
num_colors = 1
else:
num_colors = 3

new_space = gym.spaces.Box(
low=0,
high=255,
shape=(self._height, self._width, num_colors),
dtype=np.uint8,
)
if self._key is None:
original_space = self.observation_space
self.observation_space = new_space
else:
original_space = self.observation_space.spaces[self._key]
self.observation_space.spaces[self._key] = new_space
assert original_space.dtype == np.uint8 and len(original_space.shape) == 3

def observation(self, obs):
if self._key is None:
frame = obs
else:
frame = obs[self._key]

if self._grayscale:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(
frame, (self._width, self._height), interpolation=cv2.INTER_AREA
)
if self._grayscale:
frame = np.expand_dims(frame, -1)

if self._key is None:
obs = frame
else:
obs = obs.copy()
obs[self._key] = frame
return obs


class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Stack k last frames.
Returns lazy array, which is much more memory efficient.
See Also
--------
baselines.common.atari_wrappers.LazyFrames
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)

def reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()

def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info

def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))

class ScaledFloatFrame(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)

def observation(self, observation):
# careful! This undoes the memory optimization, use
# with smaller replay buffers only.
return np.array(observation).astype(np.float32) / 255.0

class LazyFrames(object):
def __init__(self, frames):
"""This object ensures that common frames between the observations are only stored once.
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
buffers.
This object should only be converted to numpy array before being passed to the model.
You'd not believe how complex the previous solution was."""
self._frames = frames
self._out = None

def _force(self):
if self._out is None:
self._out = np.concatenate(self._frames, axis=-1)
self._frames = None
return self._out

def __array__(self, dtype=None):
out = self._force()
if dtype is not None:
out = out.astype(dtype)
return out

def __len__(self):
return len(self._force())

def __getitem__(self, i):
return self._force()[i]

def count(self):
frames = self._force()
return frames.shape[frames.ndim - 1]

def frame(self, i):
return self._force()[..., i]

def wrap_atari(env, max_episode_steps=None):
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)

assert max_episode_steps is None

return env

def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, 4)
return env


class ImageToPyTorch(gym.ObservationWrapper):
"""
Image shape to channels x weight x height
"""

def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(old_shape[-1], old_shape[0], old_shape[1]),
dtype=np.uint8,
)

def observation(self, observation):
return np.transpose(observation, axes=(2, 0, 1))

def wrap_pytorch(env):
return ImageToPyTorch(env)

import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.optim as optim
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter


import argparse
import argparse
from distutils.util import strtobool
from distutils.util import strtobool
import numpy as np
import numpy as np
import gym
import gym
from gym.wrappers import TimeLimit, Monitor
from gym.wrappers import TimeLimit, Monitor
import pybullet_envs
import pybullet_envs
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import time
import random
import random
import os
import os
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper


# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.count = epsilon

def update(self, x):
batch_mean = np.mean([x], axis=0)
batch_var = np.var([x], axis=0)
batch_count = 1
self.update_from_moments(batch_mean, batch_var, batch_count)

def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
delta = batch_mean - mean
tot_count = count + batch_count

new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count

return new_mean, new_var, new_count

class NormalizedEnv(gym.core.Wrapper):
def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
super(NormalizedEnv, self).__init__(env)
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
self.clipob = clipob
self.cliprew = cliprew
self.ret = np.zeros(())
self.gamma = gamma
self.epsilon = epsilon

def step(self, action):
obs, rews, dones, infos = self.env.step(action)
infos['real_reward'] = rews
self.ret = self.ret * self.gamma + rews
obs = self._obfilt(obs)
if self.ret_rms:
self.ret_rms.update(np.array([self.ret].copy()))
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
self.ret = self.ret * (1-float(dones))
return obs, rews, dones, infos

def _obfilt(self, obs):
if self.ob_rms:
self.ob_rms.update(obs)
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
else:
return obs

def reset(self):
self.ret = np.zeros(())
obs = self.env.reset()
return self._obfilt(obs)

if __name__ == "__main__":
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PPO agent')
parser = argparse.ArgumentParser(description='PPO agent')
# Common arguments
# Common arguments
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment')
help='the name of this experiment')
parser.add_argument('--gym-id', type=str, default="HopperBulletEnv-v0",
parser.add_argument('--gym-id', type=str, default="BreakoutNoFrameskip-v4",
help='the id of the gym environment')
help='the id of the gym environment')
parser.add_argument('--learning-rate', type=float, default=3e-4,
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
help='the learning rate of the optimizer')
help='the learning rate of the optimizer')
parser.add_argument('--seed', type=int, default=1,
parser.add_argument('--seed', type=int, default=1,
help='seed of the experiment')
help='seed of the experiment')
parser.add_argument('--total-timesteps', type=int, default=2000000,
parser.add_argument('--total-timesteps', type=int, default=10000000,
help='total timesteps of the experiments')
help='total timesteps of the experiments')
parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, `torch.backends.cudnn.deterministic=False`')
help='if toggled, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, cuda will not be enabled by default')
help='if toggled, cuda will not be enabled by default')
parser.add_argument('--prod-mode', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--prod-mode', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='run the script in production mode and use wandb to log outputs')
help='run the script in production mode and use wandb to log outputs')
parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='weather to capture videos of the agent performances (check out `videos` folder)')
help='weather to capture videos of the agent performances (check out `videos` folder)')
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
help="the wandb's project name")
help="the wandb's project name")
parser.add_argument('--wandb-entity', type=str, default=None,
parser.add_argument('--wandb-entity', type=str, default=None,
help="the entity (team) of wandb's project")
help="the entity (team) of wandb's project")


# Algorithm specific arguments
# Algorithm specific arguments
parser.add_argument('--n-minibatch', type=int, default=32,
parser.add_argument('--n-minibatch', type=int, default=4,
help='the number of mini batch')
help='the number of mini batch')
parser.add_argument('--num-envs', type=int, default=1,
parser.add_argument('--num-envs', type=int, default=8,
help='the number of parallel game environment')
help='the number of parallel game environment')
parser.add_argument('--num-steps', type=int, default=2048,
parser.add_argument('--num-steps', type=int, default=128,
help='the number of steps per game environment')
help='the number of steps per game environment')
parser.add_argument('--gamma', type=float, default=0.99,
parser.add_argument('--gamma', type=float, default=0.99,
help='the discount factor gamma')
help='the discount factor gamma')
parser.add_argument('--gae-lambda', type=float, default=0.95,
parser.add_argument('--gae-lambda', type=float, default=0.95,
help='the lambda for the general advantage estimation')
help='the lambda for the general advantage estimation')
parser.add_argument('--ent-coef', type=float, default=0.0,
parser.add_argument('--ent-coef', type=float, default=0.01,
help="coefficient of the entropy")
help="coefficient of the entropy")
parser.add_argument('--vf-coef', type=float, default=0.5,
parser.add_argument('--vf-coef', type=float, default=0.5,
help="coefficient of the value function")
help="coefficient of the value function")
parser.add_argument('--max-grad-norm', type=float, default=0.5,
parser.add_argument('--max-grad-norm', type=float, default=0.5,
help='the maximum norm for the gradient clipping')
help='the maximum norm for the gradient clipping')
parser.add_argument('--clip-coef', type=float, default=0.2,
parser.add_argument('--clip-coef', type=float, default=0.1,
help="the surrogate clipping coefficient")
help="the surrogate clipping coefficient")
parser.add_argument('--update-epochs', type=int, default=10,
parser.add_argument('--update-epochs', type=int, default=4,
help="the K epochs to update the policy")
help="the K epochs to update the policy")
parser.add_argument('--kle-stop', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--kle-stop', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
parser.add_argument('--kle-rollback', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
parser.add_argument('--kle-rollback', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
parser.add_argument('--target-kl', type=float, default=0.03,
parser.add_argument('--target-kl', type=float, default=0.03,
help='the target-kl variable that is referred by --kl')
help='the target-kl variable that is referred by --kl')
parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='Use GAE for advantage computation')
help='Use GAE for advantage computation')
parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggles advantages normalization")
help="Toggles advantages normalization")
parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggle learning rate annealing for policy and value networks")
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')


args = parser.parse_args()
args = parser.parse_args()
if not args.seed:
if not args.seed:
args.seed = int(time.time())
args.seed = int(time.time())


args.batch_size = int(args.num_envs * args.num_steps)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.n_minibatch)
args.minibatch_size = int(args.batch_size // args.n_minibatch)


class ClipActionsWrapper(gym.Wrapper):
def step(self, action):
import numpy as np
action = np.nan_to_num(action)
action = np.clip(action, self.action_space.low, self.action_space.high)
return self.env.step(action)

class VecPyTorch(VecEnvWrapper):
class VecPyTorch(VecEnvWrapper):
def __init__(self, venv, device):
def __init__(self, venv, device):
super(VecPyTorch, self).__init__(venv)
super(VecPyTorch, self).__init__(venv)
self.device = device
self.device = device


def reset(self):
def reset(self):
obs = self.venv.reset()
obs = self.venv.reset()
obs = torch.from_numpy(obs).float().to(self.device)
obs = torch.from_numpy(obs).float().to(self.device)
return obs
return obs


def step_async(self, actions):
def step_async(self, actions):
actions = actions.cpu().numpy()
actions = actions.cpu().numpy()
self.venv.step_async(actions)
self.venv.step_async(actions)


def step_wait(self):
def step_wait(self):
obs, reward, done, info = self.venv.step_wait()
obs, reward, done, info = self.venv.step_wait()
obs = torch.from_numpy(obs).float().to(self.device)
obs = torch.from_numpy(obs).float().to(self.device)
reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
return obs, reward, done, info
return obs, reward, done, info




# TRY NOT TO MODIFY: setup the environment
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
if args.prod_mode:
if args.prod_mode:
import wandb
import wandb
wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True, save_code=True)
wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True, save_code=True)
writer = SummaryWriter(f"/tmp/{experiment_name}")
writer = SummaryWriter(f"/tmp/{experiment_name}")


# TRY NOT TO MODIFY: seeding
# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
random.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.backends.cudnn.deterministic = args.torch_deterministic
def make_env(gym_id, seed, idx):
def make_env(gym_id, seed, idx):
def thunk():
def thunk():
env = gym.make(gym_id)
env = gym.make(gym_id)
env = ClipActionsWrapper(env)
env = wrap_atari(env)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.RecordEpisodeStatistics(env)
if args.capture_video:
if args.capture_video:
if idx == 0:
if idx == 0:
env = Monitor(env, f'videos/{experiment_name}')
env = Monitor(env, f'videos/{experiment_name}')
env = NormalizedEnv(env)
env = wrap_pytorch(
wrap_deepmind(
env,
clip_rewards=True,
frame_stack=True,
scale=False,
)
)
env.seed(seed)
env.seed(seed)
env.action_space.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
env.observation_space.seed(seed)
return env
return env
return thunk
return thunk
envs = VecPyTorch(DummyVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)]), device)
envs = VecPyTorch(DummyVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)]), device)
if args.prod_mode:
if args.prod_mode:
envs = VecPyTorch(
envs = VecPyTorch(
SubprocVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)], "fork"),
SubprocVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)], "fork"),
device
device
)
)
assert isinstance(envs.action_space, Box), "only continuous action space is supported"
assert isinstance(envs.action_space, Discrete), "only discrete action space is supported"


# ALGO LOGIC: initialize agent here:
# ALGO LOGIC: initialize agent here:
class Scale(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale

def forward(self, x):
return x * self.scale

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
return layer


class Agent(nn.Module):
class Agent(nn.Module):
def __init__(self):
def __init__(self, frames=4):
super(Agent, self).__init__()
super(Agent, self).__init__()
self.critic = nn.Sequential(
self.network = nn.Sequential(
layer_init(nn.Linear(np.array(envs.observation_space.shape).prod(), 64)),
Scale(1/255),
nn.Tanh(),
layer_init(nn.Conv2d(frames, 32, 8, stride=4)),
layer_init(nn.Linear(64, 64)),
nn.ReLU(),
nn.Tanh(),
layer_init(nn.Conv2d(32, 64, 4, stride=2)),
layer_init(nn.Linear(64, 1), std=1.),
nn.ReLU(),
)
layer_init(nn.Conv2d(64, 64, 3, stride=1)),
self.actor_mean = nn.Sequential(
nn.ReLU(),
layer_init(nn.Linear(np.array(envs.observation_space.shape).prod(), 64)),
nn.Flatten(),
nn.Tanh(),
layer_init(nn.Linear(3136, 512)),
layer_init(nn.Linear(64, 64)),
nn.ReLU()
nn.Tanh(),
layer_init(nn.Linear(64, np.prod(envs.action_space.shape)), std=0.01),
)
)
self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.action_space.shape)))
self.actor = layer_init(nn.Linear(512, envs.action_space.n), std=0.01)
self.critic = layer_init(nn.Linear(512, 1), std=1)

def forward(self, x):
return self.network(x)


def get_action(self, x, action=None):
def get_action(self, x, action=None):
action_mean = self.actor_mean(x)
logits = self.actor(self.forward(x))
action_logstd = self.actor_logstd.expand_as(action_mean)
probs = Categorical(logits=logits)
action_std = torch.exp(action_logstd)
probs = Normal(action_mean, action_std)
if action is None:
if action is None:
action = probs.sample()
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1)
return action, probs.log_prob(action), probs.entropy()


def get_value(self, x):
def get_value(self, x):
return self.critic(x)
return self.critic(self.forward(x))


agent = Agent().to(device)
agent = Agent().to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
if args.anneal_lr:
if args.anneal_lr:
# https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
# https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
lr = lambda f: f * args.learning_rate
lr = lambda f: f * args.learning_rate


# ALGO Logic: Storage for epoch data
# ALGO Logic: Storage for epoch data
obs = torch.zeros((args.num_steps, args.num_envs) + envs.observation_space.shape).to(device)
obs = torch.zeros((args.num_steps, args.num_envs) + envs.observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.action_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)


# TRY NOT TO MODIFY: start the game
# TRY NOT TO MODIFY: start the game
global_step = 0
global_step = 0
episode_step = 0
episode_step = 0
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
next_obs = envs.reset()
next_obs = envs.reset()
next_done = torch.zeros(args.num_envs).to(device)
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size
num_updates = args.total_timesteps // args.batch_size
for update in range(1, num_updates+1):
for update in range(1, num_updates+1):
# Annealing the rate if instructed to do so.
# Annealing the rate if instructed to do so.
if args.anneal_lr:
if args.anneal_lr:
frac = 1.0 - (update - 1.0) / num_updates
frac = 1.0 - (update - 1.0) / num_updates
lrnow = lr(frac)
lrnow = lr(frac)
optimizer.param_groups[0]['lr'] = lrnow
optimizer.param_groups[0]['lr'] = lrnow


# TRY NOT TO MODIFY: prepare the execution of the game.
# TRY NOT TO MODIFY: prepare the execution of the game.
for step in range(0, args.num_steps):
for step in range(0, args.num_steps):
global_step += 1 * args.num_envs
global_step += 1 * args.num_envs
obs[step] = next_obs
obs[step] = next_obs
dones[step] = next_done
dones[step] = next_done


# ALGO LOGIC: put action logic here
# ALGO LOGIC: put action logic here
with torch.no_grad():
with torch.no_grad():
values[step] = agent.get_value(obs[step]).flatten()
values[step] = agent.get_value(obs[step]).flatten()
action, logproba, _ = agent.get_action(obs[step])
action, logproba, _ = agent.get_action(obs[step])


actions[step] = action
actions[step] = action
logprobs[step] = logproba
logprobs[step] = logproba


# TRY NOT TO MODIFY: execute the game and log data.
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rs, ds, infos = envs.step(action)
next_obs, rs, ds, infos = envs.step(action)
rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)
rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)


for info in infos:
for info in infos:
if 'episode' in info.keys():
if 'episode' in info.keys():
episode_step += info['episode']['l']
episode_step += info['episode']['l']
print(f"global_step={episode_step}, episode_reward={info['episode']['r']}")
print(f"global_step={episode_step}, episode_reward={info['episode']['r']}")
writer.add_scalar("charts/episode_reward", info['episode']['r'], episode_step)
writer.add_scalar("charts/episode_reward", info['episode']['r'], episode_step)


# bootstrap reward if not done. reached the batch limit
# bootstrap reward if not done. reached the batch limit
with torch.no_grad():
with torch.no_grad():
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)
if args.gae:
if args.gae:
advantages = torch.zeros_like(rewards).to(device)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
lastgaelam = 0
for t in reversed(range(args.num_steps)):
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = 1.0 - next_done
nextvalues = last_value
nextvalues = last_value
else:
else:
nextnonterminal = 1.0 - dones[t+1]
nextnonterminal = 1.0 - dones[t+1]
nextvalues = values[t+1]
nextvalues = values[t+1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
returns = advantages + values
else:
else:
returns = torch.zeros_like(rewards).to(device)
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextnonterminal = 1.0 - next_done
next_return = last_value
next_return = last_value
else:
else:
nextnonterminal = 1.0 - dones[t+1]
nextnonterminal = 1.0 - dones[t+1]
next_return = returns[t+1]
next_return = returns[t+1]
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values
advantages = returns - values


# flatten the batch
# flatten the batch
b_obs = obs.reshape((-1,)+envs.observation_space.shape)
b_obs = obs.reshape((-1,)+envs.observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,)+envs.action_space.shape)
b_actions = actions.reshape((-1,)+envs.action_space.shape)
b_advantages = advantages.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_values = values.reshape(-1)


# Optimizaing the policy and value network
# Optimizaing the policy and value network
target_agent = Agent().to(device)
target_agent = Agent().to(device)
inds = np.arange(args.batch_size,)
inds = np.arange(args.batch_size,)
for i_epoch_pi in range(args.update_epochs):
for i_epoch_pi in range(args.update_epochs):
np.random.shuffle(inds)
np.random.shuffle(inds)
target_agent.load_state_dict(agent.state_dict())
target_agent.load_state_dict(agent.state_dict())
for start in range(0, args.batch_size, args.minibatch_size):
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
end = start + args.minibatch_size
minibatch_ind = inds[start:end]
minibatch_ind = inds[start:end]
mb_advantages = b_advantages[minibatch_ind]
mb_advantages = b_advantages[minibatch_ind]
if args.norm_adv:
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)


_, newlogproba, entropy = agent.get_action(b_obs[minibatch_ind], b_actions[minibatch_ind])
_, newlogproba, entropy = agent.get_action(b_obs[minibatch_ind], b_actions.long()[minibatch_ind])
ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()
ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()


# Stats
# Stats
approx_kl = (b_logprobs[minibatch_ind] - newlogproba).mean()
approx_kl = (b_logprobs[minibatch_ind] - newlogproba).mean()


# Policy loss
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-args.clip_coef, 1+args.clip_coef)
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-args.clip_coef, 1+args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
entropy_loss = entropy.mean()
entropy_loss = entropy.mean()


# Value loss
# Value loss
new_values = agent.get_value(b_obs[minibatch_ind]).view(-1)
new_values = agent.get_value(b_obs[minibatch_ind]).view(-1)
if args.clip_vloss:
if args.clip_vloss:
v_loss_unclipped = ((new_values - b_returns[minibatch_ind]) ** 2)
v_loss_unclipped = ((new_values - b_returns[minibatch_ind]) ** 2)
v_clipped = b_values[minibatch_ind] + torch.clamp(new_values - b_values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_clipped = b_values[minibatch_ind] + torch.clamp(new_values - b_values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_loss_clipped = (v_clipped - b_returns[minibatch_ind])**2
v_loss_clipped
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 *((new_values - b_returns[minibatch_ind]) ** 2)

loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()

if args.kle_stop:
if approx_kl > args.target_kl:
break
if args.kle_rollback:
if (b_logprobs[minibatch_ind] - agent.get_action(b_obs[minibatch_ind], b_actions[minibatch_ind])[1]).mean() > args.target_kl:
agent.load_state_dict(target_agent.state_dict())
break

# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy.mean().item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
if args.kle_stop or args.kle_rollback:
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)

envs.close()
writer.close()