fsdp_qdora_trainpy
570 lines
"""
"""
Read our announcement blog post: https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html.
Read our announcement blog post: https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html.
This script trains a model using FSDP with LoRA & QLoRA. It pulls inspiration from
This script trains a model using FSDP with LoRA & QLoRA. It pulls inspiration from
- llama-recipes (https://github.com/facebookresearch/llama-recipes/blob/main/src/llama_recipes/finetuning.py)
- llama-recipes (https://github.com/facebookresearch/llama-recipes/blob/main/src/llama_recipes/finetuning.py)
- PyTorch FSDP docs (https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- PyTorch FSDP docs (https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- bitsandbytes (https://github.com/TimDettmers/bitsandbytes)
- bitsandbytes (https://github.com/TimDettmers/bitsandbytes)
For information on the different arguments, run `python train.py --help`
For information on the different arguments, run `python train.py --help`
You should treat this script as an alpha/preview release. If you're not comfortable with testing and debugging
You should treat this script as an alpha/preview release. If you're not comfortable with testing and debugging
models, we'd suggest holding off for a few months while the community more fully tests the approach.
models, we'd suggest holding off for a few months while the community more fully tests the approach.
"""
"""
# Imports
# Imports
# General
# General
import copy
import copy
import functools
import functools
import gc
import gc
import math
import math
import os
import os
import sys
import sys
import time
import time
import types
import types
from contextlib import nullcontext
from contextlib import nullcontext
from glob import glob
from glob import glob
from pathlib import Path
from pathlib import Path
from typing import Dict, List
from typing import Dict, List
import bitsandbytes as bnb
import bitsandbytes as bnb
import safetensors
import safetensors
import torch
import torch
import torch.distributed as dist
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.multiprocessing as mp
import torch.optim as optim
import torch.optim as optim
from accelerate import init_empty_weights
from accelerate import init_empty_weights
from accelerate.utils import set_seed
from accelerate.utils import set_seed
# Model loading
# Model loading
from bitsandbytes.nn import Linear4bit, Params4bit
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from fastcore.parallel import parallel
# Argument parsing
# Argument parsing
from fastcore.script import Param, bool_arg, call_parse
from fastcore.script import Param, bool_arg, call_parse
from packaging.version import parse
from packaging.version import parse
from safetensors.torch import save_file
from safetensors.torch import save_file
# Torch + distributed training
# Torch + distributed training
from torch import Tensor, nn
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
CheckpointImpl,
apply_activation_checkpointing,
apply_activation_checkpointing,
checkpoint_wrapper,
checkpoint_wrapper,
offload_wrapper,
offload_wrapper,
)
)
# FSDP
# FSDP
from torch.distributed.fsdp import FullStateDictConfig, MixedPrecision, StateDictType
from torch.distributed.fsdp import FullStateDictConfig, MixedPrecision, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy
from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
from torch.distributed.fsdp.wrap import (
_or_policy,
_or_policy,
lambda_auto_wrap_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
transformer_auto_wrap_policy,
)
)
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import LambdaLR
from torch.profiler import ProfilerActivity, profile, record_function
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm.auto import tqdm
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.optimization import get_linear_schedule_with_warmup
from transformers.optimization import get_linear_schedule_with_warmup
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
try:
try:
from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear
from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear
except ImportError:
except ImportError:
HQQLinear = None
HQQLinear = None
pass
pass
# To add a new model, import the transformer, attention, & MLP layers
# To add a new model, import the transformer, attention, & MLP layers
# for the wrapping policy and `check_fn` in activation checkpointing
# for the wrapping policy and `check_fn` in activation checkpointing
from transformers.models.llama.modeling_llama import (
from transformers.models.llama.modeling_llama import (
LLAMA_ATTENTION_CLASSES,
LlamaAttention,
LlamaDecoderLayer,
LlamaDecoderLayer,
LlamaMLP,
LlamaMLP,
)
)
from transformers.models.mistral.modeling_mistral import (
from transformers.models.mistral.modeling_mistral import (
MISTRAL_ATTENTION_CLASSES,
MistralAttention,
MistralDecoderLayer,
MistralDecoderLayer,
MistralMLP,
MistralMLP,
)
)
# To get rid of tokenizers warnings for now
# To get rid of tokenizers warnings for now
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# For logging things during training
# For logging things during training
try:
try:
import wandb
import wandb
except ImportError:
except ImportError:
pass
pass
# LoRA and DORA modules
# LoRA and DORA modules
sys.path.append("./scripts")
sys.path.append("./scripts")
from dora import BNBDORA, HQQDORA, DORALayer, MagnitudeLayer
from dora import BNBDORA, HQQDORA, DORALayer, MagnitudeLayer
from lora import LORA
from lora import LORA
from profiling_utils import profiling_context
from profiling_utils import profiling_context
class Logger:
class Logger:
def __init__(self, args, log_to="stdout", project_name="fsdp_qlora", entity=None, group=None, name=None, rank=0):
def __init__(self, args, log_to="stdout", project_name="fsdp_qlora", entity=None, group=None, name=None, rank=0):
# self.log_every_n_steps = log_every_n_steps TODO: add this back as an option
# self.log_every_n_steps = log_every_n_steps TODO: add this back as an option
self.log_to = log_to
self.log_to = log_to
if self.log_to == "wandb" and rank==0:
if self.log_to == "wandb" and rank==0:
import wandb
import wandb
wandb.init(project=project_name, entity=entity, group=group, name=name, config=args)
wandb.init(project=project_name, entity=entity, group=group, name=name, config=args)
def log(self, d:Dict, rank:int):
def log(self, d:Dict, rank:int):
if rank != 0: return
if rank != 0: return
if self.log_to == "tqdm":
if self.log_to == "tqdm":
for k,v in d.items():
for k,v in d.items():
tqdm.write(f'{k}: {v}')
tqdm.write(f'{k}: {v}')
elif self.log_to == "wandb":
elif self.log_to == "wandb":
wandb.log(d)
wandb.log(d)
elif self.log_to == "stdout":
elif self.log_to == "stdout":
for k,v in d.items():
for k,v in d.items():
print(f'{k}: {v}')
print(f'{k}: {v}')
def finish(self, rank=0):
def finish(self, rank=0):
if self.log_to == "wandb" and rank==0: wandb.finish()
if self.log_to == "wandb" and rank==0: wandb.finish()
def update_progress_bar(progress_bar:tqdm, epoch:int, log_loss:float, log_lr:float, rank:int):
def update_progress_bar(progress_bar:tqdm, epoch:int, log_loss:float, log_lr:float, rank:int):
"""Updates the progress bar with the current epoch, loss, and learning rate"""
"""Updates the progress bar with the current epoch, loss, and learning rate"""
if rank == 0:
if rank == 0:
if log_lr >=0:
if log_lr >=0:
progress_bar.set_description(f"Epoch {epoch}, Loss {log_loss:.3f}, LR {log_lr:.2e}", refresh=True)
progress_bar.set_description(f"Epoch {epoch}, Loss {log_loss:.3f}, LR {log_lr:.2e}", refresh=True)
else:
else:
progress_bar.set_description(f"Epoch {epoch}, Loss {log_loss:.3f}", refresh=True)
progress_bar.set_description(f"Epoch {epoch}, Loss {log_loss:.3f}", refresh=True)
def n_loading_workers(quant_method:str, param_count:float):
def n_loading_workers(quant_method:str, param_count:float):
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count()/torch.cuda.device_count())
left = int(os.cpu_count()/torch.cuda.device_count())
right = int((4 if quant_method == "hqq" else 8) * (devprops.total_memory/1e9/40) * (70/(param_count/1e9)))
right = int((4 if quant_method == "hqq" else 8) * (devprops.total_memory/1e9/40) * (70/(param_count/1e9)))
return min(left, right)
return min(left, right)
# Utilities related to model loading
# Utilities related to model loading
def replace_linear(model:nn.Module, linear_replacement:nn.Module, quant_config:dict|None=None,
def replace_linear(model:nn.Module, linear_replacement:nn.Module, quant_config:dict|None=None,
skip_modules:List[str]=["lm_head"], **kwargs):
skip_modules:List[str]=["lm_head"], **kwargs):
"""
"""
Replace linear modules with a new Linear module.
Replace linear modules with a new Linear module.
Parameters:
Parameters:
model (`torch.nn.Module`):
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
List of modules names not to convert. Defaults to `lm_head`.
"""
"""
for name, module in model.named_children():
for name, module in model.named_children():
if name in skip_modules:
if name in skip_modules:
continue
continue
if len(list(module.children())) > 0:
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, quant_config, skip_modules, **kwargs)
replace_linear(module, linear_replacement, quant_config, skip_modules, **kwargs)
if isinstance(module, torch.nn.Linear):
if isinstance(module, torch.nn.Linear):
if issubclass(linear_replacement, Linear4bit):
if issubclass(linear_replacement, Linear4bit):
model._modules[name] = linear_replacement(
model._modules[name] = linear_replacement(
module.in_features,
module.in_features,
module.out_features,
module.out_features,
module.bias is not None,
module.bias is not None,
**kwargs
**kwargs
)
)
elif issubclass(linear_replacement, HQQLinear):
elif issubclass(linear_replacement, HQQLinear):
model._modules[name] = linear_replacement(module, quant_config, **kwargs)
model._modules[name] = linear_replacement(module, quant_config, **kwargs)
else:
else:
raise ValueError(f"Unsupported linear replacement: {type(linear_replacement)}")
raise ValueError(f"Unsupported linear replacement: {type(linear_replacement)}")
return model
return model
def setup_quantized_meta_for_peft(model:nn.Module):
def setup_quantized_meta_for_peft(model:nn.Module):
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
def temp_to_method(self, *args, **kwargs):
def temp_to_method(self, *args, **kwargs):
return self
return self
for param in model.parameters():
for param in model.parameters():
if isinstance(param, Params4bit):
if isinstance(param, Params4bit):
param.quant_state._orig_to = param.quant_state.to
param.quant_state._orig_to = param.quant_state.to
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
def setup_quantized_peft_meta_for_training(model:nn.Module):
def setup_quantized_peft_meta_for_training(model:nn.Module):
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
for param in model.parameters():
for param in model.parameters():
if isinstance(param, Params4bit) and hasattr(param.quant_state, '_orig_to'):
if isinstance(param, Params4bit) and hasattr(param.quant_state, '_orig_to'):
param.quant_state.to = param.quant_state._orig_to
param.quant_state.to = param.quant_state._orig_to
param.quant_state._orig_to = None
param.quant_state._orig_to = None
def load_and_quantize(module:nn.Module, name:str, value:Tensor, device:torch.device=None, dtype:torch.dtype=None,
def load_and_quantize(module:nn.Module, name:str, value:Tensor, device:torch.device=None, dtype:torch.dtype=None,
skip_names:list[str]=[], to_cpu:bool=False, to_meta:bool=False, verbose:bool=False, quant_method:str='bnb',
skip_names:list[str]=[], to_cpu:bool=False, to_meta:bool=False, verbose:bool=False, quant_method:str='bnb',
is_dora:bool=False):
is_dora:bool=False):
"""
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
"""
"""
def place_on_device(value):
def place_on_device(value):
if to_meta:
if to_meta:
device = 'meta'
device = 'meta'
elif to_cpu:
elif to_cpu:
device = 'cpu'
device = 'cpu'
return value.to(device=device, dtype=dtype)
return value.to(device=device, dtype=dtype)
if any([skip_name in name for skip_name in skip_names]):
if any([skip_name in name for skip_name in skip_names]):
if verbose:
if verbose:
print(f"Skipping {name} because it is in skip_names")
print(f"Skipping {name} because it is in skip_names")
return
return
module_key, _, value_key = name.rpartition('.')
module_key, _, value_key = name.rpartition('.')
try:
try:
submodule = module.get_submodule(module_key)
submodule = module.get_submodule(module_key)
except AttributeError as e:
except AttributeError as e:
print(f"Module {module_key} not found:\n{e}")
print(f"Module {module_key} not found:\n{e}")
return
return
try:
try:
if quant_method=='bnb':
if quant_method=='bnb':
param = submodule.get_parameter(value_key)
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
if is_dora:
if is_dora:
setattr(submodule, "dora_scale", value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"))
setattr(submodule, "dora_scale", value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"))
value = type(param)(value.to(device=device, dtype=dtype).data, **param.__dict__).cuda(device)
value = type(param)(value.to(device=device, dtype=dtype).data, **param.__dict__).cuda(device)
if to_meta:
if to_meta:
value = type(param)(value.data.to("meta"), **value.__dict__)
value = type(param)(value.data.to("meta"), **value.__dict__)
elif to_cpu:
elif to_cpu:
value = type(param)(value.data.to("cpu"), **value.__dict__)
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
else:
value = type(param)(place_on_device(value).data)
value = type(param)(place_on_device(value).data)
elif quant_method=='hqq':
elif quant_method=='hqq':
if isinstance(submodule, HQQLinear):
if isinstance(submodule, HQQLinear):
if value_key == "weight":
if value_key == "weight":
# Like `Params4bit`, this workaround quantizes `HQQLinear`` per device so the quantization
# Like `Params4bit`, this workaround quantizes `HQQLinear`` per device so the quantization
# meta dictionary is created on all ranks, before converting to meta on non-rank 0.
# meta dictionary is created on all ranks, before converting to meta on non-rank 0.
submodule.linear_layer.to_empty(device=device)
submodule.linear_layer.to_empty(device=device)
submodule.linear_layer.weight.data.copy_(value.to(device=device, dtype=dtype))
submodule.linear_layer.weight.data.copy_(value.to(device=device, dtype=dtype))
if is_dora:
if is_dora:
setattr(submodule, "dora_scale", value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"))
setattr(submodule, "dora_scale", value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"))
submodule.initialize()
submodule.initialize()
if to_meta:
if to_meta:
setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("meta")))
setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("meta")))
elif to_cpu:
elif to_cpu:
setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("cpu")))
setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("cpu")))
submodule.in_gpu = False
submodule.in_gpu = False
if value_key == "bias":
if value_key == "bias":
raise ValueError("Bias not supported in HQQLinear yet!")
raise ValueError("Bias not supported in HQQLinear yet!")
else:
else:
param = submodule.get_parameter(value_key)
param = submodule.get_parameter(value_key)
value = type(param)(place_on_device(value).data)
value = type(param)(place_on_device(value).data)
except AttributeError:
except AttributeError:
# it's a buffer
# it's a buffer
value = place_on_device(value)
value = place_on_device(value)
pass
pass
if HQQLinear is None or not isinstance(submodule, HQQLinear):
if HQQLinear is None or not isinstance(submodule, HQQLinear):
setattr(submodule, value_key, value)
setattr(submodule, value_key, value)
# DATASET + DATALOADERS (modified from llama recipes)
# DATASET + DATALOADERS (modified from llama recipes)
# Formatting prompts in alpaca
# Formatting prompts in alpaca
PROMPT_DICT = {
PROMPT_DICT = {
"prompt_input": (
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
),
"prompt_no_input": (
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
"### Instruction:\n{instruction}\n\n### Response:"
),
),
}
}
# Dataset class
# Dataset class
class InstructionDataset(Dataset):
class InstructionDataset(Dataset):
def __init__(self, dataset, tokenizer, style="alpaca"):
def __init__(self, dataset, tokenizer, style="alpaca"):
self.dataset = dataset
self.dataset = dataset
self.tokenizer = tokenizer
self.tokenizer = tokenizer
self.style = style
self.style = style
def __len__(self):
def __len__(self):
return len(self.dataset)
return len(self.dataset)
def __getitem__(self, index):
def __getitem__(self, index):
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
if self.style == "guanaco":
if self.style == "guanaco":
prompt = self.dataset[index]["text"].split("### Assistant: ")[0]
prompt = self.dataset[index]["text"].split("### Assistant: ")[0]
example = self.dataset[index]["text"]
example = self.dataset[index]["text"]
elif self.style == "qna":
elif self.style == "qna":
prompt_template = "###Context:\n{context}\n###Question:\n{question}\n###Answer:\n"
prompt_template = "###Context:\n{context}\n###Question:\n{question}\n###Answer:\n"
sample = self.dataset[index]
sample = self.dataset[index]
prompt = prompt_template.format_map(sample)
prompt = prompt_template.format_map(sample)
example = prompt + sample['answer']
example = prompt + sample['answer']
elif self.style == "qna_no_ctx":
elif self.style == "qna_no_ctx":
prompt_template = "###Question:\n{question}\n###Answer:\n"
prompt_template = "###Question:\n{question}\n###Answer:\n"
sample = self.dataset[index]
sample = self.dataset[index]
prompt = prompt_template.format_map(sample)
prompt = prompt_template.format_map(sample)
example = prompt + sample['answer']
example = prompt + sample['answer']
else: # Alpaca
else: # Alpaca
ann = self.dataset[index]
ann = self.dataset[index]
if ann.get("input", "") == "":
if ann.get("input", "") == "":
prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
else:
else:
prompt = PROMPT_DICT["prompt_input"].format_map(ann)
prompt = PROMPT_DICT["prompt_input"].format_map(ann)
example = prompt + ann["output"]
example = prompt + ann["output"]
prompt = torch.tensor(
prompt = torch.tensor(
self.tokenizer.encode(prompt), dtype=torch.int64
self.tokenizer.encode(prompt), dtype=torch.int64
)
)
example = self.tokenizer.encode(example)
example = self.tokenizer.encode(example)
example.append(self.tokenizer.eos_token_id)
example.append(self.tokenizer.eos_token_id)
example = torch.tensor(
example = torch.tensor(
example, dtype=torch.int64
example, dtype=torch.int64
)
)
labels = copy.deepcopy(example)
labels = copy.deepcopy(example)
labels[: len(prompt)] = -1
labels[: len(prompt)] = -1
example_mask = example.ge(0)
example_mask = example.ge(0)
label_mask = labels.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
example[~example_mask] = 0
labels[~label_mask] = IGNORE_INDEX
labels[~label_mask] = IGNORE_INDEX
return {
return {
"input_ids": example.tolist(),
"input_ids": example.tolist(),
"labels": labels.tolist(),
"labels": labels.tolist(),
"attention_mask":example_mask.tolist(),
"attention_mask":example_mask.tolist(),
}
}
# And to get the dataloader
# And to get the dataloader
def get_dataloader(tokenizer:PreTrainedTokenizerFast, args:Dict):
def get_dataloader(tokenizer:PreTrainedTokenizerFast, args:Dict):
"""Creates a dataset and appropriate dataloader with distributed sampler."""
"""Creates a dataset and appropriate dataloader with distributed sampler."""
# Importing here rather than at the start to avoid multiprocessing issues
# Importing here rather than at the start to avoid multiprocessing issues
from datasets import Dataset, load_dataset
from datasets import Dataset, load_dataset
# Load the source dataset
# Load the source dataset
if args["dataset"] == "alpaca":
if args["dataset"] == "alpaca":
dataset = load_dataset("yahma/alpaca-cleaned")['train']
dataset = load_dataset("yahma/alpaca-cleaned")['train']
elif args["dataset"] == "alpaca_sample":
elif args["dataset"] == "alpaca_sample":
dataset = load_dataset("yahma/alpaca-cleaned", split=f"train[:{args['dataset_samples']}]")
dataset = load_dataset("yahma/alpaca-cleaned", split=f"train[:{args['dataset_samples']}]")
elif args["dataset"] == "dummy":
elif args["dataset"] == "dummy":
dataset = Dataset.from_dict({
dataset = Dataset.from_dict({
'instruction': ["instruction"]*args["dataset_samples"],
'instruction': ["instruction"]*args["dataset_samples"],
'input': ["input"]*args["dataset_samples"],
'input': ["input"]*args["dataset_samples"],
'output': ["output"*args["context_length"]*2]*args["dataset_samples"]} # A long output to test memory usage (gets truncated)
'output': ["output"*args["context_length"]*2]*args["dataset_samples"]} # A long output to test memory usage (gets truncated)
)
)
elif args["dataset"] == "guanaco":
elif args["dataset"] == "guanaco":
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
elif args["dataset"] == "sql":
elif args["dataset"] == "sql":
dataset = load_dataset("knowrohit07/know_sql")['validation']
dataset = load_dataset("knowrohit07/know_sql")['validation']
dataset = dataset.shuffle(seed=args["seed"])
dataset = dataset.shuffle(seed=args["seed"])
dataset = dataset.select(range(1000,len(dataset)))
dataset = dataset.select(range(1000,len(dataset)))
elif args["dataset"] == "orca_math":
elif args["dataset"] == "orca_math":
dataset = load_dataset("microsoft/orca-math-word-problems-200k")['train'].shuffle(seed=42)
dataset = load_dataset("microsoft/orca-math-word-problems-200k")['train'].shuffle(seed=42)
# train with 10k for starters. Then 100k.
# train with 10k for starters. Then 100k.
dataset = dataset.select(range(0,args['dataset_samples']))
dataset = dataset.select(range(0,args['dataset_samples']))
elif args["dataset"] == "uganda_clinical_guidelines":
dataset = load_dataset("silvaKenpachi/uganda-clinical-guidelines")['train'].shuffle(seed=42)
# train with 10k for starters. Then 100k.
dataset = dataset.select(range(0,args['dataset_samples']))
# truncate dataset so it's evenly divisible by grad_accumulation_steps
# truncate dataset so it's evenly divisible by grad_accumulation_steps
dataset = dataset.select(range(0, len(dataset)-len(dataset)%(args["batch_size"]*args["gradient_accumulation_steps"])))
dataset = dataset.select(range(0, len(dataset)-len(dataset)%(args["batch_size"]*args["gradient_accumulation_steps"])))
# # Create the InstructionDataset
# # Create the InstructionDataset
if args["dataset"] == "guanaco":
if args["dataset"] == "guanaco":
dataset = InstructionDataset(dataset, tokenizer, style="guanaco")
dataset = InstructionDataset(dataset, tokenizer, style="guanaco")
elif args["dataset"] == "sql":
elif args["dataset"] == "sql":
dataset = InstructionDataset(dataset, tokenizer, style="qna")
dataset = InstructionDataset(dataset, tokenizer, style="qna")
elif args["dataset"] == "orca_math":
elif args["dataset"] == "orca_math":
dataset = InstructionDataset(dataset, tokenizer, style="qna_no_ctx")
dataset = InstructionDataset(dataset, tokenizer, style="qna_no_ctx")
else: # (w/ alpaca prompt formatting)
else: # (w/ alpaca prompt formatting)
dataset = InstructionDataset(dataset, tokenizer, style="alpaca")
dataset = InstructionDataset(dataset, tokenizer, style="alpaca")
# Collate function
# Collate function
def collate_fn(batch, with_attention_mask=False):
def collate_fn(batch, with_attention_mask=False):
# To list of tensors
# To list of tensors
input_ids = [torch.tensor(item['input_ids']) for item in batch]
input_ids = [torch.tensor(item['input_ids']) for item in batch]
attention_masks = [torch.tensor(item['attention_mask']) for item in batch]
attention_masks = [torch.tensor(item['attention_mask']) for item in batch]
labels = [torch.tensor(item['labels']) for item in batch]
labels = [torch.tensor(item['labels']) for item in batch]
# Pad + truncate
# Pad + truncate
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)[:, :args["context_length"]]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)[:, :args["context_length"]]
if with_attention_mask:
if with_attention_mask:
attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)[:, :args["context_length"]]
attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)[:, :args["context_length"]]
else:
else:
attention_masks = None
attention_masks = None
labels = pad_sequence(labels, batch_first=True, padding_value=-100)[:, :args["context_length"]]
labels = pad_sequence(labels, batch_first=True, padding_value=-100)[:, :args["context_length"]]
# Return dict
# Return dict
return {'input_ids': input_ids, 'attention_mask': attention_masks, 'labels': labels}
return {'input_ids': input_ids, 'attention_mask': attention_masks, 'labels': labels}
# For distributed training, use DistributedSampler
# For distributed training, use DistributedSampler
sampler = DistributedSampler(dataset, seed=args["seed"])
sampler = DistributedSampler(dataset, seed=args["seed"])
# Use the custom collate function in DataLoader
# Use the custom collate function in DataLoader
dataloader = DataLoader(dataset, batch_size=args["batch_size"], collate_fn=collate_fn, sampler=sampler)
dataloader = DataLoader(dataset, batch_size=args["batch_size"], collate_fn=collate_fn, sampler=sampler)
return dataloader
return dataloader
# LR scheduler.
# LR scheduler.
def _get_cosine_one_cycle_lr_lambda(
def _get_cosine_one_cycle_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, min_lr_fraction = 0.1,
current_step: int, *, num_warmup_steps: int, num_training_steps: int, min_lr_fraction = 0.1,
):
):
if current_step < num_warmup_steps:
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return float(current_step) / float(max(1, num_warmup_steps))
scale_term = (1 - min_lr_fraction)
scale_term = (1 - min_lr_fraction)
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return (math.cos(math.pi * progress)+1) * 0.5 * scale_term + min_lr_fraction
return (math.cos(math.pi * progress)+1) * 0.5 * scale_term + min_lr_fraction
def get_cosine_one_cycle_scheduler(optimizer:optim.Optimizer, num_warmup_steps:int, num_training_steps:int, min_lr_fraction:float=0.1):
def get_cosine_one_cycle_scheduler(optimizer:optim.Optimizer, num_warmup_steps:int, num_training_steps:int, min_lr_fraction:float=0.1):
"A more general cosine scheduler with to control the minimum learning rate"
"A more general cosine scheduler with to control the minimum learning rate"
lr_lambda = functools.partial(
lr_lambda = functools.partial(
_get_cosine_one_cycle_lr_lambda,
_get_cosine_one_cycle_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_training_steps=num_training_steps,
min_lr_fraction=min_lr_fraction
min_lr_fraction=min_lr_fraction
)
)
return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
def get_lr_scheduler(optimizer:optim.Optimizer, dataloader:DataLoader, gradient_accumulation_steps:int, args:Dict):
def get_lr_scheduler(optimizer:optim.Optimizer, dataloader:DataLoader, gradient_accumulation_steps:int, args:Dict):
"""Returns linear, cosine, or constant learning rate scheduler"""
"""Returns linear, cosine, or constant learning rate scheduler"""
num_training_steps = args['num_epochs'] * len(dataloader) // gradient_accumulation_steps
num_training_steps = args['num_epochs'] * len(dataloader) // gradient_accumulation_steps
num_warmup_steps = int(num_training_steps * 0.1)
num_warmup_steps = int(num_training_steps * 0.1)
if args['lr_scheduler'] == "linear":
if args['lr_scheduler'] == "linear":
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
elif args['lr_scheduler'] == "cosine":
elif args['lr_scheduler'] == "cosine":
lr_scheduler = get_cosine_one_cycle_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr_fraction=0.1)
lr_scheduler = get_cosine_one_cycle_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr_fraction=0.1)
elif args['lr_scheduler'] == "constant":
elif args['lr_scheduler'] == "constant":
lr_scheduler = None
lr_scheduler = None
else:
else:
raise NotImplementedError(f"{args['lr_scheduler']} LR scheduler not implemented yet")
raise NotImplementedError(f"{args['lr_scheduler']} LR scheduler not implemented yet")
return lr_scheduler, num_training_steps
return lr_scheduler, num_training_steps
# Optimizer
# Optimizer
def get_optimizer(model:nn.Module, args:Dict):
def get_optimizer(model:nn.Module, args:Dict):
"""Returns an optimizer. We can add more options here if needed."""
"""Returns an optimizer. We can add more options here if needed."""
if args["optimizer"] in ["adam", "fused_adam"]:
if args["optimizer"] in ["adam", "fused_adam"]:
return optim.Adam(model.parameters(), lr=args['lr'], fused=args["optimizer"]=="fused_adam")
return optim.Adam(model.parameters(), lr=args['lr'], fused=args["optimizer"]=="fused_adam")
elif args["optimizer"] == "sgd":
elif args["optimizer"] == "sgd":
return optim.SGD(model.parameters(), lr=args['lr'])
return optim.SGD(model.parameters(), lr=args['lr'])
elif args["optimizer"] == "adadelta":
elif args["optimizer"] == "adadelta":
return optim.Adadelta(model.parameters(), lr=args['lr'])
return optim.Adadelta(model.parameters(), lr=args['lr'])
elif args["optimizer"] in ["adamw", "fused_adamw"]:
elif args["optimizer"] in ["adamw", "fused_adamw"]:
return torch.optim.AdamW(model.parameters(), lr=args['lr'], betas=(0.9,0.95),
return torch.optim.AdamW(model.parameters(), lr=args['lr'], betas=(0.9,0.95),
eps=1e-5, weight_decay=args['wd'], fused=args["optimizer"]=="fused_adamw")
eps=1e-5, weight_decay=args['wd'], fused=args["optimizer"]=="fused_adamw")
else:
else:
raise ValueError("Invalid optimizer")
raise ValueError("Invalid optimizer")
# Wrap the model using LoRA policy from llama-recipes or custom policy:
# Wrap the model using LoRA policy from llama-recipes or custom policy:
# This checks for lora layers (has weight and requires_grad)
# This checks for lora layers (has weight and requires_grad)
def get_wrapping_policy(custom_policy:bool=False, vanilla_policy:bool=False):
def get_wrapping_policy(custom_policy:bool=False, vanilla_policy:bool=False):
from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
if custom_policy:
if custom_policy:
def lambda_policy_fn(module):
def lambda_policy_fn(module):
# LoRA and DoRA trainable layers.
# LoRA and DoRA trainable layers.
return (isinstance(module, nn.Sequential) and all(m.weight.requires_grad for m in module)) or (isinstance(module, (DORALayer, MagnitudeLayer)))
return (isinstance(module, nn.Sequential) and all(m.weight.requires_grad for m in module)) or (isinstance(module, (DORALayer, MagnitudeLayer)))
else:
else:
def lambda_policy_fn(module):
def lambda_policy_fn(module):
return (
return (
len(list(module.named_children())) == 0
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
and module.weight.requires_grad
)
)
def self_attn_policy_fn(module):
def self_attn_policy_fn(module):
# Check module name is self_attn.
# Check module name is self_attn.
return isinstance(module, tuple((*LLAMA_ATTENTION_CLASSES.values(), *MISTRAL_ATTENTION_CLASSES.values())))
return isinstance(module, (LlamaAttention, MistralAttention))
def mlp_policy_fn(module):
def mlp_policy_fn(module):
# Check module name is self_attn.
# Check module name is self_attn.
return isinstance(module, (LlamaMLP, MistralMLP))
return isinstance(module, (LlamaMLP, MistralMLP))
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
self_attn_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn)
self_attn_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn)
mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn)
mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_auto_wrap_policy,
transformer_layer_cls=(LlamaDecoderLayer, MistralDecoderLayer),
transformer_layer_cls=(LlamaDecoderLayer, MistralDecoderLayer),
)
)
if vanilla_policy:
if vanilla_policy:
return transformer_wrap_policy
return transformer_wrap_policy
policies=[lambda_policy, transformer_wrap_policy]
policies=[lambda_policy, transformer_wrap_policy]
if custom_policy:
if custom_policy:
policies.extend([self_attn_policy, mlp_policy])
policies.extend([self_attn_policy, mlp_policy])
return functools.partial(_or_policy, policies=policies)
return functools.partial(_or_policy, policies=policies)
# Main function, run on each process
# Main function, run on each process
def fsdp_main(local_rank:int, world_size:int, args:Dict):
def fsdp_main(local_rank:int, world_size:int, args:Dict):
# Setup and initialize the process group
# Setup and initialize the process group
os.environ['MASTER_ADDR'] = args["master_addr"]
os.environ['MASTER_ADDR'] = args["master_addr"]
os.environ['MASTER_PORT'] = args["master_port"]
os.environ['MASTER_PORT'] = args["master_port"]
if 'SLURM_PROCID' in os.environ:
if 'SLURM_PROCID' in os.environ:
# assumes same number of GPUs per node.
# assumes same number of GPUs per node.
rank = int(os.environ['SLURM_PROCID']) * torch.cuda.device_count() + local_rank
rank = int(os.environ['SLURM_PROCID']) * torch.cuda.device_count() + local_rank
else:
else:
rank = local_rank
rank = local_rank
dist.init_process_group("nccl", rank=rank, world_size=world_size)
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
torch.cuda.set_device(local_rank)
if args["use_cpu_offload"]:
if args["use_cpu_offload"]:
torch.set_num_threads(os.cpu_count()//(min(world_size, torch.cuda.device_count())))
torch.set_num_threads(os.cpu_count()//(min(world_size, torch.cuda.device_count())))
# Start logging
# Start logging
logger = Logger(args, log_to=args["log_to"], project_name=args["project_name"],
logger = Logger(args, log_to=args["log_to"], project_name=args["project_name"],
entity=args["entity"], group=args["group"], name=args["name"], rank=rank)
entity=args["entity"], group=args["group"], name=args["name"], rank=rank)
# Timing stuff
# Timing stuff
init_start_event = torch.cuda.Event(enable_timing=True)
init_start_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
# model precision, qlora compute precison, and FSDP mixed precision policy.
# model precision, qlora compute precison, and FSDP mixed precision policy.
# The Linear4Bit quant_storage dtype should always match the FSDP param_dtype. The compute_dtype should match the AMP compute dtype.
# The Linear4Bit quant_storage dtype should always match the FSDP param_dtype. The compute_dtype should match the AMP compute dtype.
# MixedPrecision(param_dtype=fp32, reduce_dtype=fp32, buffer_dtype=fp32) uses `torch.amp.autocast` to control precision.
# MixedPrecision(param_dtype=fp32, reduce_dtype=fp32, buffer_dtype=fp32) uses `torch.amp.autocast` to control precision.
# limited qlora testing shows that fp16 only works with autocast while bf16 trains with both pure and autocast modes.
# limited qlora testing shows that fp16 only works with autocast while bf16 trains with both pure and autocast modes.
# TODO: test how often this holds for mp_fp16
# TODO: test how often this holds for mp_fp16
mp_policy = None
mp_policy = None
load_param_skip_names = []
load_param_skip_names = []
if args["precision"] == "bf16":
if args["precision"] == "bf16":
torch_dtype, compute_dtype = torch.bfloat16, torch.bfloat16
torch_dtype, compute_dtype = torch.bfloat16, torch.bfloat16
elif args["precision"] == "fp32":
elif args["precision"] == "fp32":
torch_dtype, compute_dtype = torch.float32, torch.float16
torch_dtype, compute_dtype = torch.float32, torch.float16
elif args["precision"] == "fp16_autocast":
elif args["precision"] == "fp16_autocast":
compute_dtype, torch_dtype = torch.float16, torch.float32
compute_dtype, torch_dtype = torch.float16, torch.float32
mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
elif args["precision"] == "bf16_autocast":
elif args["precision"] == "bf16_autocast":
compute_dtype, torch_dtype = torch.bfloat16, torch.float32
compute_dtype, torch_dtype = torch.bfloat16, torch.float32
mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
elif args["precision"] == "bf16_buffers_autocast":
elif args["precision"] == "bf16_buffers_autocast":
compute_dtype, torch_dtype = torch.bfloat16, torch.bfloat16
compute_dtype, torch_dtype = torch.bfloat16, torch.bfloat16
mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
load_param_skip_names = ['inv_freq']
load_param_skip_names = ['inv_freq']
else:
else:
raise ValueError("Invalid precision")
raise ValueError("Invalid precision")
# Load tokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
tokenizer.pad_token_id = tokenizer.eos_token_id # TODO check if it exists first
tokenizer.pad_token_id = tokenizer.eos_token_id # TODO check if it exists first
# Set up dataloader
# Set up dataloader
dataloader = get_dataloader(tokenizer, args)
dataloader = get_dataloader(tokenizer, args)
# Create model
# Create model
cfg = None
attn_impl = "sdpa" # torch 2.2 sdpa uses flash attn 2
if rank == 0 or args['verbose']:
print("Creating model", rank)
if args["train_type"] in ["full", "lora", "custom_lora"]:
if (arg