simiarity_2

Created Diff never expires
74 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
353 lines
164 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
444 lines
#!POPCORN leaderboard amd-mixture-of-experts
#!POPCORN leaderboard amd-mixture-of-experts
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
from typing import Dict, Tuple
from typing import Dict, Tuple


import triton
import triton
import triton.language as tl
import triton.language as tl


from task import input_t, output_t # Assuming task.py defines these types
from task import input_t, output_t # Assuming task.py defines these types


# --- Triton FFN Kernel ---
@triton.jit
@triton.jit
def silu(x):
def silu(x): return (x * tl.sigmoid(x.to(tl.float32))).to(tl.float16)
return x * tl.sigmoid(x)

# --- PyTorch MoE Implementation using Triton Kernel ---


@triton.jit
@triton.jit
def fused_ffn_kernel_v2(
def fused_ffn_kernel_v2(
# Pointers to Matrices
# Pointers to Matrices
x_ptr, w_gate_ptr, w_up_ptr, w_down_ptr, output_ptr,
x_ptr, w_gate_ptr, w_up_ptr, w_down_ptr, output_ptr,
# Matrix dimensions
# Matrix dimensions
M, N_out, K_in, K_inter,
M, N_out, K_in, K_inter,
# Strides
# Strides
stride_x_m, stride_x_k,
stride_x_m, stride_x_k,
stride_w_gate_k, stride_w_gate_inter,
stride_w_gate_k, stride_w_gate_inter,
stride_w_up_k, stride_w_up_inter,
stride_w_up_k, stride_w_up_inter,
stride_w_down_inter, stride_w_down_n,
stride_w_down_inter, stride_w_down_n,
stride_out_m, stride_out_n,
stride_out_m, stride_out_n,
# Meta-parameters
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N_OUT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N_OUT: tl.constexpr,
BLOCK_SIZE_K_IN: tl.constexpr, BLOCK_SIZE_K_INTER: tl.constexpr,
BLOCK_SIZE_K_IN: tl.constexpr, BLOCK_SIZE_K_INTER: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
):
"""
"""
Computes FFN: output = W_down(silu(x @ W_gate) * (x @ W_up))
Computes FFN: output = W_down(silu(x @ W_gate) * (x @ W_up))
Input x is cast to weight dtype (fp16) for first matmul.
Mimics reference precision:
Intermediate result is cast to weight dtype (fp16) before final matmul.
- Matmuls (W_gate, W_up, W_down) accumulate in FP32.
Final accumulation is fp32, cast to output dtype at the end.
- Outputs of W_gate and W_up are treated as FP16.
- SiLU is applied to the FP16 W_gate result.
- Element-wise multiply happens with FP16 inputs.
- Final W_down matmul takes FP16 input, accumulates FP32, outputs FP16.
"""
"""
pid = tl.program_id(axis=0)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N_OUT)
num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N_OUT)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
pid_n = (pid % num_pid_in_group) // group_size_m


offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n_out = pid_n * BLOCK_SIZE_N_OUT + tl.arange(0, BLOCK_SIZE_N_OUT)
offs_n_out = pid_n * BLOCK_SIZE_N_OUT + tl.arange(0, BLOCK_SIZE_N_OUT)
x_start_ptr = x_ptr + offs_m[:, None] * stride_x_m
x_start_ptr = x_ptr + offs_m[:, None] * stride_x_m


acc_down = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N_OUT), dtype=tl.float32)
# Dtypes
# Target dtype for the input to the W_down matmul (usually float16)
weight_dtype = w_gate_ptr.dtype.element_ty # Typically fp16 (intermediate_dtype)
intermediate_dtype = w_down_ptr.dtype.element_ty
accum_dtype = tl.float32 # Use fp32 for accumulation
# Use float32 for accumulation inside the kernel

accum_dtype = tl.float32
# Initialize final accumulator for W_down
acc_down = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N_OUT), dtype=accum_dtype)


for k_inter_start in range(0, tl.cdiv(K_inter, BLOCK_SIZE_K_INTER)):
for k_inter_start in range(0, tl.cdiv(K_inter, BLOCK_SIZE_K_INTER)):
offs_k_inter = k_inter_start * BLOCK_SIZE_K_INTER + tl.arange(0, BLOCK_SIZE_K_INTER)
offs_k_inter = k_inter_start * BLOCK_SIZE_K_INTER + tl.arange(0, BLOCK_SIZE_K_INTER)


acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
# Accumulators for W_gate and W_up (accumulate in FP32)
acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
acc_gate_fp32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
acc_up_fp32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)


w_gate_k_inter_ptr = w_gate_ptr + offs_k_inter[None, :] * stride_w_gate_inter
w_gate_k_inter_ptr = w_gate_ptr + offs_k_inter[None, :] * stride_w_gate_inter
w_up_k_inter_ptr = w_up_ptr + offs_k_inter[None, :] * stride_w_up_inter
w_up_k_inter_ptr = w_up_ptr + offs_k_inter[None, :] * stride_w_up_inter


for k_in_start in range(0, tl.cdiv(K_in, BLOCK_SIZE_K_IN)):
for k_in_start in range(0, tl.cdiv(K_in, BLOCK_SIZE_K_IN)):
offs_k_in = k_in_start * BLOCK_SIZE_K_IN + tl.arange(0, BLOCK_SIZE_K_IN)
offs_k_in = k_in_start * BLOCK_SIZE_K_IN + tl.arange(0, BLOCK_SIZE_K_IN)


# Load x block (cast to weight dtype, e.g., fp16)
# Load x block (use its native dtype, likely fp16)
x_ptrs = x_start_ptr + offs_k_in[None, :] * stride_x_k
x_ptrs = x_start_ptr + offs_k_in[None, :] * stride_x_k
x_mask = (offs_m[:, None] < M) & (offs_k_in[None, :] < K_in)
x_mask = (offs_m[:, None] < M) & (offs_k_in[None, :] < K_in)
# Load x into the same dtype as weights for the first matmul
x_block = tl.load(x_ptrs, mask=x_mask, other=0.0) # Load as input_dtype (fp16)
x_block = tl.load(x_ptrs, mask=x_mask, other=0.0).to(w_gate_ptr.dtype.element_ty)


# Load W_gate block (native dtype, e.g., fp16)
# Load W_gate block (fp16)
w_gate_ptrs = w_gate_k_inter_ptr + offs_k_in[:, None] * stride_w_gate_k
w_gate_ptrs = w_gate_k_inter_ptr + offs_k_in[:, None] * stride_w_gate_k
w_gate_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_gate_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_gate_block = tl.load(w_gate_ptrs, mask=w_gate_mask, other=0.0)
w_gate_block = tl.load(w_gate_ptrs, mask=w_gate_mask, other=0.0)


# Load W_up block (native dtype, e.g., fp16)
# Load W_up block (fp16)
w_up_ptrs = w_up_k_inter_ptr + offs_k_in[:, None] * stride_w_up_k
w_up_ptrs = w_up_k_inter_ptr + offs_k_in[:, None] * stride_w_up_k
w_up_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_up_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_up_block = tl.load(w_up_ptrs, mask=w_up_mask, other=0.0)
w_up_block = tl.load(w_up_ptrs, mask=w_up_mask, other=0.0)


# Accumulate Gate and Up results (in float32)
# Accumulate Gate and Up results in FP32
acc_gate += tl.dot(x_block, w_gate_block, out_dtype=accum_dtype)
# Ensure input x_block is cast to weight dtype if different (usually not)
acc_up += tl.dot(x_block, w_up_block, out_dtype=accum_dtype)
acc_gate_fp32 += tl.dot(x_block.to(weight_dtype), w_gate_block, out_dtype=accum_dtype)
acc_up_fp32 += tl.dot(x_block.to(weight_dtype), w_up_block, out_dtype=accum_dtype)


# --- Apply Activation and Element-wise Product (Keep in FP32) --- <<< CORRECTED LOGIC
# --- Apply Activation and Element-wise Product (Mimic Reference FP16 Path) ---
gate_activated = silu(acc_gate) # input=fp32, output=fp32
# 1. Cast FP32 accumulated results down to FP16 (like output of nn.Linear)
intermediate_fp32 = gate_activated * acc_up # Multiply in fp32
acc_gate_fp16 = acc_gate_fp32.to(weight_dtype)
# Cast down to target intermediate dtype (e.g., fp16) just before next matmul
acc_up_fp16 = acc_up_fp32.to(weight_dtype)
intermediate_block = intermediate_fp32.to(intermediate_dtype)

