simiarity_2
353 lines
#!POPCORN leaderboard amd-mixture-of-experts
#!POPCORN leaderboard amd-mixture-of-experts
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
from typing import Dict, Tuple
from typing import Dict, Tuple
import triton
import triton
import triton.language as tl
import triton.language as tl
from task import input_t, output_t # Assuming task.py defines these types
from task import input_t, output_t # Assuming task.py defines these types
# --- Triton FFN Kernel ---
@triton.jit
@triton.jit
def silu(x):
def silu(x): return (x * tl.sigmoid(x.to(tl.float32))).to(tl.float16)
return x * tl.sigmoid(x)
# --- PyTorch MoE Implementation using Triton Kernel ---
@triton.jit
@triton.jit
def fused_ffn_kernel_v2(
def fused_ffn_kernel_v2(
# Pointers to Matrices
# Pointers to Matrices
x_ptr, w_gate_ptr, w_up_ptr, w_down_ptr, output_ptr,
x_ptr, w_gate_ptr, w_up_ptr, w_down_ptr, output_ptr,
# Matrix dimensions
# Matrix dimensions
M, N_out, K_in, K_inter,
M, N_out, K_in, K_inter,
# Strides
# Strides
stride_x_m, stride_x_k,
stride_x_m, stride_x_k,
stride_w_gate_k, stride_w_gate_inter,
stride_w_gate_k, stride_w_gate_inter,
stride_w_up_k, stride_w_up_inter,
stride_w_up_k, stride_w_up_inter,
stride_w_down_inter, stride_w_down_n,
stride_w_down_inter, stride_w_down_n,
stride_out_m, stride_out_n,
stride_out_m, stride_out_n,
# Meta-parameters
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N_OUT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N_OUT: tl.constexpr,
BLOCK_SIZE_K_IN: tl.constexpr, BLOCK_SIZE_K_INTER: tl.constexpr,
BLOCK_SIZE_K_IN: tl.constexpr, BLOCK_SIZE_K_INTER: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
):
"""
"""
Computes FFN: output = W_down(silu(x @ W_gate) * (x @ W_up))
Computes FFN: output = W_down(silu(x @ W_gate) * (x @ W_up))
Input x is cast to weight dtype (fp16) for first matmul.
Mimics reference precision:
Intermediate result is cast to weight dtype (fp16) before final matmul.
- Matmuls (W_gate, W_up, W_down) accumulate in FP32.
Final accumulation is fp32, cast to output dtype at the end.
- Outputs of W_gate and W_up are treated as FP16.
- SiLU is applied to the FP16 W_gate result.
- Element-wise multiply happens with FP16 inputs.
- Final W_down matmul takes FP16 input, accumulates FP32, outputs FP16.
"""
"""
pid = tl.program_id(axis=0)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N_OUT)
num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N_OUT)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n_out = pid_n * BLOCK_SIZE_N_OUT + tl.arange(0, BLOCK_SIZE_N_OUT)
offs_n_out = pid_n * BLOCK_SIZE_N_OUT + tl.arange(0, BLOCK_SIZE_N_OUT)
x_start_ptr = x_ptr + offs_m[:, None] * stride_x_m
x_start_ptr = x_ptr + offs_m[:, None] * stride_x_m
acc_down = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N_OUT), dtype=tl.float32)
# Dtypes
# Target dtype for the input to the W_down matmul (usually float16)
weight_dtype = w_gate_ptr.dtype.element_ty # Typically fp16 (intermediate_dtype)
intermediate_dtype = w_down_ptr.dtype.element_ty
accum_dtype = tl.float32 # Use fp32 for accumulation
# Use float32 for accumulation inside the kernel
accum_dtype = tl.float32
# Initialize final accumulator for W_down
acc_down = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N_OUT), dtype=accum_dtype)
for k_inter_start in range(0, tl.cdiv(K_inter, BLOCK_SIZE_K_INTER)):
for k_inter_start in range(0, tl.cdiv(K_inter, BLOCK_SIZE_K_INTER)):
offs_k_inter = k_inter_start * BLOCK_SIZE_K_INTER + tl.arange(0, BLOCK_SIZE_K_INTER)
offs_k_inter = k_inter_start * BLOCK_SIZE_K_INTER + tl.arange(0, BLOCK_SIZE_K_INTER)
acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
# Accumulators for W_gate and W_up (accumulate in FP32)
acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
acc_gate_fp32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
acc_up_fp32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
w_gate_k_inter_ptr = w_gate_ptr + offs_k_inter[None, :] * stride_w_gate_inter
w_gate_k_inter_ptr = w_gate_ptr + offs_k_inter[None, :] * stride_w_gate_inter
w_up_k_inter_ptr = w_up_ptr + offs_k_inter[None, :] * stride_w_up_inter
w_up_k_inter_ptr = w_up_ptr + offs_k_inter[None, :] * stride_w_up_inter
for k_in_start in range(0, tl.cdiv(K_in, BLOCK_SIZE_K_IN)):
for k_in_start in range(0, tl.cdiv(K_in, BLOCK_SIZE_K_IN)):
offs_k_in = k_in_start * BLOCK_SIZE_K_IN + tl.arange(0, BLOCK_SIZE_K_IN)
offs_k_in = k_in_start * BLOCK_SIZE_K_IN + tl.arange(0, BLOCK_SIZE_K_IN)
# Load x block (cast to weight dtype, e.g., fp16)
# Load x block (use its native dtype, likely fp16)
x_ptrs = x_start_ptr + offs_k_in[None, :] * stride_x_k
x_ptrs = x_start_ptr + offs_k_in[None, :] * stride_x_k
x_mask = (offs_m[:, None] < M) & (offs_k_in[None, :] < K_in)
x_mask = (offs_m[:, None] < M) & (offs_k_in[None, :] < K_in)
# Load x into the same dtype as weights for the first matmul
x_block = tl.load(x_ptrs, mask=x_mask, other=0.0) # Load as input_dtype (fp16)
x_block = tl.load(x_ptrs, mask=x_mask, other=0.0).to(w_gate_ptr.dtype.element_ty)
# Load W_gate block (native dtype, e.g., fp16)
# Load W_gate block (fp16)
w_gate_ptrs = w_gate_k_inter_ptr + offs_k_in[:, None] * stride_w_gate_k
w_gate_ptrs = w_gate_k_inter_ptr + offs_k_in[:, None] * stride_w_gate_k
w_gate_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_gate_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_gate_block = tl.load(w_gate_ptrs, mask=w_gate_mask, other=0.0)
w_gate_block = tl.load(w_gate_ptrs, mask=w_gate_mask, other=0.0)
# Load W_up block (native dtype, e.g., fp16)
# Load W_up block (fp16)
w_up_ptrs = w_up_k_inter_ptr + offs_k_in[:, None] * stride_w_up_k
w_up_ptrs = w_up_k_inter_ptr + offs_k_in[:, None] * stride_w_up_k
w_up_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_up_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_up_block = tl.load(w_up_ptrs, mask=w_up_mask, other=0.0)
w_up_block = tl.load(w_up_ptrs, mask=w_up_mask, other=0.0)
# Accumulate Gate and Up results (in float32)
# Accumulate Gate and Up results in FP32
acc_gate += tl.dot(x_block, w_gate_block, out_dtype=accum_dtype)
# Ensure input x_block is cast to weight dtype if different (usually not)
acc_up += tl.dot(x_block, w_up_block, out_dtype=accum_dtype)
acc_gate_fp32 += tl.dot(x_block.to(weight_dtype), w_gate_block, out_dtype=accum_dtype)
acc_up_fp32 += tl.dot(x_block.to(weight_dtype), w_up_block, out_dtype=accum_dtype)
# --- Apply Activation and Element-wise Product (Keep in FP32) --- <<< CORRECTED LOGIC
# --- Apply Activation and Element-wise Product (Mimic Reference FP16 Path) ---
gate_activated = silu(acc_gate) # input=fp32, output=fp32
# 1. Cast FP32 accumulated results down to FP16 (like output of nn.Linear)
intermediate_fp32 = gate_activated * acc_up # Multiply in fp32
acc_gate_fp16 = acc_gate_fp32.to(weight_dtype)
# Cast down to target intermediate dtype (e.g., fp16) just before next matmul
acc_up_fp16 = acc_up_fp32.to(weight_dtype)
intermediate_block = intermediate_fp32.to(intermediate_dtype)
# 2. Apply SiLU on the FP16 gate result
gate_activated_fp16 = silu(acc_gate_fp16) # input=fp16 -> output=fp16
# 3. Perform element-wise multiply in FP16
intermediate_block_fp16 = gate_activated_fp16 * acc_up_fp16 # fp16 * fp16 -> fp16
# --- Compute Down Projection Dot Product ---
# --- Compute Down Projection Dot Product ---
# Load W_down block (ensure it's target intermediate dtype, e.g., fp16)
# Load W_down block (fp16)
w_down_start_ptr = w_down_ptr + offs_n_out[None, :] * stride_w_down_n
w_down_start_ptr = w_down_ptr + offs_n_out[None, :] * stride_w_down_n
w_down_ptrs = w_down_start_ptr + offs_k_inter[:, None] * stride_w_down_inter
w_down_ptrs = w_down_start_ptr + offs_k_inter[:, None] * stride_w_down_inter
w_down_mask = (offs_k_inter[:, None] < K_inter) & (offs_n_out[None, :] < N_out)
w_down_mask = (offs_k_inter[:, None] < K_inter) & (offs_n_out[None, :] < N_out)
# Ensure W_down block is loaded with the expected intermediate dtype
w_down_block = tl.load(w_down_ptrs, mask=w_down_mask, other=0.0) # Should be fp16
w_down_block = tl.load(w_down_ptrs, mask=w_down_mask, other=0.0).to(intermediate_dtype)
# Accumulate final result (dot input is fp16/intermediate_dtype, accumulate in fp32)
# 4. Accumulate W_down result in FP32 using FP16 intermediate input
acc_down += tl.dot(intermediate_block, w_down_block, out_dtype=accum_dtype)
acc_down += tl.dot(intermediate_block_fp16, w_down_block, out_dtype=accum_dtype)
# --- Store Final Output ---
# --- Store Final Output ---
acc_down = acc_down.to(output_ptr.dtype.element_ty) # Final cast to output dtype (e.g., fp16)
output_ptrs = output_ptr + offs_m[:, None] * stride_out_m + offs_n_out[None, :] * stride_out_n
output_ptrs = output_ptr + offs_m[:, None] * stride_out_m + offs_n_out[None, :] * stride_out_n
output_mask = (offs_m[:, None] < M) & (offs_n_out[None, :] < N_out)
output_mask = (offs_m[:, None] < M) & (offs_n_out[None, :] < N_out)
tl.store(output_ptrs, acc_down, mask=output_mask)
# Cast final accumulated FP32 result to output dtype (e.g., FP16)
tl.store(output_ptrs, acc_down.to(output_ptr.dtype.element_ty), mask=output_mask)
# --- PyTorch MoE Implementation using Triton Kernel ---
class TritonExpert(nn.Module):
class TritonExpert(nn.Module):
"""Wrapper for the Triton FFN kernel."""
"""Wrapper for the Triton FFN kernel."""
def __init__(self, config: Dict, W_gate: torch.Tensor, W_up: torch.Tensor, W_down: torch.Tensor, d_expert_actual: int):
def __init__(self, config: Dict, W_gate: torch.Tensor, W_up: torch.Tensor, W_down: torch.Tensor, d_expert_actual: int):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.d_hidden: int = config["d_hidden"]
self.d_hidden: int = config["d_hidden"]
self.d_expert_actual = d_expert_actual
self.d_expert_actual = d_expert_actual
# Store weights directly, assuming they are already in the correct layout for the kernel
# Store weights directly, assuming they are already in the correct layout for the kernel
# Kernel expects:
# Kernel expects:
# W_gate: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_gate: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_up: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_up: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_down: [K_inter, N_out] = [d_expert_actual, d_hidden]
# W_down: [K_inter, N_out] = [d_expert_actual, d_hidden]
# We assume the input weights dict already has these shapes.
# We assume the input weights dict already has these shapes.
self.W_gate = W_gate
self.W_gate = W_gate
self.W_up = W_up
self.W_up = W_up
self.W_down = W_down
self.W_down = W_down
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
"""
Args:
Args:
x: Input tensor of shape [num_tokens, d_hidden]
x: Input tensor of shape [num_tokens, d_hidden]
Returns:
Returns:
Output tensor of shape [num_tokens, d_hidden]
Output tensor of shape [num_tokens, d_hidden]
"""
"""
if x.shape[0] == 0: # Handle empty input case
if x.shape[0] == 0: # Handle empty input case
return torch.zeros((0, self.d_hidden), device=x.device, dtype=x.dtype)
return torch.zeros((0, self.d_hidden), device=x.device, dtype=x.dtype)
M, K_in = x.shape
M, K_in = x.shape
assert K_in == self.d_hidden
assert K_in == self.d_hidden
K_inter = self.d_expert_actual
K_inter = self.d_expert_actual
N_out = self.d_hidden
N_out = self.d_hidden
# Ensure contiguous inputs
# Ensure contiguous inputs
x = x.contiguous()
x = x.contiguous()
# Weights should already be contiguous from loading
# Weights should already be contiguous from loading
assert self.W_gate.is_contiguous()
assert self.W_gate.is_contiguous()
assert self.W_up.is_contiguous()
assert self.W_up.is_contiguous()
assert self.W_down.is_contiguous()
assert self.W_down.is_contiguous()
# Allocate output tensor
# Allocate output tensor
output = torch.empty((M, N_out), device=x.device, dtype=x.dtype)
output = torch.empty((M, N_out), device=x.device, dtype=x.dtype)
# --- Kernel Launch ---
# --- Kernel Launch ---
grid = lambda META: (
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N_out, META['BLOCK_SIZE_N_OUT']),
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N_out, META['BLOCK_SIZE_N_OUT']),
)
)
BLOCK_SIZE_M = 64 # Smaller M block might be okay if M (tokens per expert) isn't huge
BLOCK_SIZE_M = 64 # Smaller M block might be okay if M (tokens per expert) isn't huge
BLOCK_SIZE_N_OUT = 64
BLOCK_SIZE_N_OUT = 64
BLOCK_SIZE_K_IN = 128 # Corresponds to d_hidden
BLOCK_SIZE_K_IN = 128 # Corresponds to d_hidden
BLOCK_SIZE_K_INTER = 128 # Corresponds to d_expert_actual
BLOCK_SIZE_K_INTER = 128 # Corresponds to d_expert_actual
GROUP_SIZE_M = 8 # For wave scheduling / locality
GROUP_SIZE_M = 4 # For wave scheduling / locality
num_warps = 8
num_warps = 8
num_stages = 1 # Adjust based on shared memory usage and latency hiding needs
num_stages = 1 # Adjust based on shared memory usage and latency hiding needs
fused_ffn_kernel_v2[grid](
fused_ffn_kernel_v2[grid](
x, self.W_gate, self.W_up, self.W_down, output,
x, self.W_gate, self.W_up, self.W_down, output,
M, N_out, K_in, K_inter,
M, N_out, K_in, K_inter,
# Strides: Torch gives byte strides, Triton needs element strides.
# Strides: Torch gives byte strides, Triton needs element strides.
x.stride(0), x.stride(1),
x.stride(0), x.stride(1),
self.W_gate.stride(0), self.W_gate.stride(1),
self.W_gate.stride(0), self.W_gate.stride(1),
self.W_up.stride(0), self.W_up.stride(1),
self.W_up.stride(0), self.W_up.stride(1),
self.W_down.stride(0), self.W_down.stride(1),
self.W_down.stride(0), self.W_down.stride(1),
output.stride(0), output.stride(1),
output.stride(0), output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N_OUT=BLOCK_SIZE_N_OUT,
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N_OUT=BLOCK_SIZE_N_OUT,
BLOCK_SIZE_K_IN=BLOCK_SIZE_K_IN, BLOCK_SIZE_K_INTER=BLOCK_SIZE_K_INTER,
BLOCK_SIZE_K_IN=BLOCK_SIZE_K_IN, BLOCK_SIZE_K_INTER=BLOCK_SIZE_K_INTER,
GROUP_SIZE_M=GROUP_SIZE_M,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=num_warps,
num_warps=num_warps,
num_stages=num_stages
num_stages=num_stages
)
)
return output
return output
@triton.jit
def _triton_gate_kernel(
X_ptr, # Pointer to input tensor x [M, K]
Wg_ptr, # Pointer to gate weight tensor W_g [N, K]
Indices_ptr, # Pointer to output indices tensor [M, top_k] (int64)
Scores_ptr, # Pointer to output scores tensor [M, top_k]
M, # Number of tokens (bs * seq_len)
N: tl.constexpr, # Number of experts
K, # Hidden dimension
stride_xm, stride_xk, # Strides for X (element stride)
stride_wn, stride_wk, # Strides for W_g (element stride)
stride_indices_m, stride_indices_k, # Strides for Indices (element stride)
stride_scores_m, stride_scores_k, # Strides for Scores (element stride)
top_k: tl.constexpr, # Number of experts per token (compile-time constant)
BLOCK_K: tl.constexpr, # Tile size for K dimension
):
"""
Triton kernel for MoE Gating: Fused Matmul(x, W_g.T) + Softmax + Iterative TopK.
Each program instance computes the results for BLOCK_M tokens (rows in x).
Improves Wg reuse.
"""
pid_m = tl.program_id(axis=0)
# --- Compute Logits (x[pid_m, :] @ W_g.T) ---
x_row_ptr = X_ptr + pid_m * stride_xm
offs_k = tl.arange(0, BLOCK_K)
offs_n = tl.arange(0, N) # Use actual N directly
wg_ptrs = Wg_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
accumulator = tl.zeros((N,), dtype=tl.float32)
for k_start in range(0, tl.cdiv(K, BLOCK_K)):
current_offs_k = k_start * BLOCK_K + offs_k
x_mask = (pid_m < M) & (current_offs_k < K)
x_chunk = tl.load(x_row_ptr + current_offs_k * stride_xk, mask=x_mask)[:, None]
wg_mask = (current_offs_k[None, :] < K)
wg_chunk = tl.load(wg_ptrs + k_start * BLOCK_K * stride_wk, mask=wg_mask)
dot_result = tl.dot(wg_chunk, x_chunk)
dot_result_reshaped = tl.reshape(dot_result, (N,)) # Shape is now [N]
accumulator += dot_result_reshaped
if pid_m < M:
current_scores_fp16 = tl.softmax(accumulator.to(tl.float16)).to(tl.float16)
indices_row_ptr = Indices_ptr + pid_m * stride_indices_m
scores_row_ptr = Scores_ptr + pid_m * stride_scores_m
loop_scores = current_scores_fp16
neg_inf_fp16 = -float('inf')
for k_iter in tl.static_range(top_k):
max_val_fp16, max_idx = tl.max(loop_scores, axis=0, return_indices=True) # max_val_fp16 is fp16
tl.store(indices_row_ptr + k_iter * stride_indices_k, max_idx.to(tl.int64))
tl.store(scores_row_ptr + k_iter * stride_scores_k, max_val_fp16)
loop_scores = tl.where(tl.arange(0, N) == max_idx, neg_inf_fp16, loop_scores)
class MoEGate(nn.Module):
class MoEGate(nn.Module):
# Keep PyTorch implementation for Gating
# Keep PyTorch implementation for Gating
def __init__(self, config: Dict, W_g_weight: torch.Tensor):
def __init__(self, config: Dict, W_g_weight: torch.Tensor):
super().__init__()
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.top_k: int = config["n_experts_per_token"]
self.d_hidden: int = config["d_hidden"]
self.n_experts: int = config["n_routed_experts"]
self.register_buffer('W_g', W_g_weight) # Register as buffer
self.register_buffer('W_g', W_g_weight) # Register as buffer
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward_pt(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# x: [bs*seq_len, d_hidden]
# x: [bs*seq_len, d_hidden]
# W_g: [num_experts, d_hidden]
# W_g: [num_experts, d_hidden]
logits = F.linear(x, self.W_g) # W_g needs to be [out, in] = [num_experts, d_hidden]
logits = F.linear(x, self.W_g) # W_g needs to be [out, in] = [num_experts, d_hidden]
scores = logits.softmax(dim=-1) # Use float32 for softmax stability
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores.to(x.dtype) # Cast scores back to input dtype
return topk_indices, topk_scores
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Uses the Triton kernel for MoE gating (Linear + Softmax + Iterative TopK).
x: Input tensor of shape [M, K] = [bs*seq_len, d_hidden]
Returns: Tuple[Tensor[M, top_k], Tensor[M, top_k]] (Indices: int64, Scores: x.dtype)
"""
x = x.contiguous()
M, K = x.shape
N = self.n_experts
topk_indices = torch.empty((M, self.top_k), device=x.device, dtype=torch.int64)
topk_scores = torch.empty((M, self.top_k), device=x.device, dtype=x.dtype)
# --- Kernel Config ---
BLOCK_K = 64
grid = (M,) # One program per token
_triton_gate_kernel[grid](
x, self.W_g, topk_indices, topk_scores,
M, N, K,
x.stride(0), x.stride(1),
self.W_g.stride(0), self.W_g.stride(1),
topk_indices.stride(0), topk_indices.stride(1),
topk_scores.stride(0), topk_scores.stride(1),
top_k=self.top_k,
BLOCK_K=BLOCK_K,
# num_warps=4, # Optional tuning
# num_stages=2 # Optional tuning
)
return topk_indices, topk_scores
class TritonMoE(nn.Module):
class TritonMoE(nn.Module):
def __init__(self, config: Dict, weights: Dict[str, torch.Tensor]):
def __init__(self, config: Dict, weights: Dict[str, torch.Tensor]):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.num_experts = config["n_routed_experts"]
self.num_experts = config["n_routed_experts"]
self.top_k = config["n_experts_per_token"]
self.top_k = config["n_experts_per_token"]
self.d_hidden = config["d_hidden"]
self.d_hidden = config["d_hidden"]
self.d_expert = config["d_expert"]
self.d_expert = config["d_expert"]
# --- Gating Network ---
# --- Gating Network ---
# W_g weight shape from dict: [num_experts, d_hidden]
# W_g weight shape from dict: [num_experts, d_hidden]
self.gating_network = MoEGate(config, weights['router.weight'])
self.gating_network = MoEGate(config, weights['router.weight'])
# --- Experts ---
# --- Experts ---
self.experts = nn.ModuleList()
self.experts = nn.ModuleList()
for i in range(self.num_experts):
for i in range(self.num_experts):
# Weights from dict:
# Weights from dict:
# W_gate: [d_hidden, d_expert]
# W_gate: [d_hidden, d_expert]
# W_up: [d_hidden, d_expert]
# W_up: [d_hidden, d_expert]
# W_down: [d_expert, d_hidden]
# W_down: [d_expert, d_hidden]
w_gate = weights[f'experts.{i}.0.weight']
w_gate = weights[f'experts.{i}.0.weight']
w_up = weights[f'experts.{i}.1.weight']
w_up = weights[f'experts.{i}.1.weight']
w_down = weights[f'experts.{i}.2.weight']
w_down = weights[f'experts.{i}.2.weight']
self.experts.append(TritonExpert(config, w_gate, w_up, w_down, self.d_expert))
self.experts.append(TritonExpert(config, w_gate, w_up, w_down, self.d_expert))
# --- Shared Expert ---
# --- Shared Expert ---
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
# Weights from dict:
# Weights from dict:
# W_gate: [d_hidden, shared_expert_dim]
# W_gate: [d_hidden, shared_expert_dim]
# W_up: [d_hidden, shared_expert_dim]
# W_up: [d_hidden, shared_expert_dim]
# W_down: [shared_expert_dim, d_hidden]
# W_down: [shared_expert_dim, d_hidden]
w_gate_shared = weights['shared_experts.0.weight']
w_gate_shared = weights['shared_experts.0.weight']
w_up_shared = weights['shared_experts.1.weight']
w_up_shared = weights['shared_experts.1.weight']
w_down_shared = weights['shared_experts.2.weight']
w_down_shared = weights['shared_experts.2.weight']
self.shared_expert = TritonExpert(config, w_gate_shared, w_up_shared, w_down_shared, shared_expert_dim)
self.shared_expert = TritonExpert(config, w_gate_shared, w_up_shared, w_down_shared, shared_expert_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape # [batch_size, seq_len, d_hidden]
orig_shape = x.shape # [batch_size, seq_len, d_hidden]
x_flat = x.view(-1, self.d_hidden) # [bs*seq_len, d_hidden]
x_flat = x.view(-1, self.d_hidden) # [bs*seq_len, d_hidden]
# --- Shared Expert Forward ---
# Run shared expert on the flattened input
shared_output_flat = self.shared_expert(x_flat)
# --- Gating & Routing ---
# --- Gating & Routing ---
expert_indices, expert_scores = self.gating_network(x_flat)
expert_indices, expert_scores = self.gating_network(x_flat)
flat_expert_indices = expert_indices.view(-1) # [bs*seq_len * top_k]
flat_expert_indices = expert_indices.view(-1) # [bs*seq_len * top_k]
flat_expert_weights = expert_scores.view(-1, 1) # [bs*seq_len * top_k, 1]
flat_expert_weights = expert_scores.view(-1, 1) # [bs*seq_len * top_k, 1]
token_ids = torch.arange(x_flat.shape[0], device=x.device).repeat_interleave(self.top_k) # [bs*seq_len * top_k]
token_ids = torch.arange(x_flat.shape[0], device=x.device).repeat_interleave(self.top_k) # [bs*seq_len * top_k]
# Sort by expert index for potentially more coalesced processing within each expert call
# Sort by expert index for potentially more coalesced processing within each expert call
idxs_sorted = flat_expert_indices.argsort()
idxs_sorted = flat_expert_indices.argsort()
sorted_expert_indices = flat_expert_indices[idxs_sorted]
sorted_expert_indices = flat_expert_indices[idxs_sorted]
sorted_token_ids = token_ids[idxs_sorted]
sorted_token_ids = token_ids[idxs_sorted]
sorted_weights = flat_expert_weights[idxs_sorted]
sorted_weights = flat_expert_weights[idxs_sorted]
# Find boundaries for each expert in the sorted list
# Find boundaries for each expert in the sorted list
expert_boundaries = torch.searchsorted(sorted_expert_indices, torch.arange(self.num_experts + 1, device=x.device))
expert_boundaries = torch.searchsorted(sorted_expert_indices, torch.arange(self.num_experts + 1, device=x.device))
N_tokens = x_flat.shape[0]
all_weighted_expert_outputs = torch.zeros(x_flat.shape[0] * self.top_k, self.d_hidden, device=x.device, dtype=x.dtype)
all_weighted_expert_outputs = torch.zeros(N_tokens * self.top_k, self.d_hidden, device=x.device, dtype=x.dtype)
for expert_id in range(self.num_experts):
for expert_id in range(self.num_experts):
start_idx = expert_boundaries[expert_id]
start_idx = expert_boundaries[expert_id]
end_idx = expert_boundaries[expert_id + 1]
end_idx = expert_boundaries[expert_id + 1]
if start_idx == end_idx:
if start_idx == end_idx:
continue # No tokens routed to this expert
continue # No tokens routed to this expert
expert = self.experts[expert_id]
expert = self.experts[expert_id]
current_token_indices = sorted_token_ids[start_idx:end_idx]
current_token_indices = sorted_token_ids[start_idx:end_idx]
current_weights = sorted_weights[start_idx:end_idx]
current_weights = sorted_weights[start_idx:end_idx]
# Gather tokens for the current expert
# Gather tokens for the current expert
expert_tokens = x_flat[current_token_indices] # Gather operation
expert_tokens = x_flat[current_token_indices] # Gather operation
# Run expert FFN using Triton kernel
# Run expert FFN using Triton kernel
expert_out = expert(expert_tokens) # Calls TritonExpert.forward -> fused_ffn_kernel_v2
expert_out = expert(expert_tokens) # Calls TritonExpert.forward -> fused_ffn_kernel_v2
weighted_expert_out = expert_out * current_weights # Removed .to(torch.float32)
weighted_expert_out = expert_out * current_weights # Removed .to(torch.float32)
all_weighted_expert_outputs[start_idx:end_idx] = weighted_expert_out # Store fp16 result
all_weighted_expert_outputs[start_idx:end_idx] = weighted_expert_out # Store fp16 result
scatter_indices = sorted_token_ids.view(-1, 1).expand(-1, self.d_hidden)
scatter_indices = sorted_token_ids.view(-1, 1).expand(-1, self.d_hidden)
expert_cache = torch.zeros_like(x_flat)
expert_cache.scatter_reduce_(
shared_output_flat = self.shared_expert(x_flat)
shared_output_flat.scatter_reduce_(
dim=0,
dim=0,
index=scatter_indices, # Indices corresponding to all_weighted_expert_outputs
index=scatter_indices, # Indices corresponding to all_weighted_expert_outputs
src=all_weighted_expert_outputs, # The permuted results
src=all_weighted_expert_outputs, # The permuted results
reduce="sum",
reduce="sum",
include_self=False # Important: Don't add the initial zeros in expert_cache
include_self=True # Important: Don't add the initial zeros in expert_cache
)
)
# Combine shared and routed outputs
final_output_flat = expert_cache.to(x.dtype) + shared_output_flat
# Reshape back to original shape
# Reshape back to original shape
output = final_output_flat.view(orig_shape)
output = shared_output_flat.view(orig_shape)
return output
return output
def custom_kernel(data: input_t) -> output_t:
def custom_kernel(data: input_t) -> output_t:
"""
"""
Triton implementation of DeepSeek-style Mixture of Experts.
Uses Triton for FFN computations and PyTorch for routing/gather/scatter.
Uses Triton for FFN computations and PyTorch for routing/gather/scatter.
Triton implementation of DeepSeek-style Mixture of Experts.
Args:
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- weights: Dictionary containing model weights
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
- config: Dictionary containing model configuration parameters
Returns:
Returns:
Output tensor [batch_size, seq_len, d_model]
Output tensor [batch_size, seq_len, d_model]
(Aux data dictionary is omitted as per template/inference focus)
(Aux data dictionary is omitted as per template/inference focus)
"""
"""
input_tensor, weights, config = data
input_tensor, weights, config = data
# Instantiate the MoE model with Triton experts
# Instantiate the MoE model with Triton experts
triton_moe_model = TritonMoE(config, weights)
triton_moe_model = TritonMoE(config, weights)
# Run the model
# Run the model
# Ensure model and input are on the same device (should be CUDA)
# Ensure model and input are on the same device (should be CUDA)
input_tensor = input_tensor.to('cuda')
input_tensor = input_tensor.to('cuda')
triton_moe_model.to('cuda')
triton_moe_model.to('cuda')
# Ensure weights are contiguous (important for Triton)
# Ensure weights are contiguous (important for Triton)
for k, v in weights.items():
for k, v in weights.items():
weights[k] = v.contiguous()
weights[k] = v.contiguous()
# It's crucial that the input dtype matches what the kernel expects (float16 usually)
# It's crucial that the input dtype matches what the kernel expects (float16 usually)
input_dtype = next(iter(weights.values())).dtype # Get dtype from weights
input_dtype = next(iter(weights.values())).dtype # Get dtype from weights
input_tensor = input_tensor.to(input_dtype)
input_tensor = input_tensor.to(input_dtype)
output = triton_moe_model(input_tensor)
output = triton_moe_model(input_tensor)
# The competition framework seems to expect just the tensor output, not the aux_data tuple
# The competition framework seems to expect just the tensor output, not the aux_data tuple
return output
return output
'''
'''
from reference import generate_input, ref_kernel, generate_input
from reference import generate_input, ref_kernel, generate_input
from utils import make_match_reference
from utils import make_match_reference
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 4, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 512, "seed": 9371}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 4, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 512, "seed": 9371}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 8, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 8192, "seed": 81934}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 8, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 8192, "seed": 1212}
inp = generate_input(**args)
inp = generate_input(**args)
out = custom_kernel(generate_input(**args))
out = custom_kernel(generate_input(**args))
ref = make_match_reference(ref_kernel)
ref = make_match_reference(ref_kernel)
print(ref(inp, out))
print(ref(inp, out))
'''
'''