similarity_1

Created Diff never expires
87 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
482 lines
41 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
444 lines
#!POPCORN leaderboard amd-mixture-of-experts
#!POPCORN leaderboard amd-mixture-of-experts
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
from typing import Dict, Tuple
from typing import Dict, Tuple


import triton
import triton
import triton.language as tl
import triton.language as tl


from task import input_t, output_t # Assuming task.py defines these types
from task import input_t, output_t # Assuming task.py defines these types


# --- Triton FFN Kernel ---
@triton.jit
@triton.jit
def silu(x): return (x * tl.sigmoid(x.to(tl.float32))).to(tl.float16)
def silu(x): return (x * tl.sigmoid(x.to(tl.float32))).to(tl.float16)


# --- PyTorch MoE Implementation using Triton Kernel ---

@triton.jit
@triton.jit
def fused_ffn_kernel_v2(
def fused_ffn_kernel_v2(
# Pointers to Matrices
# Pointers to Matrices
x_ptr, w_gate_ptr, w_up_ptr, w_down_ptr, output_ptr,
x_ptr, w_gate_ptr, w_up_ptr, w_down_ptr, output_ptr,
# Matrix dimensions
# Matrix dimensions
M, N_out, K_in, K_inter,
M, N_out, K_in, K_inter,
# Strides
# Strides
stride_x_m, stride_x_k,
stride_x_m, stride_x_k,
stride_w_gate_k, stride_w_gate_inter,
stride_w_gate_k, stride_w_gate_inter,
stride_w_up_k, stride_w_up_inter,
stride_w_up_k, stride_w_up_inter,
stride_w_down_inter, stride_w_down_n,
stride_w_down_inter, stride_w_down_n,
stride_out_m, stride_out_n,
stride_out_m, stride_out_n,
# Meta-parameters
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N_OUT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N_OUT: tl.constexpr,
BLOCK_SIZE_K_IN: tl.constexpr, BLOCK_SIZE_K_INTER: tl.constexpr,
BLOCK_SIZE_K_IN: tl.constexpr, BLOCK_SIZE_K_INTER: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
):
"""
"""
Computes FFN: output = W_down(silu(x @ W_gate) * (x @ W_up))
Computes FFN: output = W_down(silu(x @ W_gate) * (x @ W_up))
Mimics reference precision:
Mimics reference precision:
- Matmuls (W_gate, W_up, W_down) accumulate in FP32.
- Matmuls (W_gate, W_up, W_down) accumulate in FP32.
- Outputs of W_gate and W_up are treated as FP16.
- Outputs of W_gate and W_up are treated as FP16.
- SiLU is applied to the FP16 W_gate result.
- SiLU is applied to the FP16 W_gate result.
- Element-wise multiply happens with FP16 inputs.
- Element-wise multiply happens with FP16 inputs.
- Final W_down matmul takes FP16 input, accumulates FP32, outputs FP16.
- Final W_down matmul takes FP16 input, accumulates FP32, outputs FP16.
"""
"""
pid = tl.program_id(axis=0)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N_OUT)
num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N_OUT)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
pid_n = (pid % num_pid_in_group) // group_size_m


offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n_out = pid_n * BLOCK_SIZE_N_OUT + tl.arange(0, BLOCK_SIZE_N_OUT)
offs_n_out = pid_n * BLOCK_SIZE_N_OUT + tl.arange(0, BLOCK_SIZE_N_OUT)
x_start_ptr = x_ptr + offs_m[:, None] * stride_x_m
x_start_ptr = x_ptr + offs_m[:, None] * stride_x_m


# Dtypes
# Dtypes
weight_dtype = w_gate_ptr.dtype.element_ty # Typically fp16 (intermediate_dtype)
weight_dtype = w_gate_ptr.dtype.element_ty # Typically fp16 (intermediate_dtype)
accum_dtype = tl.float32 # Use fp32 for accumulation
accum_dtype = tl.float32 # Use fp32 for accumulation


# Initialize final accumulator for W_down
# Initialize final accumulator for W_down
acc_down = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N_OUT), dtype=accum_dtype)
acc_down = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N_OUT), dtype=accum_dtype)


for k_inter_start in range(0, tl.cdiv(K_inter, BLOCK_SIZE_K_INTER)):
for k_inter_start in range(0, tl.cdiv(K_inter, BLOCK_SIZE_K_INTER)):
offs_k_inter = k_inter_start * BLOCK_SIZE_K_INTER + tl.arange(0, BLOCK_SIZE_K_INTER)
offs_k_inter = k_inter_start * BLOCK_SIZE_K_INTER + tl.arange(0, BLOCK_SIZE_K_INTER)


# Accumulators for W_gate and W_up (accumulate in FP32)
# Accumulators for W_gate and W_up (accumulate in FP32)
acc_gate_fp32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
acc_gate_fp32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
acc_up_fp32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)
acc_up_fp32 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K_INTER), dtype=accum_dtype)