# 2. Apply SiLU on the FP16 gate result
gate_activated_fp16 = silu(acc_gate_fp16) # input=fp16 -> output=fp16

# 3. Perform element-wise multiply in FP16
intermediate_block_fp16 = gate_activated_fp16 * acc_up_fp16 # fp16 * fp16 -> fp16


# --- Compute Down Projection Dot Product ---
# --- Compute Down Projection Dot Product ---
# Load W_down block (ensure it's target intermediate dtype, e.g., fp16)
# Load W_down block (fp16)
w_down_start_ptr = w_down_ptr + offs_n_out[None, :] * stride_w_down_n
w_down_start_ptr = w_down_ptr + offs_n_out[None, :] * stride_w_down_n
w_down_ptrs = w_down_start_ptr + offs_k_inter[:, None] * stride_w_down_inter
w_down_ptrs = w_down_start_ptr + offs_k_inter[:, None] * stride_w_down_inter
w_down_mask = (offs_k_inter[:, None] < K_inter) & (offs_n_out[None, :] < N_out)
w_down_mask = (offs_k_inter[:, None] < K_inter) & (offs_n_out[None, :] < N_out)
# Ensure W_down block is loaded with the expected intermediate dtype
w_down_block = tl.load(w_down_ptrs, mask=w_down_mask, other=0.0) # Should be fp16
w_down_block = tl.load(w_down_ptrs, mask=w_down_mask, other=0.0).to(intermediate_dtype)


# Accumulate final result (dot input is fp16/intermediate_dtype, accumulate in fp32)
# 4. Accumulate W_down result in FP32 using FP16 intermediate input
acc_down += tl.dot(intermediate_block, w_down_block, out_dtype=accum_dtype)
acc_down += tl.dot(intermediate_block_fp16, w_down_block, out_dtype=accum_dtype)


# --- Store Final Output ---
# --- Store Final Output ---
acc_down = acc_down.to(output_ptr.dtype.element_ty) # Final cast to output dtype (e.g., fp16)
output_ptrs = output_ptr + offs_m[:, None] * stride_out_m + offs_n_out[None, :] * stride_out_n
output_ptrs = output_ptr + offs_m[:, None] * stride_out_m + offs_n_out[None, :] * stride_out_n
output_mask = (offs_m[:, None] < M) & (offs_n_out[None, :] < N_out)
output_mask = (offs_m[:, None] < M) & (offs_n_out[None, :] < N_out)
tl.store(output_ptrs, acc_down, mask=output_mask)
# Cast final accumulated FP32 result to output dtype (e.g., FP16)

tl.store(output_ptrs, acc_down.to(output_ptr.dtype.element_ty), mask=output_mask)
# --- PyTorch MoE Implementation using Triton Kernel ---


class TritonExpert(nn.Module):
class TritonExpert(nn.Module):
"""Wrapper for the Triton FFN kernel."""
"""Wrapper for the Triton FFN kernel."""
def __init__(self, config: Dict, W_gate: torch.Tensor, W_up: torch.Tensor, W_down: torch.Tensor, d_expert_actual: int):
def __init__(self, config: Dict, W_gate: torch.Tensor, W_up: torch.Tensor, W_down: torch.Tensor, d_expert_actual: int):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.d_hidden: int = config["d_hidden"]
self.d_hidden: int = config["d_hidden"]
self.d_expert_actual = d_expert_actual
self.d_expert_actual = d_expert_actual


# Store weights directly, assuming they are already in the correct layout for the kernel
# Store weights directly, assuming they are already in the correct layout for the kernel
# Kernel expects:
# Kernel expects:
# W_gate: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_gate: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_up: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_up: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_down: [K_inter, N_out] = [d_expert_actual, d_hidden]
# W_down: [K_inter, N_out] = [d_expert_actual, d_hidden]
# We assume the input weights dict already has these shapes.
# We assume the input weights dict already has these shapes.
self.W_gate = W_gate
self.W_gate = W_gate
self.W_up = W_up
self.W_up = W_up
self.W_down = W_down
self.W_down = W_down


