log num tokens and unnormalized loss on main
584 lines
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# All rights reserved.
#
#
# This source code is licensed under the BSD-style license found in the
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import sys
import sys
import time
import time
from functools import partial
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn
from warnings import warn
import torch
import torch
from omegaconf import DictConfig, ListConfig
from omegaconf import DictConfig, ListConfig
from torch import nn
from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed import destroy_process_group, init_process_group
from torch.optim import Optimizer
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune import config, modules, training, utils
from torchtune.config._utils import _get_component_from_path
from torchtune.config._utils import _get_component_from_path
from torchtune.data import padded_collate_packed
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.lr_schedulers import get_lr
from torchtune.training.lr_schedulers import get_lr
from tqdm import tqdm
from tqdm import tqdm
log = utils.get_logger("DEBUG")
log = utils.get_logger("DEBUG")
class FullFinetuneRecipeDistributed(FTRecipeInterface):
class FullFinetuneRecipeDistributed(FTRecipeInterface):
"""
"""
Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
distributed training and can be run on a single node (1 to 8 GPUs).
distributed training and can be run on a single node (1 to 8 GPUs).
Features:
Features:
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.
DDP is currently not supported. Training on CPU is not supported.
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.
a result of this activation recomputation.
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
most cases this should halve the memory footprint of full precision (fp32) training, without
loss in model quality (will depend on the model, training data and other settings). For
loss in model quality (will depend on the model, training data and other settings). For
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
precision are currently not supported.
precision are currently not supported.
- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
controlled using the ``gradient_accumulation_steps`` flag.
controlled using the ``gradient_accumulation_steps`` flag.
Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
total batch size of 64.
total batch size of 64.
Gradient accumulation is especially useful when you are memory constrained. In this case,
Gradient accumulation is especially useful when you are memory constrained. In this case,
accumulating gradients might give you better training speed than enabling activation
accumulating gradients might give you better training speed than enabling activation
checkpointing.
checkpointing.
- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
only saved at the end of a given epoch and used in case of resuming training.
only saved at the end of a given epoch and used in case of resuming training.
Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
currently not supported.
currently not supported.
For more details on the checkpointer, please take a look at
For more details on the checkpointer, please take a look at
our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
- Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
- Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
``clip_grad_norm='inf'``.
``clip_grad_norm='inf'``.
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
has example commands for how to kick-off training.
Args:
Args:
cfg (DictConfig): OmegaConf object parsed from yaml file
cfg (DictConfig): OmegaConf object parsed from yaml file
Raises:
Raises:
ValueError: If ``dtype`` is set to fp16.
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
"""
"""
def __init__(self, cfg: DictConfig) -> None:
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(cfg.dtype, device=self._device)
self._dtype = training.get_dtype(cfg.dtype, device=self._device)
if self._dtype == torch.float16:
if self._dtype == torch.float16:
raise ValueError(
raise ValueError(
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)
)
if (
if (
cfg.get("fsdp_cpu_offload", False)
cfg.get("fsdp_cpu_offload", False)
and cfg.optimizer.get("fused", False)
and cfg.optimizer.get("fused", False)
and not utils.torch_version_ge("2.4.0")
and not utils.torch_version_ge("2.4.0")
):
):
raise RuntimeError(
raise RuntimeError(
"Using fused optimizer on CPU is only supported in PyTorch nightly."
"Using fused optimizer on CPU is only supported in PyTorch nightly."
)
)
# logging attributes
# logging attributes
self._output_dir = cfg.output_dir
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
if self._log_peak_memory_stats and self._device.type != "cuda":
if self._log_peak_memory_stats and self._device.type != "cuda":
log.info(
log.info(
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
)
)
self._log_peak_memory_stats = False
self._log_peak_memory_stats = False
# _is_rank_zero is used primarily for logging. In the future, the logger
# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
# should directly take care of this
_, rank = training.get_world_size_and_rank()
_, rank = training.get_world_size_and_rank()
self._rank = rank
self._is_rank_zero = rank == 0
self._is_rank_zero = rank == 0
# Training cfg
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
# These are public properties which are updated by the checkpoint loader
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = training.set_seed(seed=cfg.seed)
self.seed = training.set_seed(seed=cfg.seed)
self.epochs_run = 0
self.epochs_run = 0
self.total_epochs = cfg.epochs
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
self.global_step = 0
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
is True, this also includes the recipe state.
"""
"""
self._checkpointer = config.instantiate(
self._checkpointer = config.instantiate(
cfg_checkpointer,
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
)
)
checkpoint_dict = self._checkpointer.load_checkpoint()
checkpoint_dict = self._checkpointer.load_checkpoint()
if self._resume_from_checkpoint:
if self._resume_from_checkpoint:
self._update_recipe_state(checkpoint_dict)
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict
return checkpoint_dict
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
"""
"""
Updates the recipe state from checkpoint.
Updates the recipe state from checkpoint.
"""
"""
try:
try:
self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
# on mismatch, warn the user and prevent the override
# on mismatch, warn the user and prevent the override
if self.seed != ckpt_dict[training.SEED_KEY]:
if self.seed != ckpt_dict[training.SEED_KEY]:
warn(
warn(
message=(
message=(
"Config value for seed does not match the checkpoint value, "
"Config value for seed does not match the checkpoint value, "
f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
)
)
)
)
self.seed = ckpt_dict[training.SEED_KEY]
self.seed = ckpt_dict[training.SEED_KEY]
if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
warn(
warn(
message=(
message=(
"Config value for max_steps_per_epoch does not match the checkpoint value, "
"Config value for max_steps_per_epoch does not match the checkpoint value, "
f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
)
)
)
)
self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
# on mismatch, warn the user but allow the override
# on mismatch, warn the user but allow the override
if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
warn(
warn(
message=(
message=(
"Config value for total_epochs does not match the checkpoint value, "
"Config value for total_epochs does not match the checkpoint value, "
f"using the config value: {self.total_epochs}"
f"using the config value: {self.total_epochs}"
)
)
)
)
except KeyError as e:
except KeyError as e:
raise KeyError(
raise KeyError(
"Checkpoint does not contain the required keys needed for updating recipe state. "
"Checkpoint does not contain the required keys needed for updating recipe state. "
"Are you sure you passed in the right recipe checkpoint?"
"Are you sure you passed in the right recipe checkpoint?"
) from e
) from e
def setup(self, cfg: DictConfig) -> None:
def setup(self, cfg: DictConfig) -> None:
"""
"""
Setup the recipe. This includes training state (if resume_from_checkpoint is True),
Setup the recipe. This includes training state (if resume_from_checkpoint is True),
model, tokenizer, loss, optimizer, sampler, and dataloader.
model, tokenizer, loss, optimizer, sampler, and dataloader.
"""
"""
if self._is_rank_zero:
if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)
self._metric_logger = config.instantiate(cfg.metric_logger)
# log config with parameter override
# log config with parameter override
self._metric_logger.log_config(cfg)
self._metric_logger.log_config(cfg)
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
self._compile = cfg.get("compile", False)
self._compile = cfg.get("compile", False)
self._model = self._setup_model(
self._model = self._setup_model(
cfg_model=cfg.model,
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
model_state_dict=checkpoint_dict[training.MODEL_KEY],
model_state_dict=checkpoint_dict[training.MODEL_KEY],
ac_mode=cfg.get("ac_mode", None),
ac_mode=cfg.get("ac_mode", None),
ac_option=cfg.get("ac_option", None),
ac_option=cfg.get("ac_option", None),
)
)
self._tokenizer = config.instantiate(cfg.tokenizer)
self._tokenizer = config.instantiate(cfg.tokenizer)
self._optimizer = self._setup_optimizer(
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=self._optimizer_in_bwd,
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=(
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
if self._resume_from_checkpoint
else None
else None
),
),
)
)
# initialize loss
# initialize loss
self._loss_fn = config.instantiate(cfg.loss)
self._loss_fn = config.instantiate(cfg.loss)
if self._compile:
if self._compile:
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
# set num_output_chunks for model
# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
if self._is_rank_zero:
if self._is_rank_zero:
log.info("Loss is initialized.")
log.info("Loss is initialized.")
# sampler and dataloader depend on the tokenizer and loss_fn and should be
# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
# setup after both of these are initialized
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
self._sampler, self._dataloader = self._setup_data(
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
batch_size=cfg.batch_size,
collate_fn=collate_name,
collate_fn=collate_name,
)
)
# Finally update the recipe state which can only be correctly set after all of the
# Finally update the recipe state which can only be correctly set after all of the
# other components have been initialized and updated.
# other components have been initialized and updated.
#
#
# Number of training steps in each epoch depends on the number of batches produced
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader, the max_steps_per_epoch param set by the user and the
# by the dataloader, the max_steps_per_epoch param set by the user and the
# gradient_accumulation_steps param. This value is used for logging and tracking
# gradient_accumulation_steps param. This value is used for logging and tracking
# training state. The computation should happen after the dataloader has been setup
# training state. The computation should happen after the dataloader has been setup
self._steps_per_epoch = (
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
len(self._dataloader) // self._gradient_accumulation_steps
)
)
if (
if (
self.max_steps_per_epoch is not None
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
and self.max_steps_per_epoch < self._steps_per_epoch
):
):
self._steps_per_epoch = self.max_steps_per_epoch
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
# Used to ignore labels for loss computation
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full(
self.ignore_labels_cache = torch.full(
(cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
(cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
)
def _setup_profiler(
def _setup_profiler(
self, cfg_profiler: Optional[DictConfig] = None
self, cfg_profiler: Optional[DictConfig] = None
) -> Union[torch.profiler.profile, DummyProfiler]:
) -> Union[torch.profiler.profile, DummyProfiler]:
"""
"""
Parses the `profiler` section of top-level `cfg` and sets up profiler
Parses the `profiler` section of top-level `cfg` and sets up profiler
Args:
Args:
cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
`recipe.main`). Default None.
`recipe.main`). Default None.
Returns:
Returns:
profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
that the instrumented training loop does not need to be changed profiling is disabled.
that the instrumented training loop does not need to be changed profiling is disabled.
The profiler config can be provided in configs under the `profiler` key with the following layout:
The profiler config can be provided in configs under the `profiler` key with the following layout:
.. code-block:: yaml
.. code-block:: yaml
profiler:
profiler:
enabled: bool
enabled: bool
#Output directory of trace artifacts
#Output directory of trace artifacts
output_dir: str
output_dir: str
#`torch.profiler.ProfilerActivity` types to trace
#`torch.profiler.ProfilerActivity` types to trace
cpu: bool
cpu: bool
cuda: bool
cuda: bool
#Trace options
#Trace options
profile_memory: bool
profile_memory: bool
with_stack: bool
with_stack: bool
record_shapes: bool
record_shapes: bool
with_flops: bool
with_flops: bool
# `torch.profiler.schedule` options:
# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: int
wait_steps: int
warmup_steps: int
warmup_steps: int
active_steps: int
active_steps: int
num_cycles: int
num_cycles: int
"""
"""
# Missing profiler section in config, assume disabled
# Missing profiler section in config, assume disabled
if cfg_profiler is None:
if cfg_profiler is None:
cfg_profiler = DictConfig({"enabled": False})
cfg_profiler = DictConfig({"enabled": False})
# Check that component is included and set correctly
# Check that component is included and set correctly
if cfg_profiler.get("_component_", None) is None:
if cfg_profiler.get("_component_", None) is None:
cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
else:
else:
assert (
assert (
cfg_profiler.get("_component_")
cfg_profiler.get("_component_")
== "torchtune.training.setup_torch_profiler"
== "torchtune.training.setup_torch_profiler"
), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
profiler, profiler_cfg = config.instantiate(cfg_profiler)
profiler, profiler_cfg = config.instantiate(cfg_profiler)
if self._is_rank_zero:
if self._is_rank_zero:
log.info(f" Profiler config after instantiation: {profiler_cfg}")
log.info(f" Profiler config after instantiation: {profiler_cfg}")
self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
if profiler_cfg["enabled"]:
if profiler_cfg["enabled"]:
self.profiler_wait_steps = profiler_cfg["wait_steps"]
self.profiler_wait_steps = profiler_cfg["wait_steps"]
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
self.profiler_active_steps = profiler_cfg["active_steps"]
self.profiler_active_steps = profiler_cfg["active_steps"]
return profiler
return profiler
def _setup_model(
def _setup_model(
self,
self,
cfg_model: DictConfig,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_checkpointing: bool,
fsdp_cpu_offload: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
custom_sharded_layers: Optional[List[str]] = None,
ac_mode: Optional[str] = None,
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
ac_option: Optional[int] = None,
) -> nn.Module:
) -> nn.Module:
"""
"""
Model initialization has some important considerations:
Model initialization has some important considerations:
a. To minimize GPU peak memory, we initialize the model on meta device with
a. To minimize GPU peak memory, we initialize the model on meta device with
the right dtype
the right dtype
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
full state dicts are loaded with ``torch.load(mmap=True)``
full state dicts are loaded with ``torch.load(mmap=True)``
"""
"""
if self._is_rank_zero:
if self._is_rank_zero:
log.info(
log.info(
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
)
)
init_start = time.perf_counter()
init_start = time.perf_counter()
with training.set_default_dtype(self._dtype), torch.device("meta"):
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
model = config.instantiate(cfg_model)
if self._compile:
if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)
training.compile_model(model, verbose=self._is_rank_zero)
# We currently have two versions of activation checkpointing in this recipe
# We currently have two versions of activation checkpointing in this recipe
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
# the older version of AC and this behavior is unchanged
# the older version of AC and this behavior is unchanged
# ac_mode and ac_option together control selective AC. This is only enabled
# ac_mode and ac_option together control selective AC. This is only enabled
# when these are set AND ``enable_activation_checkpointing`` is set to False
# when these are set AND ``enable_activation_checkpointing`` is set to False
# We'll clean this up as soon as testing of AC is complete
# We'll clean this up as soon as testing of AC is complete
if (not enable_activation_checkpointing) and (ac_mode is not None):
if (not enable_activation_checkpointing) and (ac_mode is not None):
apply_selective_activation_checkpointing(
apply_selective_activation_checkpointing(
model,
model,
ac_mode,
ac_mode,
ac_option,
ac_option,
)
)
# original activation checkpointing (full) - flip the condition above
# original activation checkpointing (full) - flip the condition above
if enable_activation_checkpointing and ac_mode is None:
if enable_activation_checkpointing and ac_mode is None:
training.set_activation_checkpointing(
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
)
# For FSDP sharding
# For FSDP sharding
fsdp_shard_conditions = [
fsdp_shard_conditions = [
partial(
partial(
training.get_shard_conditions,
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
names_to_match=custom_sharded_layers,
)
)
]
]
training.shard_model(
training.shard_model(
model=model,
model=model,
shard_conditions=fsdp_shard_conditions,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
reshard_after_forward=reshard_after_forward,
)
)
with training.set_default_dtype(self._dtype), self._device:
with training.set_default_dtype(self._dtype), self._device:
for m in model.modules():
for m in model.modules():
# RoPE is not covered in state dict
# RoPE is not covered in state dict
if hasattr(m, "rope_init"):
if hasattr(m, "rope_init"):
m.rope_init()
m.rope_init()
# This method will convert the full model state dict into a sharded state
# This method will convert the full model state dict into a sharded state
# dict and load into the model
# dict and load into the model
training.load_from_full_model_state_dict(
training.load_from_full_model_state_dict(
model,
model,
model_state_dict,
model_state_dict,
self._device,
self._device,
self._is_rank_zero,
self._is_rank_zero,
strict=True,
strict=True,
cpu_offload=fsdp_cpu_offload,
cpu_offload=fsdp_cpu_offload,
)
)
# Ensure no params and buffers are on meta device
# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)
training.validate_no_params_on_meta_device(model)
if self._is_rank_zero:
if self._is_rank_zero:
log.info(
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
)
)
memory_stats = training.get_memory_stats(device=self._device)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
training.log_memory_stats(memory_stats)
# synchronize before training begins
# synchronize before training begins
torch.distributed.barrier()
torch.distributed.barrier()
return model
return model
def _setup_optimizer(
def _setup_optimizer(
self,
self,
cfg_optimizer: DictConfig,
cfg_optimizer: DictConfig,
optimizer_in_bwd: bool = False,
optimizer_in_bwd: bool = False,
opt_state_dict: Optional[Dict[str, Any]] = None,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optional[Optimizer]:
) -> Optional[Optimizer]:
if optimizer_in_bwd:
if optimizer_in_bwd:
# Maintain a dict of optims for every parameter.
# Maintain a dict of optims for every parameter.
optim_dict = {
optim_dict = {
param: config.instantiate(cfg_optimizer, [param])
param: config.instantiate(cfg_optimizer, [param])
for param in self._model.parameters()
for param in self._model.parameters()
}
}
# Register optimizer step hooks on the model to run optimizer in backward.
# Register optimizer step hooks on the model to run optimizer in backward.
training.register_optim_in_bwd_hooks(
training.register_optim_in_bwd_hooks(
model=self._model, optim_dict=optim_dict
model=self._model, optim_dict=optim_dict
)
)
# Create a wrapper for checkpoint save/load of optimizer states when running in backward.
# Create a wrapper for checkpoint save/load of optimizer states when running in backward.
self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
model=self._model, optim_dict=optim_dict
model=self._model, optim_dict=optim_dict
)
)
# Load optimizer states for each param. If optimizer states are being restored in an optimizer in
# Load optimizer states for each param. If optimizer states are being restored in an optimizer in
# backward run, these need to have been saved with the same setting. Cannot restore from runs that
# backward run, these need to have been saved with the same setting. Cannot restore from runs that
# did not use optimizer in backward.
# did not use optimizer in backward.
if opt_state_dict is not None:
if opt_state_dict is not None:
for param in opt_state_dict.keys():
for param in opt_state_dict.keys():
try:
try:
training.load_from_full_optimizer_state_dict(
training.load_from_full_optimizer_state_dict(
self._optim_ckpt_wrapper.state_dict()[param],
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
opt_state_dict[param],
self._device,
self._device,
)
)
except BaseException as e:
except BaseException as e:
raise RuntimeError(
raise RuntimeError(
"Failed loading in-backward optimizer checkpoints."
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
"Please make sure run being restored from was using in-backward optimizer."
) from e
) from e
if self._is_rank_zero:
if self._is_rank_zero:
log.info("In-backward optimizers are set up.")
log.info("In-backward optimizers are set up.")
return None
return None
else:
else:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
training.load_from_full_optimizer_state_dict(
optimizer,
optimizer,
opt_state_dict,
opt_state_dict,
self._device,
self._device,
)
)
if self._is_rank_zero:
if self._is_rank_zero:
log.info("Optimizer is initialized.")
log.info("Optimizer is initialized.")
return optimizer
return optimizer
def _setup_data(
def _setup_data(
self,
self,
cfg_dataset: DictConfig,
cfg_dataset: DictConfig,
shuffle: bool,
shuffle: bool,
batch_size: int,
batch_size: int,
collate_fn: str,
collate_fn: str,
) -> Tuple[DistributedSampler, DataLoader]:
) -> Tuple[DistributedSampler, DataLoader]:
"""
"""
All data related setup happens here. Currently this recipe only supports the
All data related setup happens here. Currently this recipe only supports the
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
iterable datasets and streaming datasets are not supported.
"""
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()
if isinstance(cfg_dataset, ListConfig):
if isinstance(cfg_dataset, ListConfig):
datasets = [
datasets = [
config.instantiate(single_cfg_dataset, self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
for single_cfg_dataset in cfg_dataset
]
]
ds = ConcatDataset(datasets=datasets)
ds = ConcatDataset(datasets=datasets)
packed = False
packed = False
else:
else:
ds = config.instantiate(cfg_dataset, self._tokenizer)
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)
packed = cfg_dataset.get("packed", False)
# Instantiate collate_fn
# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
if "left_pad_sequence" in collate_fn:
raise RuntimeError("left_pad_sequence collator is only for inference.")
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)
collate_fn = _get_component_from_path(collate_fn)
sampler = DistributedSampler(
sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
)
dataloader = DataLoader(
dataloader = DataLoader(
dataset=ds,
dataset=ds,
batch_size=batch_size,
batch_size=batch_size,
sampler=sampler,
sampler=sampler,
# dropping last avoids shape issues with compile + flex attention
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
drop_last=True,
collate_fn=(
collate_fn=(
partial(
partial(
collate_fn,
collate_fn,
padding_idx=self._tokenizer.pad_id,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
ignore_idx=self._loss_fn.ignore_index,
)
)
if not packed
if not packed
else padded_collate_packed
else padded_collate_packed
),
),
)
)
if self._is_rank_zero:
if self._is_rank_zero:
log.info("Dataset and Sampler are initialized.")
log.info("Dataset and Sampler are initialized.")
return sampler, dataloader
return sampler, dataloader
def save_checkpoint(
def save_checkpoint(
self,
self,
epoch: int,
epoch: int,
) -> None:
) -> None:
"""
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
contains the following information:
- Model weights with key training.MODEL_KEY
- Model weights with key training.MODEL_KEY
- Relevant recipe state if training is not complete
- Relevant recipe state if training is not complete
Checkpointer will save the model weights and recipe state in
Checkpointer will save the model weights and recipe state in
different checkpoint files. To correctly resume training from an intermediate checkpoint,
different checkpoint files. To correctly resume training from an intermediate checkpoint,
the model weights and recipe state must be provided.
the model weights and recipe state must be provided.
"""
"""
# final dict passed onto the checkpointer
# final dict passed onto the checkpointer
checkpoint_dict = {}
checkpoint_dict = {}
intermediate_checkpoint = epoch + 1 < self.total_epochs
intermediate_checkpoint = epoch + 1 < self.total_epochs
if self._is_rank_zero:
if self._is_rank_zero:
log.info(
log.info(
"Saving checkpoint. This may take some time. Retrieving full model state dict..."
"Saving checkpoint. This may take some time. Retrieving full model state dict..."
)
)
start = time.perf_counter(