w_gate_k_inter_ptr = w_gate_ptr + offs_k_inter[None, :] * stride_w_gate_inter
w_gate_k_inter_ptr = w_gate_ptr + offs_k_inter[None, :] * stride_w_gate_inter
w_up_k_inter_ptr = w_up_ptr + offs_k_inter[None, :] * stride_w_up_inter
w_up_k_inter_ptr = w_up_ptr + offs_k_inter[None, :] * stride_w_up_inter


for k_in_start in range(0, tl.cdiv(K_in, BLOCK_SIZE_K_IN)):
for k_in_start in range(0, tl.cdiv(K_in, BLOCK_SIZE_K_IN)):
offs_k_in = k_in_start * BLOCK_SIZE_K_IN + tl.arange(0, BLOCK_SIZE_K_IN)
offs_k_in = k_in_start * BLOCK_SIZE_K_IN + tl.arange(0, BLOCK_SIZE_K_IN)


# Load x block (use its native dtype, likely fp16)
# Load x block (use its native dtype, likely fp16)
x_ptrs = x_start_ptr + offs_k_in[None, :] * stride_x_k
x_ptrs = x_start_ptr + offs_k_in[None, :] * stride_x_k
x_mask = (offs_m[:, None] < M) & (offs_k_in[None, :] < K_in)
x_mask = (offs_m[:, None] < M) & (offs_k_in[None, :] < K_in)
x_block = tl.load(x_ptrs, mask=x_mask, other=0.0) # Load as input_dtype (fp16)
x_block = tl.load(x_ptrs, mask=x_mask, other=0.0) # Load as input_dtype (fp16)


# Load W_gate block (fp16)
# Load W_gate block (fp16)
w_gate_ptrs = w_gate_k_inter_ptr + offs_k_in[:, None] * stride_w_gate_k
w_gate_ptrs = w_gate_k_inter_ptr + offs_k_in[:, None] * stride_w_gate_k
w_gate_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_gate_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_gate_block = tl.load(w_gate_ptrs, mask=w_gate_mask, other=0.0)
w_gate_block = tl.load(w_gate_ptrs, mask=w_gate_mask, other=0.0)


# Load W_up block (fp16)
# Load W_up block (fp16)
w_up_ptrs = w_up_k_inter_ptr + offs_k_in[:, None] * stride_w_up_k
w_up_ptrs = w_up_k_inter_ptr + offs_k_in[:, None] * stride_w_up_k
w_up_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_up_mask = (offs_k_in[:, None] < K_in) & (offs_k_inter[None, :] < K_inter)
w_up_block = tl.load(w_up_ptrs, mask=w_up_mask, other=0.0)
w_up_block = tl.load(w_up_ptrs, mask=w_up_mask, other=0.0)


# Accumulate Gate and Up results in FP32
# Accumulate Gate and Up results in FP32
# Ensure input x_block is cast to weight dtype if different (usually not)
# Ensure input x_block is cast to weight dtype if different (usually not)
acc_gate_fp32 += tl.dot(x_block.to(weight_dtype), w_gate_block, out_dtype=accum_dtype)
acc_gate_fp32 += tl.dot(x_block.to(weight_dtype), w_gate_block, out_dtype=accum_dtype)
acc_up_fp32 += tl.dot(x_block.to(weight_dtype), w_up_block, out_dtype=accum_dtype)
acc_up_fp32 += tl.dot(x_block.to(weight_dtype), w_up_block, out_dtype=accum_dtype)


# --- Apply Activation and Element-wise Product (Mimic Reference FP16 Path) ---
# --- Apply Activation and Element-wise Product (Mimic Reference FP16 Path) ---
# 1. Cast FP32 accumulated results down to FP16 (like output of nn.Linear)
# 1. Cast FP32 accumulated results down to FP16 (like output of nn.Linear)
acc_gate_fp16 = acc_gate_fp32.to(weight_dtype)
acc_gate_fp16 = acc_gate_fp32.to(weight_dtype)
acc_up_fp16 = acc_up_fp32.to(weight_dtype)
acc_up_fp16 = acc_up_fp32.to(weight_dtype)


# 2. Apply SiLU on the FP16 gate result
# 2. Apply SiLU on the FP16 gate result
gate_activated_fp16 = silu(acc_gate_fp16) # input=fp16 -> output=fp16
gate_activated_fp16 = silu(acc_gate_fp16) # input=fp16 -> output=fp16


# 3. Perform element-wise multiply in FP16
# 3. Perform element-wise multiply in FP16
intermediate_block_fp16 = gate_activated_fp16 * acc_up_fp16 # fp16 * fp16 -> fp16
intermediate_block_fp16 = gate_activated_fp16 * acc_up_fp16 # fp16 * fp16 -> fp16


# --- Compute Down Projection Dot Product ---
# --- Compute Down Projection Dot Product ---
# Load W_down block (fp16)
# Load W_down block (fp16)
w_down_start_ptr = w_down_ptr + offs_n_out[None, :] * stride_w_down_n
w_down_start_ptr = w_down_ptr + offs_n_out[None, :] * stride_w_down_n
w_down_ptrs = w_down_start_ptr + offs_k_inter[:, None] * stride_w_down_inter
w_down_ptrs = w_down_start_ptr + offs_k_inter[:, None] * stride_w_down_inter
w_down_mask = (offs_k_inter[:, None] < K_inter) & (offs_n_out[None, :] < N_out)
w_down_mask = (offs_k_inter[:, None] < K_inter) & (offs_n_out[None, :] < N_out)
w_down_block = tl.load(w_down_ptrs, mask=w_down_mask, other=0.0) # Should be fp16
w_down_block = tl.load(w_down_ptrs, mask=w_down_mask, other=0.0) # Should be fp16