def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
"""
Args:
Args:
x: Input tensor of shape [num_tokens, d_hidden]
x: Input tensor of shape [num_tokens, d_hidden]
Returns:
Returns:
Output tensor of shape [num_tokens, d_hidden]
Output tensor of shape [num_tokens, d_hidden]
"""
"""
if x.shape[0] == 0: # Handle empty input case
if x.shape[0] == 0: # Handle empty input case
return torch.zeros((0, self.d_hidden), device=x.device, dtype=x.dtype)
return torch.zeros((0, self.d_hidden), device=x.device, dtype=x.dtype)


M, K_in = x.shape
M, K_in = x.shape
assert K_in == self.d_hidden
assert K_in == self.d_hidden
K_inter = self.d_expert_actual
K_inter = self.d_expert_actual
N_out = self.d_hidden
N_out = self.d_hidden


# Ensure contiguous inputs
# Ensure contiguous inputs
x = x.contiguous()
x = x.contiguous()
# Weights should already be contiguous from loading
# Weights should already be contiguous from loading
assert self.W_gate.is_contiguous()
assert self.W_gate.is_contiguous()
assert self.W_up.is_contiguous()
assert self.W_up.is_contiguous()
assert self.W_down.is_contiguous()
assert self.W_down.is_contiguous()


# Allocate output tensor
# Allocate output tensor
output = torch.empty((M, N_out), device=x.device, dtype=x.dtype)
output = torch.empty((M, N_out), device=x.device, dtype=x.dtype)


# --- Kernel Launch ---
# --- Kernel Launch ---
grid = lambda META: (
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N_out, META['BLOCK_SIZE_N_OUT']),
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N_out, META['BLOCK_SIZE_N_OUT']),
)
)


BLOCK_SIZE_M = 64 # Smaller M block might be okay if M (tokens per expert) isn't huge
BLOCK_SIZE_M = 64 # Smaller M block might be okay if M (tokens per expert) isn't huge
BLOCK_SIZE_N_OUT = 64
BLOCK_SIZE_N_OUT = 64
BLOCK_SIZE_K_IN = 128 # Corresponds to d_hidden
BLOCK_SIZE_K_IN = 128 # Corresponds to d_hidden
BLOCK_SIZE_K_INTER = 128 # Corresponds to d_expert_actual
BLOCK_SIZE_K_INTER = 128 # Corresponds to d_expert_actual
GROUP_SIZE_M = 8 # For wave scheduling / locality
GROUP_SIZE_M = 4 # For wave scheduling / locality
num_warps = 8
num_warps = 8
num_stages = 1 # Adjust based on shared memory usage and latency hiding needs
num_stages = 1 # Adjust based on shared memory usage and latency hiding needs


fused_ffn_kernel_v2[grid](
fused_ffn_kernel_v2[grid](
x, self.W_gate, self.W_up, self.W_down, output,
x, self.W_gate, self.W_up, self.W_down, output,
M, N_out, K_in, K_inter,
M, N_out, K_in, K_inter,
# Strides: Torch gives byte strides, Triton needs element strides.
# Strides: Torch gives byte strides, Triton needs element strides.
x.stride(0), x.stride(1),
x.stride(0), x.stride(1),
self.W_gate.stride(0), self.W_gate.stride(1),
self.W_gate.stride(0), self.W_gate.stride(1),
self.W_up.stride(0), self.W_up.stride(1),
self.W_up.stride(0), self.W_up.stride(1),
self.W_down.stride(0), self.W_down.stride(1),
self.W_down.stride(0), self.W_down.stride(1),
output.stride(0), output.stride(1),
output.stride(0), output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N_OUT=BLOCK_SIZE_N_OUT,
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N_OUT=BLOCK_SIZE_N_OUT,
BLOCK_SIZE_K_IN=BLOCK_SIZE_K_IN, BLOCK_SIZE_K_INTER=BLOCK_SIZE_K_INTER,
BLOCK_SIZE_K_IN=BLOCK_SIZE_K_IN, BLOCK_SIZE_K_INTER=BLOCK_SIZE_K_INTER,
GROUP_SIZE_M=GROUP_SIZE_M,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=num_warps,
num_warps=num_warps,
num_stages=num_stages
num_stages=num_stages
)
)
return output
return output




@triton.jit
def _triton_gate_kernel(
X_ptr, # Pointer to input tensor x [M, K]
Wg_ptr, # Pointer to gate weight tensor W_g [N, K]
Indices_ptr, # Pointer to output indices tensor [M, top_k] (int64)
Scores_ptr, # Pointer to output scores tensor [M, top_k]
M, # Number of tokens (bs * seq_len)
N: tl.constexpr, # Number of experts
K, # Hidden dimension
stride_xm, stride_xk, # Strides for X (element stride)
stride_wn, stride_wk, # Strides for W_g (element stride)
stride_indices_m, stride_indices_k, # Strides for Indices (element stride)
stride_scores_m, stride_scores_k, # Strides for Scores (element stride)
top_k: tl.constexpr, # Number of experts per token (compile-time constant)
BLOCK_K: tl.constexpr, # Tile size for K dimension
):
"""
Triton kernel for MoE Gating: Fused Matmul(x, W_g.T) + Softmax + Iterative TopK.
Each program instance computes the results for BLOCK_M tokens (rows in x).
Improves Wg reuse.
"""
pid_m = tl.program_id(axis=0)
# --- Compute Logits (x[pid_m, :] @ W_g.T) ---
x_row_ptr = X_ptr + pid_m * stride_xm
offs_k = tl.arange(0, BLOCK_K)
offs_n = tl.arange(0, N) # Use actual N directly
wg_ptrs = Wg_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk

accumulator = tl.zeros((N,), dtype=tl.float32)

for k_start in range(0, tl.cdiv(K, BLOCK_K)):
current_offs_k = k_start * BLOCK_K + offs_k
x_mask = (pid_m < M) & (current_offs_k < K)
x_chunk = tl.load(x_row_ptr + current_offs_k * stride_xk, mask=x_mask)[:, None]

wg_mask = (current_offs_k[None, :] < K)
wg_chunk = tl.load(wg_ptrs + k_start * BLOCK_K * stride_wk, mask=wg_mask)
dot_result = tl.dot(wg_chunk, x_chunk)
dot_result_reshaped = tl.reshape(dot_result, (N,)) # Shape is now [N]
accumulator += dot_result_reshaped

if pid_m < M:
current_scores_fp16 = tl.softmax(accumulator.to(tl.float16)).to(tl.float16)

indices_row_ptr = Indices_ptr + pid_m * stride_indices_m
scores_row_ptr = Scores_ptr + pid_m * stride_scores_m

loop_scores = current_scores_fp16
neg_inf_fp16 = -float('inf')

for k_iter in tl.static_range(top_k):
max_val_fp16, max_idx = tl.max(loop_scores, axis=0, return_indices=True) # max_val_fp16 is fp16
tl.store(indices_row_ptr + k_iter * stride_indices_k, max_idx.to(tl.int64))
tl.store(scores_row_ptr + k_iter * stride_scores_k, max_val_fp16)
loop_scores = tl.where(tl.arange(0, N) == max_idx, neg_inf_fp16, loop_scores)


class MoEGate(nn.Module):
class MoEGate(nn.Module):
# Keep PyTorch implementation for Gating
# Keep PyTorch implementation for Gating
def __init__(self, config: Dict, W_g_weight: torch.Tensor):
def __init__(self, config: Dict, W_g_weight: torch.Tensor):
super().__init__()
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.top_k: int = config["n_experts_per_token"]
self.d_hidden: int = config["d_hidden"]
self.n_experts: int = config["n_routed_experts"]
self.register_buffer('W_g', W_g_weight) # Register as buffer
self.register_buffer('W_g', W_g_weight) # Register as buffer


def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

def forward_pt(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# x: [bs*seq_len, d_hidden]
# x: [bs*seq_len, d_hidden]
# W_g: [num_experts, d_hidden]
# W_g: [num_experts, d_hidden]
logits = F.linear(x, self.W_g) # W_g needs to be [out, in] = [num_experts, d_hidden]
logits = F.linear(x, self.W_g) # W_g needs to be [out, in] = [num_experts, d_hidden]
scores = logits.softmax(dim=-1) # Use float32 for softmax stability
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores.to(x.dtype) # Cast scores back to input dtype
return topk_indices, topk_scores

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Uses the Triton kernel for MoE gating (Linear + Softmax + Iterative TopK).
x: Input tensor of shape [M, K] = [bs*seq_len, d_hidden]
Returns: Tuple[Tensor[M, top_k], Tensor[M, top_k]] (Indices: int64, Scores: x.dtype)
"""
x = x.contiguous()
M, K = x.shape
N = self.n_experts

