FSDP with integers solution
68 lines
import os
import os
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel
class IntCodes(nn.ParameterList):
class IntCodes(nn.Module):
pass
def __init__(self, codes: torch.tensor, storage_dtype: torch.dtype = torch.float64):
super().__init__()
assert torch.finfo(storage_dtype).bits % torch.iinfo(codes.dtype).bits == 0
self.dtype, self.shape, self.numel = codes.dtype, codes.shape, codes.numel()
size_ratio = torch.finfo(storage_dtype).bits // torch.iinfo(codes.dtype).bits
codes = F.pad(codes.flatten().clone(), pad=[0, -codes.numel() % size_ratio])
assert len(codes.untyped_storage()) == codes.nbytes # no offset / stride / tail
self.storage_dtype = storage_dtype
self.data = nn.Parameter(
torch.as_tensor(codes.untyped_storage(), device=codes.device, dtype=storage_dtype),
requires_grad=False)
def forward(self):
assert self.data.is_contiguous() and self.data.dtype == self.storage_dtype
byte_offset = self.data.storage_offset() * self.data.nbytes // self.data.numel()
return torch.as_tensor(
self.data.untyped_storage()[byte_offset: byte_offset + self.data.nbytes],
device=self.data.device, dtype=self.dtype
)[:self.numel].view(*self.shape)
class Linear8bit(nn.Module):
class Linear8bit(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
super().__init__()
self.in_features, self.out_features = in_features, out_features
self.in_features, self.out_features = in_features, out_features
self.bias = nn.Parameter(torch.randn(out_features)) if bias else None
self.bias = nn.Parameter(torch.randn(out_features)) if bias else None
self.scales = nn.Parameter(torch.rand(out_features) / 128 / out_features ** 0.5)
self.scales = nn.Parameter(torch.rand(out_features) / 128 / out_features ** 0.5)
self.codes = IntCodes([nn.Parameter(
self.codes = IntCodes(
torch.randint(-128, 128, size=(out_features, in_features), dtype=torch.int8),
torch.randint(-128, 128, size=(out_features, in_features), dtype=torch.int8),
requires_grad=False)]) # ^-- example with random weights; in practice, you use pre-trained weights
) # ^-- example with random weights; in practice, you use pre-trained weights
def forward(self, input):
def forward(self, input):
weight = self.scales.unsqueeze(1) * self.codes[0]
weight = self.scales.unsqueeze(1) * self.codes()
return F.linear(input, weight, self.bias)
return F.linear(input, weight, self.bias)
class QLoRA(nn.Module):
class QLoRA(nn.Module):
def __init__(self, base_layer: Linear8bit, rank: int):
def __init__(self, base_layer: Linear8bit, rank: int):
super().__init__()
super().__init__()
self.base_layer = base_layer
self.base_layer = base_layer
self.adapter = nn.Sequential(
self.adapter = nn.Sequential(
nn.Linear(base_layer.in_features, rank, bias=False),
nn.Linear(base_layer.in_features, rank, bias=False),
nn.Linear(rank, base_layer.out_features, bias=False)
nn.Linear(rank, base_layer.out_features, bias=False)
)
)
def forward(self, input):
def forward(self, input):
return self.base_layer(input) + self.adapter(input)
return self.base_layer(input) + self.adapter(input)
if __name__ == '__main__':
if __name__ == '__main__':
torch.manual_seed(1337)
torch.manual_seed(1337)
torch.distributed.init_process_group()
torch.distributed.init_process_group()
rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"cuda:{rank}")
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.cuda.set_device(device)
inputs = torch.randn(100, 32, device=device)
inputs = torch.randn(100, 32, device=device)
labels = torch.randint(0, 10, size=(100,), device=device)
labels = torch.randint(0, 10, size=(100,), device=device)
model = nn.Sequential(
model = nn.Sequential(
QLoRA(Linear8bit(32, 128), rank=8), nn.GELU(),
QLoRA(Linear8bit(32, 128), rank=8), nn.GELU(),
QLoRA(Linear8bit(128, 128), rank=8), nn.GELU(),
QLoRA(Linear8bit(128, 128), rank=8), nn.GELU(),
nn.Linear(128, 10)
nn.Linear(128, 10)
).to(device)
).to(device)
if torch.distributed.is_initialized():
if torch.distributed.is_initialized():
model = FullyShardedDataParallel(
model = FullyShardedDataParallel(
model, auto_wrap_policy=lambda module, recurse, **_: recurse or isinstance(module, IntCodes)
model, auto_wrap_policy=lambda module, recurse, **_: recurse or isinstance(module, IntCodes)
)
)
opt = torch.optim.Adam(model.parameters())
opt = torch.optim.Adam(model.parameters())
for i in range(1000):
for i in range(1000):
loss = F.cross_entropy(model(inputs), labels)
loss = F.cross_entropy(model(inputs), labels)
opt.zero_grad()
opt.zero_grad()
loss.backward()
loss.backward()
opt.step()
opt.step()
if rank == 0:
if rank == 0:
print(f"Step {i}\tloss = {loss.item():.8f}")
print(f"Step {i}\tloss = {loss.item():.8f}")
torch.distributed.destroy_process_group()
torch.distributed.destroy_process_group()