# 4. Accumulate W_down result in FP32 using FP16 intermediate input
# 4. Accumulate W_down result in FP32 using FP16 intermediate input
acc_down += tl.dot(intermediate_block_fp16, w_down_block, out_dtype=accum_dtype)
acc_down += tl.dot(intermediate_block_fp16, w_down_block, out_dtype=accum_dtype)


# --- Store Final Output ---
# --- Store Final Output ---
output_ptrs = output_ptr + offs_m[:, None] * stride_out_m + offs_n_out[None, :] * stride_out_n
output_ptrs = output_ptr + offs_m[:, None] * stride_out_m + offs_n_out[None, :] * stride_out_n
output_mask = (offs_m[:, None] < M) & (offs_n_out[None, :] < N_out)
output_mask = (offs_m[:, None] < M) & (offs_n_out[None, :] < N_out)
# Cast final accumulated FP32 result to output dtype (e.g., FP16)
# Cast final accumulated FP32 result to output dtype (e.g., FP16)
tl.store(output_ptrs, acc_down.to(output_ptr.dtype.element_ty), mask=output_mask)
tl.store(output_ptrs, acc_down.to(output_ptr.dtype.element_ty), mask=output_mask)


class TritonExpert(nn.Module):
class TritonExpert(nn.Module):
"""Wrapper for the Triton FFN kernel."""
"""Wrapper for the Triton FFN kernel."""
def __init__(self, config: Dict, W_gate: torch.Tensor, W_up: torch.Tensor, W_down: torch.Tensor, d_expert_actual: int):
def __init__(self, config: Dict, W_gate: torch.Tensor, W_up: torch.Tensor, W_down: torch.Tensor, d_expert_actual: int):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.d_hidden: int = config["d_hidden"]
self.d_hidden: int = config["d_hidden"]
self.d_expert_actual = d_expert_actual
self.d_expert_actual = d_expert_actual


# Store weights directly, assuming they are already in the correct layout for the kernel
# Store weights directly, assuming they are already in the correct layout for the kernel
# Kernel expects:
# Kernel expects:
# W_gate: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_gate: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_up: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_up: [K_in, K_inter] = [d_hidden, d_expert_actual]
# W_down: [K_inter, N_out] = [d_expert_actual, d_hidden]
# W_down: [K_inter, N_out] = [d_expert_actual, d_hidden]
# We assume the input weights dict already has these shapes.
# We assume the input weights dict already has these shapes.
self.W_gate = W_gate
self.W_gate = W_gate
self.W_up = W_up
self.W_up = W_up
self.W_down = W_down
self.W_down = W_down