topk_indices = torch.empty((M, self.top_k), device=x.device, dtype=torch.int64)
topk_scores = torch.empty((M, self.top_k), device=x.device, dtype=x.dtype)

# --- Kernel Config ---
BLOCK_K = 64
grid = (M,) # One program per token
_triton_gate_kernel[grid](
x, self.W_g, topk_indices, topk_scores,
M, N, K,
x.stride(0), x.stride(1),
self.W_g.stride(0), self.W_g.stride(1),
topk_indices.stride(0), topk_indices.stride(1),
topk_scores.stride(0), topk_scores.stride(1),
top_k=self.top_k,
BLOCK_K=BLOCK_K,
# num_warps=4, # Optional tuning
# num_stages=2 # Optional tuning
)


return topk_indices, topk_scores


class TritonMoE(nn.Module):
class TritonMoE(nn.Module):
def __init__(self, config: Dict, weights: Dict[str, torch.Tensor]):
def __init__(self, config: Dict, weights: Dict[str, torch.Tensor]):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.num_experts = config["n_routed_experts"]
self.num_experts = config["n_routed_experts"]
self.top_k = config["n_experts_per_token"]
self.top_k = config["n_experts_per_token"]
self.d_hidden = config["d_hidden"]
self.d_hidden = config["d_hidden"]
self.d_expert = config["d_expert"]
self.d_expert = config["d_expert"]


