FSDP with integers solution

Created Diff never expires
7 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
68 lines
23 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
85 lines
import os
import os
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel




class IntCodes(nn.ParameterList):
class IntCodes(nn.Module):
pass
def __init__(self, codes: torch.tensor, storage_dtype: torch.dtype = torch.float64):
super().__init__()
assert torch.finfo(storage_dtype).bits % torch.iinfo(codes.dtype).bits == 0
self.dtype, self.shape, self.numel = codes.dtype, codes.shape, codes.numel()
size_ratio = torch.finfo(storage_dtype).bits // torch.iinfo(codes.dtype).bits
codes = F.pad(codes.flatten().clone(), pad=[0, -codes.numel() % size_ratio])
assert len(codes.untyped_storage()) == codes.nbytes # no offset / stride / tail
self.storage_dtype = storage_dtype
self.data = nn.Parameter(
torch.as_tensor(codes.untyped_storage(), device=codes.device, dtype=storage_dtype),
requires_grad=False)
def forward(self):
assert self.data.is_contiguous() and self.data.dtype == self.storage_dtype
byte_offset = self.data.storage_offset() * self.data.nbytes // self.data.numel()
return torch.as_tensor(
self.data.untyped_storage()[byte_offset: byte_offset + self.data.nbytes],
device=self.data.device, dtype=self.dtype
)[:self.numel].view(*self.shape)




class Linear8bit(nn.Module):
class Linear8bit(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
super().__init__()
self.in_features, self.out_features = in_features, out_features
self.in_features, self.out_features = in_features, out_features
self.bias = nn.Parameter(torch.randn(out_features)) if bias else None
self.bias = nn.Parameter(torch.randn(out_features)) if bias else None
self.scales = nn.Parameter(torch.rand(out_features) / 128 / out_features ** 0.5)
self.scales = nn.Parameter(torch.rand(out_features) / 128 / out_features ** 0.5)
self.codes = IntCodes([nn.Parameter(
self.codes = IntCodes(
torch.randint(-128, 128, size=(out_features, in_features), dtype=torch.int8),
torch.randint(-128, 128, size=(out_features, in_features), dtype=torch.int8),
requires_grad=False)]) # ^-- example with random weights; in practice, you use pre-trained weights
) # ^-- example with random weights; in practice, you use pre-trained weights
def forward(self, input):
def forward(self, input):
weight = self.scales.unsqueeze(1) * self.codes[0]
weight = self.scales.unsqueeze(1) * self.codes()
return F.linear(input, weight, self.bias)
return F.linear(input, weight, self.bias)




class QLoRA(nn.Module):
class QLoRA(nn.Module):
def __init__(self, base_layer: Linear8bit, rank: int):
def __init__(self, base_layer: Linear8bit, rank: int):
super().__init__()
super().__init__()
self.base_layer = base_layer
self.base_layer = base_layer
self.adapter = nn.Sequential(
self.adapter = nn.Sequential(
nn.Linear(base_layer.in_features, rank, bias=False),
nn.Linear(base_layer.in_features, rank, bias=False),
nn.Linear(rank, base_layer.out_features, bias=False)
nn.Linear(rank, base_layer.out_features, bias=False)
)
)
def forward(self, input):
def forward(self, input):
return self.base_layer(input) + self.adapter(input)
return self.base_layer(input) + self.adapter(input)






if __name__ == '__main__':
if __name__ == '__main__':
torch.manual_seed(1337)
torch.manual_seed(1337)
torch.distributed.init_process_group()
torch.distributed.init_process_group()
rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"cuda:{rank}")
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.cuda.set_device(device)
inputs = torch.randn(100, 32, device=device)
inputs = torch.randn(100, 32, device=device)
labels = torch.randint(0, 10, size=(100,), device=device)
labels = torch.randint(0, 10, size=(100,), device=device)
model = nn.Sequential(
model = nn.Sequential(
QLoRA(Linear8bit(32, 128), rank=8), nn.GELU(),
QLoRA(Linear8bit(32, 128), rank=8), nn.GELU(),
QLoRA(Linear8bit(128, 128), rank=8), nn.GELU(),
QLoRA(Linear8bit(128, 128), rank=8), nn.GELU(),
nn.Linear(128, 10)
nn.Linear(128, 10)
).to(device)
).to(device)
if torch.distributed.is_initialized():
if torch.distributed.is_initialized():
model = FullyShardedDataParallel(
model = FullyShardedDataParallel(
model, auto_wrap_policy=lambda module, recurse, **_: recurse or isinstance(module, IntCodes)
model, auto_wrap_policy=lambda module, recurse, **_: recurse or isinstance(module, IntCodes)
)
)
opt = torch.optim.Adam(model.parameters())
opt = torch.optim.Adam(model.parameters())
for i in range(1000):
for i in range(1000):
loss = F.cross_entropy(model(inputs), labels)
loss = F.cross_entropy(model(inputs), labels)
opt.zero_grad()
opt.zero_grad()
loss.backward()
loss.backward()
opt.step()
opt.step()
if rank == 0:
if rank == 0:
print(f"Step {i}\tloss = {loss.item():.8f}")
print(f"Step {i}\tloss = {loss.item():.8f}")
torch.distributed.destroy_process_group()
torch.distributed.destroy_process_group()