def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
"""
Args:
Args:
x: Input tensor of shape [num_tokens, d_hidden]
x: Input tensor of shape [num_tokens, d_hidden]
Returns:
Returns:
Output tensor of shape [num_tokens, d_hidden]
Output tensor of shape [num_tokens, d_hidden]
"""
"""
if x.shape[0] == 0: # Handle empty input case
if x.shape[0] == 0: # Handle empty input case
return torch.zeros((0, self.d_hidden), device=x.device, dtype=x.dtype)
return torch.zeros((0, self.d_hidden), device=x.device, dtype=x.dtype)


M, K_in = x.shape
M, K_in = x.shape
assert K_in == self.d_hidden
assert K_in == self.d_hidden
K_inter = self.d_expert_actual
K_inter = self.d_expert_actual
N_out = self.d_hidden
N_out = self.d_hidden


# Ensure contiguous inputs
# Ensure contiguous inputs
x = x.contiguous()
x = x.contiguous()
# Weights should already be contiguous from loading
# Weights should already be contiguous from loading
assert self.W_gate.is_contiguous()
assert self.W_gate.is_contiguous()
assert self.W_up.is_contiguous()
assert self.W_up.is_contiguous()
assert self.W_down.is_contiguous()
assert self.W_down.is_contiguous()


# Allocate output tensor
# Allocate output tensor
output = torch.empty((M, N_out), device=x.device, dtype=x.dtype)
output = torch.empty((M, N_out), device=x.device, dtype=x.dtype)


# --- Kernel Launch ---
# --- Kernel Launch ---
grid = lambda META: (
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N_out, META['BLOCK_SIZE_N_OUT']),
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N_out, META['BLOCK_SIZE_N_OUT']),
)
)


BLOCK_SIZE_M = 64 # Smaller M block might be okay if M (tokens per expert) isn't huge
BLOCK_SIZE_M = 64 # Smaller M block might be okay if M (tokens per expert) isn't huge
BLOCK_SIZE_N_OUT = 64
BLOCK_SIZE_N_OUT = 64
BLOCK_SIZE_K_IN = 128 # Corresponds to d_hidden
BLOCK_SIZE_K_IN = 128 # Corresponds to d_hidden
BLOCK_SIZE_K_INTER = 128 # Corresponds to d_expert_actual
BLOCK_SIZE_K_INTER = 128 # Corresponds to d_expert_actual
GROUP_SIZE_M = 4 # For wave scheduling / locality
GROUP_SIZE_M = 4 # For wave scheduling / locality
num_warps = 8
num_warps = 8
num_stages = 1 # Adjust based on shared memory usage and latency hiding needs
num_stages = 1 # Adjust based on shared memory usage and latency hiding needs


fused_ffn_kernel_v2[grid](
fused_ffn_kernel_v2[grid](
x, self.W_gate, self.W_up, self.W_down, output,
x, self.W_gate, self.W_up, self.W_down, output,
M, N_out, K_in, K_inter,
M, N_out, K_in, K_inter,
# Strides: Torch gives byte strides, Triton needs element strides.
# Strides: Torch gives byte strides, Triton needs element strides.
x.stride(0), x.stride(1),
x.stride(0), x.stride(1),
self.W_gate.stride(0), self.W_gate.stride(1),
self.W_gate.stride(0), self.W_gate.stride(1),
self.W_up.stride(0), self.W_up.stride(1),
self.W_up.stride(0), self.W_up.stride(1),
self.W_down.stride(0), self.W_down.stride(1),
self.W_down.stride(0), self.W_down.stride(1),
output.stride(0), output.stride(1),
output.stride(0), output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N_OUT=BLOCK_SIZE_N_OUT,
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N_OUT=BLOCK_SIZE_N_OUT,
BLOCK_SIZE_K_IN=BLOCK_SIZE_K_IN, BLOCK_SIZE_K_INTER=BLOCK_SIZE_K_INTER,
BLOCK_SIZE_K_IN=BLOCK_SIZE_K_IN, BLOCK_SIZE_K_INTER=BLOCK_SIZE_K_INTER,
GROUP_SIZE_M=GROUP_SIZE_M,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=num_warps,
num_warps=num_warps,
num_stages=num_stages
num_stages=num_stages
)
)
return output
return output




@triton.jit
@triton.jit
def _triton_gate_kernel(
def _triton_gate_kernel(
X_ptr, # Pointer to input tensor x [M, K]
X_ptr, # Pointer to input tensor x [M, K]
Wg_ptr, # Pointer to gate weight tensor W_g [N, K]
Wg_ptr, # Pointer to gate weight tensor W_g [N, K]
Indices_ptr, # Pointer to output indices tensor [M, top_k] (int64)
Indices_ptr, # Pointer to output indices tensor [M, top_k] (int64)
Scores_ptr, # Pointer to output scores tensor [M, top_k]
Scores_ptr, # Pointer to output scores tensor [M, top_k]
M, # Number of tokens (bs * seq_len)
M, # Number of tokens (bs * seq_len)
N: tl.constexpr, # Number of experts
N: tl.constexpr, # Number of experts
K, # Hidden dimension
K, # Hidden dimension
stride_xm, stride_xk, # Strides for X (element stride)
stride_xm, stride_xk, # Strides for X (element stride)
stride_wn, stride_wk, # Strides for W_g (element stride)
stride_wn, stride_wk, # Strides for W_g (element stride)
stride_indices_m, stride_indices_k, # Strides for Indices (element stride)
stride_indices_m, stride_indices_k, # Strides for Indices (element stride)
stride_scores_m, stride_scores_k, # Strides for Scores (element stride)
stride_scores_m, stride_scores_k, # Strides for Scores (element stride)
top_k: tl.constexpr, # Number of experts per token (compile-time constant)
top_k: tl.constexpr, # Number of experts per token (compile-time constant)
BLOCK_M: tl.constexpr, # Tile size for M dimension
BLOCK_K: tl.constexpr, # Tile size for K dimension
BLOCK_K: tl.constexpr, # Tile size for K dimension
):
):
"""
"""
Triton kernel for MoE Gating: Fused Matmul(x, W_g.T) + Softmax + Iterative TopK.
Triton kernel for MoE Gating: Fused Matmul(x, W_g.T) + Softmax + Iterative TopK.
Each program instance computes the results for BLOCK_M tokens (rows in x).
Each program instance computes the results for BLOCK_M tokens (rows in x).
Improves Wg reuse.
Improves Wg reuse.
"""
"""
pid_m_block = tl.program_id(axis=0)
pid_m = tl.program_id(axis=0)
offs_m = pid_m_block * BLOCK_M + tl.arange(0, BLOCK_M) # Shape [BLOCK_M]
mask_m = offs_m < M # Shape [BLOCK_M]
# --- Compute Logits (x[pid_m, :] @ W_g.T) ---
x_block_ptr = X_ptr + offs_m[:, None] * stride_xm
x_row_ptr = X_ptr + pid_m * stride_xm

accumulator = tl.zeros((BLOCK_M, N), dtype=tl.float32)
offs_k = tl.arange(0, BLOCK_K)
offs_k = tl.arange(0, BLOCK_K)
wg_ptrs_base = Wg_ptr + tl.arange(0, N)[:, None] * stride_wn + offs_k[None, :] * stride_wk
offs_n = tl.arange(0, N) # Use actual N directly
wg_ptrs = Wg_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk

accumulator = tl.zeros((N,), dtype=tl.float32)


for k_start in range(0, tl.cdiv(K, BLOCK_K)):
for k_start in range(0, tl.cdiv(K, BLOCK_K)):
current_offs_k = k_start * BLOCK_K + offs_k # Shape [BLOCK_K]
current_offs_k = k_start * BLOCK_K + offs_k
x_ptrs = x_block_ptr + current_offs_k[None, :] * stride_xk
x_mask = (pid_m < M) & (current_offs_k < K)
x_mask = (mask_m[:, None]) & (current_offs_k[None, :] < K) # Shape [BLOCK_M, BLOCK_K]
x_chunk = tl.load(x_row_ptr + current_offs_k * stride_xk, mask=x_mask)[:, None]
x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0) # Shape [BLOCK_M, BLOCK_K]

wg_ptrs = wg_ptrs_base + k_start * BLOCK_K * stride_wk
wg_mask = (current_offs_k[None, :] < K) # Shape [1, BLOCK_K], broadcasts to [N, BLOCK_K]
wg_chunk = tl.load(wg_ptrs, mask=wg_mask, other=0.0) # Shape [N, BLOCK_K]

accumulator += tl.dot(x_chunk, tl.trans(wg_chunk)) # Output shape [BLOCK_M, N]

accumulator = tl.trans(accumulator, 1, 0)
current_scores_fp16 = tl.softmax(accumulator.to(tl.float16)).to(tl.float16)
current_scores_fp16 = tl.trans(current_scores_fp16, 1, 0)
# --- Iterative Top-K ---
# Pointers to the output rows for this block
indices_block_ptr = Indices_ptr + offs_m[:, None] * stride_indices_m
scores_block_ptr = Scores_ptr + offs_m[:, None] * stride_scores_m

# We need to find top_k for each row in BLOCK_M independently
# Initialize accumulators for indices and scores for the block
neg_inf_fp16 = -float('inf')
# Using tl.full requires a scalar or a tensor that matches the shape exactly
# We initialize iteratively or use a temporary scalar expansion if needed
indices_accumulator = tl.zeros((BLOCK_M, top_k), dtype=tl.int64)
scores_accumulator = tl.full((BLOCK_M, top_k), neg_inf_fp16, dtype=tl.float16)

loop_scores = current_scores_fp16 # Shape [BLOCK_M, N]
offs_n = tl.arange(0, N) # Shape [N]
offs_top_k = tl.arange(0, top_k) # Shape [top_k]


for k_iter in tl.static_range(top_k):
wg_mask = (current_offs_k[None, :] < K)
# Find the max score and its index for *each row* in the block
wg_chunk = tl.load(wg_ptrs + k_start * BLOCK_K * stride_wk, mask=wg_mask)
max_vals_fp16, max_indices = tl.max(loop_scores, axis=1, return_indices=True) # Shapes [BLOCK_M], [BLOCK_M]
dot_result = tl.dot(wg_chunk, x_chunk)
dot_result_reshaped = tl.reshape(dot_result, (N,)) # Shape is now [N]
accumulator += dot_result_reshaped


# Create a mask to select the k_iter-th column in the output accumulators
if pid_m < M:
k_mask = (offs_top_k[None, :] == k_iter) # Shape [1, top_k], broadcasts to [BLOCK_M, top_k]
current_scores_fp16 = tl.softmax(accumulator.to(tl.float16)).to(tl.float16)


# Update the accumulators for the current top element found for each row
indices_row_ptr = Indices_ptr + pid_m * stride_indices_m
indices_accumulator = tl.where(k_mask, max_indices[:, None].to(tl.int64), indices_accumulator)
scores_row_ptr = Scores_ptr + pid_m * stride_scores_m
scores_accumulator = tl.where(k_mask, max_vals_fp16[:, None], scores_accumulator)


# Set the score of the chosen expert to -inf for each row to prevent re-selection
loop_scores = current_scores_fp16
# Use broadcasting: compare offsets_n [N] with max_indices [BLOCK_M]
neg_inf_fp16 = -float('inf')
mask_out_condition = (offs_n[None, :] == max_indices[:, None]) # Shape [BLOCK_M, N]
loop_scores = tl.where(mask_out_condition, neg_inf_fp16, loop_scores)


# --- Store Results ---
for k_iter in tl.static_range(top_k):
# Pointers for storing the [BLOCK_M, top_k] results
max_val_fp16, max_idx = tl.max(loop_scores, axis=0, return_indices=True) # max_val_fp16 is fp16
indices_store_ptrs = indices_block_ptr + offs_top_k[None, :] * stride_indices_k # Shape [BLOCK_M, top_k]
tl.store(indices_row_ptr + k_iter * stride_indices_k, max_idx.to(tl.int64))
scores_store_ptrs = scores_block_ptr + offs_top_k[None, :] * stride_scores_k # Shape [BLOCK_M, top_k]
tl.store(scores_row_ptr + k_iter * stride_scores_k, max_val_fp16)
loop_scores = tl.where(tl.arange(0, N) == max_idx, neg_inf_fp16, loop_scores)


# Store results, masking out rows beyond M
tl.store(indices_store_ptrs, indices_accumulator, mask=mask_m[:, None])
tl.store(scores_store_ptrs, scores_accumulator, mask=mask_m[:, None])


class MoEGate(nn.Module):
class MoEGate(nn.Module):
# Keep PyTorch implementation for Gating
# Keep PyTorch implementation for Gating
def __init__(self, config: Dict, W_g_weight: torch.Tensor):
def __init__(self, config: Dict, W_g_weight: torch.Tensor):
super().__init__()
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.top_k: int = config["n_experts_per_token"]
self.n_experts: int = config["n_routed_experts"]
self.n_experts: int = config["n_routed_experts"]
self.register_buffer('W_g', W_g_weight) # Register as buffer
self.register_buffer('W_g', W_g_weight) # Register as buffer




def forward_pt(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward_pt(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# x: [bs*seq_len, d_hidden]
# x: [bs*seq_len, d_hidden]
# W_g: [num_experts, d_hidden]
# W_g: [num_experts, d_hidden]
logits = F.linear(x, self.W_g) # W_g needs to be [out, in] = [num_experts, d_hidden]
logits = F.linear(x, self.W_g) # W_g needs to be [out, in] = [num_experts, d_hidden]
scores = logits.softmax(dim=-1)
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores
return topk_indices, topk_scores


def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
"""
Uses the Triton kernel for MoE gating (Linear + Softmax + Iterative TopK).
Uses the Triton kernel for MoE gating (Linear + Softmax + Iterative TopK).
x: Input tensor of shape [M, K] = [bs*seq_len, d_hidden]
x: Input tensor of shape [M, K] = [bs*seq_len, d_hidden]
Returns: Tuple[Tensor[M, top_k], Tensor[M, top_k]] (Indices: int64, Scores: x.dtype)
Returns: Tuple[Tensor[M, top_k], Tensor[M, top_k]] (Indices: int64, Scores: x.dtype)
"""
"""