# --- Gating Network ---
# --- Gating Network ---
# W_g weight shape from dict: [num_experts, d_hidden]
# W_g weight shape from dict: [num_experts, d_hidden]
self.gating_network = MoEGate(config, weights['router.weight'])
self.gating_network = MoEGate(config, weights['router.weight'])


# --- Experts ---
# --- Experts ---
self.experts = nn.ModuleList()
self.experts = nn.ModuleList()
for i in range(self.num_experts):
for i in range(self.num_experts):
# Weights from dict:
# Weights from dict:
# W_gate: [d_hidden, d_expert]
# W_gate: [d_hidden, d_expert]
# W_up: [d_hidden, d_expert]
# W_up: [d_hidden, d_expert]
# W_down: [d_expert, d_hidden]
# W_down: [d_expert, d_hidden]
w_gate = weights[f'experts.{i}.0.weight']
w_gate = weights[f'experts.{i}.0.weight']
w_up = weights[f'experts.{i}.1.weight']
w_up = weights[f'experts.{i}.1.weight']
w_down = weights[f'experts.{i}.2.weight']
w_down = weights[f'experts.{i}.2.weight']
self.experts.append(TritonExpert(config, w_gate, w_up, w_down, self.d_expert))
self.experts.append(TritonExpert(config, w_gate, w_up, w_down, self.d_expert))


# --- Shared Expert ---
# --- Shared Expert ---
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
# Weights from dict:
# Weights from dict:
# W_gate: [d_hidden, shared_expert_dim]
# W_gate: [d_hidden, shared_expert_dim]
# W_up: [d_hidden, shared_expert_dim]
# W_up: [d_hidden, shared_expert_dim]
# W_down: [shared_expert_dim, d_hidden]
# W_down: [shared_expert_dim, d_hidden]
w_gate_shared = weights['shared_experts.0.weight']
w_gate_shared = weights['shared_experts.0.weight']
w_up_shared = weights['shared_experts.1.weight']
w_up_shared = weights['shared_experts.1.weight']
w_down_shared = weights['shared_experts.2.weight']
w_down_shared = weights['shared_experts.2.weight']
self.shared_expert = TritonExpert(config, w_gate_shared, w_up_shared, w_down_shared, shared_expert_dim)
self.shared_expert = TritonExpert(config, w_gate_shared, w_up_shared, w_down_shared, shared_expert_dim)




