on_off_policy
151 lines
import time
import time
import tqdm
import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List, Union, Callable, Optional
from typing import Dict, List, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg
from tianshou.utils import tqdm_config, MovAvg
from tianshou.trainer import test_episode, gather_info
from tianshou.trainer import test_episode, gather_info
def onpolicy_trainer(
def offpolicy_trainer(
policy: BasePolicy,
policy: BasePolicy,
train_collector: Collector,
train_collector: Collector,
test_collector: Collector,
test_collector: Collector,
max_epoch: int,
max_epoch: int,
step_per_epoch: int,
step_per_epoch: int,
collect_per_step: int,
collect_per_step: int,
repeat_per_collect: int,
episode_per_test: Union[int, List[int]],
episode_per_test: Union[int, List[int]],
batch_size: int,
batch_size: int,
update_per_step: int = 1,
train_fn: Optional[Callable[[int, int], None]] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
writer: Optional[SummaryWriter] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
log_interval: int = 1,
verbose: bool = True,
verbose: bool = True,
test_in_train: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure.
"""A wrapper for off-policy trainer procedure.
The "step" in trainer means a policy network update.
The "step" in trainer means a policy network update.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
class.
:param train_collector: the collector used for training.
:param train_collector: the collector used for training.
:type train_collector: :class:`~tianshou.data.Collector`
:type train_collector: :class:`~tianshou.data.Collector`
:param test_collector: the collector used for testing.
:param test_collector: the collector used for testing.
:type test_collector: :class:`~tianshou.data.Collector`
:type test_collector: :class:`~tianshou.data.Collector`
:param int max_epoch: the maximum of epochs for training. The training
:param int max_epoch: the maximum of epochs for training. The training
process might be finished before reaching the ``max_epoch``.
process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of step for updating policy network
:param int step_per_epoch: the number of step for updating policy network
in one epoch.
in one epoch.
:param int collect_per_step: the number of episodes the collector would
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some
collect before the network update. In other words, collect some frames
episodes and do one policy network update.
and do some policy network update.
:param int repeat_per_collect: the number of repeat time for policy
learning, for example, set it to 2 means the policy needs to learn each
given batch data twice.
:param episode_per_test: the number of episodes for one policy evaluation.
:param episode_per_test: the number of episodes for one policy evaluation.
:type episode_per_test: int or list of ints
:param int batch_size: the batch size of sample data, which is going to
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
feed in the policy network.
:param int update_per_step: the number of times the policy network would
be updated after frames are collected, for example, set it to 256 means
it updates policy 256 times once after ``collect_per_step`` frames are
collected.
:param function train_fn: a function receives the current number of epoch
:param function train_fn: a function receives the current number of epoch
and step index, and performs some operations at the beginning of
and step index, and performs some operations at the beginning of
training in this poch.
training in this epoch.
:param function test_fn: a function receives the current number of epoch
:param function test_fn: a function receives the current number of epoch
and step index, and performs some operations at the beginning of
and step index, and performs some operations at the beginning of
testing in this epoch.
testing in this epoch.
:param function save_fn: a function for saving policy when the undiscounted
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
returns of the testing result, return a boolean which indicates whether
reaching the goal.
reaching the goal.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
SummaryWriter.
:param int log_interval: the log interval of the writer.
:param int log_interval: the log interval of the writer.
:param bool verbose: whether to print the information.
:param bool verbose: whether to print the information.
:param bool test_in_train: whether to test in the training phase.
:param bool test_in_train: whether to test in the training phase.
:return: See :func:`~tianshou.trainer.gather_info`.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
"""
env_step, gradient_step = 0, 0
env_step, gradient_step = 0, 0
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
stat: Dict[str, MovAvg] = {}
stat: Dict[str, MovAvg] = {}
start_time = time.time()
start_time = time.time()
train_collector.reset_stat()
train_collector.reset_stat()
test_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
for epoch in range(1, 1 + max_epoch):
# train
# train
policy.train()
policy.train()
with tqdm.tqdm(
with tqdm.tqdm(
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
) as t:
while t.n < t.total:
while t.n < t.total:
if train_fn:
if train_fn:
train_fn(epoch, env_step)
train_fn(epoch, env_step)
result = train_collector.collect(n_episode=collect_per_step)
result = train_collector.collect(n_step=collect_per_step)
env_step += int(result["n/st"])
env_step += int(result["n/st"])
data = {
data = {
"env_step": str(env_step),
"env_step": str(env_step),
"rew": f"{result['rew']:.2f}",
"rew": f"{result['rew']:.2f}",
"len": str(int(result["len"])),
"len": str(int(result["len"])),
"n/ep": str(int(result["n/ep"])),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
"n/st": str(int(result["n/st"])),
"v/ep": f"{result['v/ep']:.2f}",
"v/ep": f"{result['v/ep']:.2f}",
"v/st": f"{result['v/st']:.2f}",
"v/st": f"{result['v/st']:.2f}",
}
}
if writer and env_step % log_interval == 0:
if writer and env_step % log_interval == 0:
for k in result.keys():
for k in result.keys():
writer.add_scalar(
writer.add_scalar(
"train/" + k, result[k], global_step=env_step)
"train/" + k, result[k], global_step=env_step)
if test_in_train and stop_fn and stop_fn(result["rew"]):
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
test_result = test_episode(
policy, test_collector, test_fn,
policy, test_collector, test_fn,
epoch, episode_per_test, writer, env_step)
epoch, episode_per_test, writer, env_step)
if stop_fn(test_result["rew"]):
if stop_fn(test_result["rew"]):
if save_fn:
if save_fn:
save_fn(policy)
save_fn(policy)
for k in result.keys():
for k in result.keys():
data[k] = f"{result[k]:.2f}"
data[k] = f"{result[k]:.2f}"
t.set_postfix(**data)
t.set_postfix(**data)
return gather_info(
return gather_info(
start_time, train_collector, test_collector,
start_time, train_collector, test_collector,
test_result["rew"], test_result["rew_std"])
test_result["rew"], test_result["rew_std"])
else:
else:
policy.train()
policy.train()
losses = policy.update(
for i in range(update_per_step * min(
0, train_collector.buffer,
result["n/st"] // collect_per_step, t.total - t.n)):
batch_size=batch_size, repeat=repeat_per_collect)
gradient_step += 1
train_collector.reset_buffer()
losses = policy.update(batch_size, train_collector.buffer)
step = max([1] + [
for k in losses.keys():
len(v) for v in losses.values() if isinstance(v, list)])
if stat.get(k) is None:
gradient_step += step
stat[k] = MovAvg()
for k in losses.keys():
stat[k].add(losses[k])
if stat.get(k) is None:
data[k] = f"{stat[k].get():.6f}"
stat[k] = MovAvg()
if writer and gradient_step % log_interval == 0:
stat[k].add(losses[k])
writer.add_scalar(
data[k] = f"{stat[k].get():.6f}"
k, stat[k].get(), global_step=gradient_step)
if writer and gradient_step % log_interval == 0:
t.update(1)
writer.add_scalar(
t.set_postfix(**data)
k, stat[k].get(), global_step=gradient_step)
t.update(step)
t.set_postfix(**data)
if t.n <= t.total:
if t.n <= t.total:
t.update()
t.update()
# test
# test
result = test_episode(policy, test_collector, test_fn, epoch,
result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, env_step)
episode_per_test, writer, env_step)
if best_epoch == -1 or best_reward < result["rew"]:
if best_epoch == -1 or best_reward < result["rew"]:
best_reward, best_reward_std = result["rew"], result["rew_std"]
best_reward, best_reward_std = result["rew"], result["rew_std"]
best_epoch = epoch
best_epoch = epoch
if save_fn:
if save_fn:
save_fn(policy)
save_fn(policy)
if verbose:
if verbose:
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
f"{best_reward_std:.6f} in #{best_epoch}")
f"{best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
if stop_fn and stop_fn(best_reward):
break
break
return gather_info(start_time, train_collector, test_collector,
return gather_info(start_time, train_collector, test_collector,
best_reward, best_reward_std)
best_reward, best_reward_std)