x = x.contiguous()
x = x.contiguous()
M, K = x.shape
M, K = x.shape
N = self.n_experts
N = self.n_experts


topk_indices = torch.empty((M, self.top_k), device=x.device, dtype=torch.int64)
topk_indices = torch.empty((M, self.top_k), device=x.device, dtype=torch.int64)
topk_scores = torch.empty((M, self.top_k), device=x.device, dtype=x.dtype)
topk_scores = torch.empty((M, self.top_k), device=x.device, dtype=x.dtype)



# --- Kernel Config ---
BLOCK_M = 32 # Or 8, 32. Start tuning here. Must be power of 2.
BLOCK_K = 64
BLOCK_K = 64 # Or 64, 256. Depends on K and shared memory. Must be power of 2.
grid = (M,) # One program per token

# num_warps and num_stages can also be tuned
num_warps = 4 # Typical starting point
num_stages = 2 # Or 3. Helps hide latency.

grid = (triton.cdiv(M, BLOCK_M),) # Define grid size based on tiled M dimension

_triton_gate_kernel[grid](
_triton_gate_kernel[grid](
x, self.W_g, topk_indices, topk_scores,
x, self.W_g, topk_indices, topk_scores,
M, N, K,
M, N, K,
x.stride(0), x.stride(1),
x.stride(0), x.stride(1),
self.W_g.stride(0), self.W_g.stride(1),
self.W_g.stride(0), self.W_g.stride(1),
topk_indices.stride(0), topk_indices.stride(1),
topk_indices.stride(0), topk_indices.stride(1),
topk_scores.stride(0), topk_scores.stride(1),
topk_scores.stride(0), topk_scores.stride(1),
top_k=self.top_k,
top_k=self.top_k,
BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_K,
BLOCK_K=BLOCK_K,
num_warps=num_warps,
# num_warps=4, # Optional tuning
num_stages=num_stages
# num_stages=2 # Optional tuning
)
)