def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape # [batch_size, seq_len, d_hidden]
orig_shape = x.shape # [batch_size, seq_len, d_hidden]
x_flat = x.view(-1, self.d_hidden) # [bs*seq_len, d_hidden]
x_flat = x.view(-1, self.d_hidden) # [bs*seq_len, d_hidden]


# --- Shared Expert Forward ---
# Run shared expert on the flattened input
shared_output_flat = self.shared_expert(x_flat)

# --- Gating & Routing ---
# --- Gating & Routing ---
expert_indices, expert_scores = self.gating_network(x_flat)
expert_indices, expert_scores = self.gating_network(x_flat)
flat_expert_indices = expert_indices.view(-1) # [bs*seq_len * top_k]
flat_expert_indices = expert_indices.view(-1) # [bs*seq_len * top_k]
flat_expert_weights = expert_scores.view(-1, 1) # [bs*seq_len * top_k, 1]
flat_expert_weights = expert_scores.view(-1, 1) # [bs*seq_len * top_k, 1]
token_ids = torch.arange(x_flat.shape[0], device=x.device).repeat_interleave(self.top_k) # [bs*seq_len * top_k]
token_ids = torch.arange(x_flat.shape[0], device=x.device).repeat_interleave(self.top_k) # [bs*seq_len * top_k]


# Sort by expert index for potentially more coalesced processing within each expert call
# Sort by expert index for potentially more coalesced processing within each expert call
idxs_sorted = flat_expert_indices.argsort()
idxs_sorted = flat_expert_indices.argsort()
sorted_expert_indices = flat_expert_indices[idxs_sorted]
sorted_expert_indices = flat_expert_indices[idxs_sorted]
sorted_token_ids = token_ids[idxs_sorted]
sorted_token_ids = token_ids[idxs_sorted]
sorted_weights = flat_expert_weights[idxs_sorted]
sorted_weights = flat_expert_weights[idxs_sorted]


# Find boundaries for each expert in the sorted list
# Find boundaries for each expert in the sorted list
expert_boundaries = torch.searchsorted(sorted_expert_indices, torch.arange(self.num_experts + 1, device=x.device))
expert_boundaries = torch.searchsorted(sorted_expert_indices, torch.arange(self.num_experts + 1, device=x.device))
N_tokens = x_flat.shape[0]
all_weighted_expert_outputs = torch.zeros(x_flat.shape[0] * self.top_k, self.d_hidden, device=x.device, dtype=x.dtype)
all_weighted_expert_outputs = torch.zeros(N_tokens * self.top_k, self.d_hidden, device=x.device, dtype=x.dtype)
for expert_id in range(self.num_experts):
for expert_id in range(self.num_experts):
start_idx = expert_boundaries[expert_id]
start_idx = expert_boundaries[expert_id]
end_idx = expert_boundaries[expert_id + 1]
end_idx = expert_boundaries[expert_id + 1]


if start_idx == end_idx:
if start_idx == end_idx:
continue # No tokens routed to this expert
continue # No tokens routed to this expert


expert = self.experts[expert_id]
expert = self.experts[expert_id]
current_token_indices = sorted_token_ids[start_idx:end_idx]
current_token_indices = sorted_token_ids[start_idx:end_idx]
current_weights = sorted_weights[start_idx:end_idx]
current_weights = sorted_weights[start_idx:end_idx]


# Gather tokens for the current expert
# Gather tokens for the current expert
expert_tokens = x_flat[current_token_indices] # Gather operation
expert_tokens = x_flat[current_token_indices] # Gather operation


# Run expert FFN using Triton kernel
# Run expert FFN using Triton kernel
expert_out = expert(expert_tokens) # Calls TritonExpert.forward -> fused_ffn_kernel_v2
expert_out = expert(expert_tokens) # Calls TritonExpert.forward -> fused_ffn_kernel_v2


weighted_expert_out = expert_out * current_weights # Removed .to(torch.float32)
weighted_expert_out = expert_out * current_weights # Removed .to(torch.float32)
all_weighted_expert_outputs[start_idx:end_idx] = weighted_expert_out # Store fp16 result
all_weighted_expert_outputs[start_idx:end_idx] = weighted_expert_out # Store fp16 result


scatter_indices = sorted_token_ids.view(-1, 1).expand(-1, self.d_hidden)
scatter_indices = sorted_token_ids.view(-1, 1).expand(-1, self.d_hidden)
expert_cache = torch.zeros_like(x_flat)

expert_cache.scatter_reduce_(
shared_output_flat = self.shared_expert(x_flat)
shared_output_flat.scatter_reduce_(
dim=0,
dim=0,
index=scatter_indices, # Indices corresponding to all_weighted_expert_outputs
index=scatter_indices, # Indices corresponding to all_weighted_expert_outputs
src=all_weighted_expert_outputs, # The permuted results
src=all_weighted_expert_outputs, # The permuted results
reduce="sum",
reduce="sum",
include_self=False # Important: Don't add the initial zeros in expert_cache
include_self=True # Important: Don't add the initial zeros in expert_cache
)
)
# Combine shared and routed outputs
final_output_flat = expert_cache.to(x.dtype) + shared_output_flat
# Reshape back to original shape
# Reshape back to original shape
output = final_output_flat.view(orig_shape)
output = shared_output_flat.view(orig_shape)


return output
return output




def custom_kernel(data: input_t) -> output_t:
def custom_kernel(data: input_t) -> output_t:
"""
"""
Triton implementation of DeepSeek-style Mixture of Experts.
Uses Triton for FFN computations and PyTorch for routing/gather/scatter.
Uses Triton for FFN computations and PyTorch for routing/gather/scatter.
Triton implementation of DeepSeek-style Mixture of Experts.


Args:
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- weights: Dictionary containing model weights
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
- config: Dictionary containing model configuration parameters


Returns:
Returns:
Output tensor [batch_size, seq_len, d_model]
Output tensor [batch_size, seq_len, d_model]
(Aux data dictionary is omitted as per template/inference focus)
(Aux data dictionary is omitted as per template/inference focus)
"""
"""
input_tensor, weights, config = data
input_tensor, weights, config = data


# Instantiate the MoE model with Triton experts
# Instantiate the MoE model with Triton experts
triton_moe_model = TritonMoE(config, weights)
triton_moe_model = TritonMoE(config, weights)


# Run the model
# Run the model
# Ensure model and input are on the same device (should be CUDA)
# Ensure model and input are on the same device (should be CUDA)
input_tensor = input_tensor.to('cuda')
input_tensor = input_tensor.to('cuda')
triton_moe_model.to('cuda')
triton_moe_model.to('cuda')


# Ensure weights are contiguous (important for Triton)
# Ensure weights are contiguous (important for Triton)
for k, v in weights.items():
for k, v in weights.items():
weights[k] = v.contiguous()
weights[k] = v.contiguous()


# It's crucial that the input dtype matches what the kernel expects (float16 usually)
# It's crucial that the input dtype matches what the kernel expects (float16 usually)
input_dtype = next(iter(weights.values())).dtype # Get dtype from weights
input_dtype = next(iter(weights.values())).dtype # Get dtype from weights
input_tensor = input_tensor.to(input_dtype)
input_tensor = input_tensor.to(input_dtype)


output = triton_moe_model(input_tensor)
output = triton_moe_model(input_tensor)


# The competition framework seems to expect just the tensor output, not the aux_data tuple
# The competition framework seems to expect just the tensor output, not the aux_data tuple
return output
return output
'''
'''
from reference import generate_input, ref_kernel, generate_input
from reference import generate_input, ref_kernel, generate_input
from utils import make_match_reference
from utils import make_match_reference
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 4, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 512, "seed": 9371}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 4, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 512, "seed": 9371}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 8, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 8192, "seed": 81934}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 8, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 8192, "seed": 1212}
inp = generate_input(**args)
inp = generate_input(**args)
out = custom_kernel(generate_input(**args))
out = custom_kernel(generate_input(**args))
ref = make_match_reference(ref_kernel)
ref = make_match_reference(ref_kernel)
print(ref(inp, out))
print(ref(inp, out))
'''
'''