return topk_indices, topk_scores
return topk_indices, topk_scores


class TritonMoE(nn.Module):
class TritonMoE(nn.Module):
def __init__(self, config: Dict, weights: Dict[str, torch.Tensor]):
def __init__(self, config: Dict, weights: Dict[str, torch.Tensor]):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.num_experts = config["n_routed_experts"]
self.num_experts = config["n_routed_experts"]
self.top_k = config["n_experts_per_token"]
self.top_k = config["n_experts_per_token"]
self.d_hidden = config["d_hidden"]
self.d_hidden = config["d_hidden"]
self.d_expert = config["d_expert"]
self.d_expert = config["d_expert"]


# --- Gating Network ---
# --- Gating Network ---
# W_g weight shape from dict: [num_experts, d_hidden]
# W_g weight shape from dict: [num_experts, d_hidden]
self.gating_network = MoEGate(config, weights['router.weight'])
self.gating_network = MoEGate(config, weights['router.weight'])


# --- Experts ---
# --- Experts ---
self.experts = nn.ModuleList()
self.experts = nn.ModuleList()
for i in range(self.num_experts):
for i in range(self.num_experts):
# Weights from dict:
# Weights from dict:
# W_gate: [d_hidden, d_expert]
# W_gate: [d_hidden, d_expert]
# W_up: [d_hidden, d_expert]
# W_up: [d_hidden, d_expert]
# W_down: [d_expert, d_hidden]
# W_down: [d_expert, d_hidden]
w_gate = weights[f'experts.{i}.0.weight']
w_gate = weights[f'experts.{i}.0.weight']
w_up = weights[f'experts.{i}.1.weight']
w_up = weights[f'experts.{i}.1.weight']
w_down = weights[f'experts.{i}.2.weight']
w_down = weights[f'experts.{i}.2.weight']
self.experts.append(TritonExpert(config, w_gate, w_up, w_down, self.d_expert))
self.experts.append(TritonExpert(config, w_gate, w_up, w_down, self.d_expert))


# --- Shared Expert ---
# --- Shared Expert ---
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
# Weights from dict:
# Weights from dict:
# W_gate: [d_hidden, shared_expert_dim]
# W_gate: [d_hidden, shared_expert_dim]
# W_up: [d_hidden, shared_expert_dim]
# W_up: [d_hidden, shared_expert_dim]
# W_down: [shared_expert_dim, d_hidden]
# W_down: [shared_expert_dim, d_hidden]
w_gate_shared = weights['shared_experts.0.weight']
w_gate_shared = weights['shared_experts.0.weight']
w_up_shared = weights['shared_experts.1.weight']
w_up_shared = weights['shared_experts.1.weight']
w_down_shared = weights['shared_experts.2.weight']
w_down_shared = weights['shared_experts.2.weight']
self.shared_expert = TritonExpert(config, w_gate_shared, w_up_shared, w_down_shared, shared_expert_dim)
self.shared_expert = TritonExpert(config, w_gate_shared, w_up_shared, w_down_shared, shared_expert_dim)




def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape # [batch_size, seq_len, d_hidden]
orig_shape = x.shape # [batch_size, seq_len, d_hidden]
x_flat = x.view(-1, self.d_hidden) # [bs*seq_len, d_hidden]
x_flat = x.view(-1, self.d_hidden) # [bs*seq_len, d_hidden]


# --- Gating & Routing ---
# --- Gating & Routing ---
expert_indices, expert_scores = self.gating_network(x_flat)
expert_indices, expert_scores = self.gating_network(x_flat)
flat_expert_indices = expert_indices.view(-1) # [bs*seq_len * top_k]
flat_expert_indices = expert_indices.view(-1) # [bs*seq_len * top_k]
flat_expert_weights = expert_scores.view(-1, 1) # [bs*seq_len * top_k, 1]
flat_expert_weights = expert_scores.view(-1, 1) # [bs*seq_len * top_k, 1]
token_ids = torch.arange(x_flat.shape[0], device=x.device).repeat_interleave(self.top_k) # [bs*seq_len * top_k]
token_ids = torch.arange(x_flat.shape[0], device=x.device).repeat_interleave(self.top_k) # [bs*seq_len * top_k]


# Sort by expert index for potentially more coalesced processing within each expert call
# Sort by expert index for potentially more coalesced processing within each expert call
idxs_sorted = flat_expert_indices.argsort()
idxs_sorted = flat_expert_indices.argsort()
sorted_expert_indices = flat_expert_indices[idxs_sorted]
sorted_expert_indices = flat_expert_indices[idxs_sorted]
sorted_token_ids = token_ids[idxs_sorted]
sorted_token_ids = token_ids[idxs_sorted]
sorted_weights = flat_expert_weights[idxs_sorted]
sorted_weights = flat_expert_weights[idxs_sorted]


# Find boundaries for each expert in the sorted list
# Find boundaries for each expert in the sorted list
expert_boundaries = torch.searchsorted(sorted_expert_indices, torch.arange(self.num_experts + 1, device=x.device))
expert_boundaries = torch.searchsorted(sorted_expert_indices, torch.arange(self.num_experts + 1, device=x.device))
all_weighted_expert_outputs = torch.zeros(x_flat.shape[0] * self.top_k, self.d_hidden, device=x.device, dtype=x.dtype)
all_weighted_expert_outputs = torch.zeros(x_flat.shape[0] * self.top_k, self.d_hidden, device=x.device, dtype=x.dtype)
for expert_id in range(self.num_experts):
for expert_id in range(self.num_experts):
start_idx = expert_boundaries[expert_id]
start_idx = expert_boundaries[expert_id]
end_idx = expert_boundaries[expert_id + 1]
end_idx = expert_boundaries[expert_id + 1]


if start_idx == end_idx:
if start_idx == end_idx:
continue # No tokens routed to this expert
continue # No tokens routed to this expert


expert = self.experts[expert_id]
expert = self.experts[expert_id]
current_token_indices = sorted_token_ids[start_idx:end_idx]
current_token_indices = sorted_token_ids[start_idx:end_idx]
current_weights = sorted_weights[start_idx:end_idx]
current_weights = sorted_weights[start_idx:end_idx]


# Gather tokens for the current expert
# Gather tokens for the current expert
expert_tokens = x_flat[current_token_indices] # Gather operation
expert_tokens = x_flat[current_token_indices] # Gather operation


# Run expert FFN using Triton kernel
# Run expert FFN using Triton kernel
expert_out = expert(expert_tokens) # Calls TritonExpert.forward -> fused_ffn_kernel_v2
expert_out = expert(expert_tokens) # Calls TritonExpert.forward -> fused_ffn_kernel_v2


weighted_expert_out = expert_out * current_weights # Removed .to(torch.float32)
weighted_expert_out = expert_out * current_weights # Removed .to(torch.float32)
all_weighted_expert_outputs[start_idx:end_idx] = weighted_expert_out # Store fp16 result
all_weighted_expert_outputs[start_idx:end_idx] = weighted_expert_out # Store fp16 result


scatter_indices = sorted_token_ids.view(-1, 1).expand(-1, self.d_hidden)
scatter_indices = sorted_token_ids.view(-1, 1).expand(-1, self.d_hidden)


shared_output_flat = self.shared_expert(x_flat)
shared_output_flat = self.shared_expert(x_flat)
shared_output_flat.scatter_reduce_(
shared_output_flat.scatter_reduce_(
dim=0,
dim=0,
index=scatter_indices, # Indices corresponding to all_weighted_expert_outputs
index=scatter_indices, # Indices corresponding to all_weighted_expert_outputs
src=all_weighted_expert_outputs, # The permuted results
src=all_weighted_expert_outputs, # The permuted results
reduce="sum",
reduce="sum",
include_self=True # Important: Don't add the initial zeros in expert_cache
include_self=True # Important: Don't add the initial zeros in expert_cache
)
)
# Reshape back to original shape
# Reshape back to original shape
output = shared_output_flat.view(orig_shape)
output = shared_output_flat.view(orig_shape)


return output
return output




def custom_kernel(data: input_t) -> output_t:
def custom_kernel(data: input_t) -> output_t:
"""
"""
Uses Triton for FFN computations and PyTorch for routing/gather/scatter.
Uses Triton for FFN computations and PyTorch for routing/gather/scatter.
Triton implementation of DeepSeek-style Mixture of Experts.
Triton implementation of DeepSeek-style Mixture of Experts.


Args:
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- weights: Dictionary containing model weights
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
- config: Dictionary containing model configuration parameters


Returns:
Returns:
Output tensor [batch_size, seq_len, d_model]
Output tensor [batch_size, seq_len, d_model]
(Aux data dictionary is omitted as per template/inference focus)
(Aux data dictionary is omitted as per template/inference focus)
"""
"""
input_tensor, weights, config = data
input_tensor, weights, config = data


# Instantiate the MoE model with Triton experts
# Instantiate the MoE model with Triton experts
triton_moe_model = TritonMoE(config, weights)
triton_moe_model = TritonMoE(config, weights)


# Run the model
# Run the model
# Ensure model and input are on the same device (should be CUDA)
# Ensure model and input are on the same device (should be CUDA)
input_tensor = input_tensor.to('cuda')
input_tensor = input_tensor.to('cuda')
triton_moe_model.to('cuda')
triton_moe_model.to('cuda')


# Ensure weights are contiguous (important for Triton)
# Ensure weights are contiguous (important for Triton)
for k, v in weights.items():
for k, v in weights.items():
weights[k] = v.contiguous()
weights[k] = v.contiguous()


# It's crucial that the input dtype matches what the kernel expects (float16 usually)
# It's crucial that the input dtype matches what the kernel expects (float16 usually)
input_dtype = next(iter(weights.values())).dtype # Get dtype from weights
input_dtype = next(iter(weights.values())).dtype # Get dtype from weights
input_tensor = input_tensor.to(input_dtype)
input_tensor = input_tensor.to(input_dtype)


output = triton_moe_model(input_tensor)
output = triton_moe_model(input_tensor)


# The competition framework seems to expect just the tensor output, not the aux_data tuple
# The competition framework seems to expect just the tensor output, not the aux_data tuple
return output
return output
'''
'''
from reference import generate_input, ref_kernel, generate_input
from reference import generate_input, ref_kernel, generate_input
from utils import make_match_reference
from utils import make_match_reference
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 4, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 512, "seed": 9371}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 4, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 512, "seed": 9371}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 8, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 8192, "seed": 81934}
args = {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 8, "nsharedexperts": 1, "nexpertspertoken": 4, "bs": 1, "seqlen": 8192, "seed": 1212}
inp = generate_input(**args)
inp = generate_input(**args)
out = custom_kernel(generate_input(**args))
out = custom_kernel(generate_input(**args))
ref = make_match_reference(ref_kernel)
ref = make_match_reference(ref_kernel)
print(ref(inp, out))
print(ref(inp, out))
'''
'''