Fwd 02/09 vs. 02/14

Created Diff never expires
0 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
383 lines
0 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
383 lines
# AOT ID: ['0_forward']
# AOT ID: ['0_forward']
from ctypes import c_void_p, c_long, c_int
from ctypes import c_void_p, c_long, c_int
import torch
import torch
import math
import math
import random
import random
import os
import os
import tempfile
import tempfile
from math import inf, nan
from math import inf, nan
from cmath import nanj
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton
import triton.language as tl
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
from torch._inductor.runtime.triton_heuristics import (
grid,
grid,
split_scan_grid,
split_scan_grid,
grid_combo_kernels,
grid_combo_kernels,
start_graph,
start_graph,
end_graph,
end_graph,
cooperative_reduction_grid,
cooperative_reduction_grid,
)
)
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch._inductor.kernel.flex_attention
import torch._inductor.kernel.flex_attention


aten = torch.ops.aten
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p




# kernel path: /tmp/torchinductor_root/py/cpy464bp6ukc5g2wp7hq2jbqpvoi3yp2gkjew2p7up4q2t7wzimf.py
# kernel path: /tmp/torchinductor_root/py/cpy464bp6ukc5g2wp7hq2jbqpvoi3yp2gkjew2p7up4q2t7wzimf.py
# Topologically Sorted Source Nodes: [embedding_3, x, mul, mul_1, x_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Topologically Sorted Source Nodes: [embedding_3, x, mul, mul_1, x_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# embedding_3 => embedding_3
# embedding_3 => embedding_3
# mul => mul_1
# mul => mul_1
# mul_1 => mul_2
# mul_1 => mul_2
# x => add, convert_element_type_10, convert_element_type_11, mean, mul, pow_1, rsqrt
# x => add, convert_element_type_10, convert_element_type_11, mean, mul, pow_1, rsqrt
# x_1 => add_1
# x_1 => add_1
# Graph fragment:
# Graph fragment:
# %embedding_3 : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%primals_6, %primals_1), kwargs = {})
# %embedding_3 : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%primals_6, %primals_1), kwargs = {})
# %convert_element_type_10 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %convert_element_type_10 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_10, 2), kwargs = {})
# %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_10, 2), kwargs = {})
# %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [2], True), kwargs = {})
# %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [2], True), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mean, 1.1920928955078125e-07), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mean, 1.1920928955078125e-07), kwargs = {})
# %rsqrt : [num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
# %rsqrt : [num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=17] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %convert_element_type_11 : [num_users=17] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_4, %convert_element_type_11), kwargs = {})
# %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_4, %convert_element_type_11), kwargs = {})
# %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_5, %convert_element_type_11), kwargs = {})
# %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_5, %convert_element_type_11), kwargs = {})
# %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, %mul_2), kwargs = {})
# %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, %mul_2), kwargs = {})
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', '''
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
triton_helpers.set_driver_to_gpu()


@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
)
@triton.jit
@triton.jit
def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, r0_numel):
def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, r0_numel):
xnumel = 65536
xnumel = 65536
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
r0_numel = 1024
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
roffset = r0_offset
rindex = r0_index
rindex = r0_index
x0 = xindex
x0 = xindex
r0_1 = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp17 = tl.load(in_ptr2 + (16))
tmp17 = tl.load(in_ptr2 + (16))
tmp18 = tl.broadcast_to(tmp17, [R0_BLOCK])
tmp18 = tl.broadcast_to(tmp17, [R0_BLOCK])
tmp23 = tl.load(in_ptr2 + (17))
tmp23 = tl.load(in_ptr2 + (17))
tmp24 = tl.broadcast_to(tmp23, [R0_BLOCK])
tmp24 = tl.broadcast_to(tmp23, [R0_BLOCK])
tmp1 = tl.full([R0_BLOCK], 50257, tl.int32)
tmp1 = tl.full([R0_BLOCK], 50257, tl.int32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 50257), "index out of bounds: 0 <= tmp4 < 50257")
tl.device_assert((0 <= tmp4) & (tmp4 < 50257), "index out of bounds: 0 <= tmp4 < 50257")
tmp6 = tl.load(in_ptr1 + (r0_1 + 1024*tmp4), None).to(tl.float32)
tmp6 = tl.load(in_ptr1 + (r0_1 + 1024*tmp4), None).to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7 * tmp7
tmp8 = tmp7 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp12 = 1024.0
tmp12 = 1024.0
tmp13 = tmp11 / tmp12
tmp13 = tmp11 / tmp12
tmp14 = 1.1920928955078125e-07
tmp14 = 1.1920928955078125e-07
tmp15 = tmp13 + tmp14
tmp15 = tmp13 + tmp14
tmp16 = libdevice.rsqrt(tmp15)
tmp16 = libdevice.rsqrt(tmp15)
tmp19 = tmp18.to(tl.float32)
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp7 * tmp16
tmp20 = tmp7 * tmp16
tmp21 = tmp20.to(tl.float32)
tmp21 = tmp20.to(tl.float32)
tmp22 = tmp19 * tmp21
tmp22 = tmp19 * tmp21
tmp25 = tmp24.to(tl.float32)
tmp25 = tmp24.to(tl.float32)
tmp26 = tmp25 * tmp21
tmp26 = tmp25 * tmp21
tmp27 = tmp22 + tmp26
tmp27 = tmp22 + tmp26
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp6, None)
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp6, None)
tl.debug_barrier()
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp16, None)
tl.store(in_out_ptr0 + (x0), tmp16, None)
tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp27, None)
tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp27, None)
''', device_str='cuda')
''', device_str='cuda')




# kernel path: /tmp/torchinductor_root/ur/curpusl7dnnsyt2hqus4f27u67wivplzuyflwhcwb2vy5hcbp3bi.py
# kernel path: /tmp/torchinductor_root/ur/curpusl7dnnsyt2hqus4f27u67wivplzuyflwhcwb2vy5hcbp3bi.py
# Topologically Sorted Source Nodes: [embedding, v_1, mul_10, mul_11, v_2], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Topologically Sorted Source Nodes: [embedding, v_1, mul_10, mul_11, v_2], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# embedding => embedding
# embedding => embedding
# mul_10 => mul_14
# mul_10 => mul_14
# mul_11 => mul_15
# mul_11 => mul_15
# v_1 => add_8, convert_element_type_22, convert_element_type_23, mean_3, mul_13, pow_4, rsqrt_3
# v_1 => add_8, convert_element_type_22, convert_element_type_23, mean_3, mul_13, pow_4, rsqrt_3
# v_2 => add_9
# v_2 => add_9
# Graph fragment:
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%primals_2, %primals_1), kwargs = {})
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%primals_2, %primals_1), kwargs = {})
# %convert_element_type_22 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_14, torch.float32), kwargs = {})
# %convert_element_type_22 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_14, torch.float32), kwargs = {})
# %pow_4 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_22, 2), kwargs = {})
# %pow_4 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_22, 2), kwargs = {})
# %mean_3 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [3], True), kwargs = {})
# %mean_3 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [3], True), kwargs = {})
# %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mean_3, 1.1920928955078125e-07), kwargs = {})
# %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mean_3, 1.1920928955078125e-07), kwargs = {})
# %rsqrt_3 : [num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {})
# %rsqrt_3 : [num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {})
# %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_22, %rsqrt_3), kwargs = {})
# %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_22, %rsqrt_3), kwargs = {})
# %convert_element_type_23 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_13, torch.bfloat16), kwargs = {})
# %convert_element_type_23 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_13, torch.bfloat16), kwargs = {})
# %mul_14 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_6, %convert_element_type_23), kwargs = {})
# %mul_14 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_6, %convert_element_type_23), kwargs = {})
# %mul_15 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_7, %view_12), kwargs = {})
# %mul_15 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_7, %view_12), kwargs = {})
# %add_9 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_14, %mul_15), kwargs = {})
# %add_9 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_14, %mul_15), kwargs = {})
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_1 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_1', '''
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_1 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_1', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
triton_helpers.set_driver_to_gpu()


@triton_heuristics.reduction(
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*i32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr0': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': '*i32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr0': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
)
@triton.jit
@triton.jit
def triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_1(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
def triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_1(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
xnumel = 524288
r0_numel = 128
r0_numel = 128
rnumel = r0_numel
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
rbase = r0_base
x0 = (xindex % 8)
x0 = (xindex % 8)
x1 = xindex // 8
x1 = xindex // 8
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
x3 = xindex
x3 = xindex
for r0_offset in range(0, r0_numel, R0_BLOCK):
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
r0_mask = r0_index < r0_numel
roffset = r0_offset
roffset = r0_offset
rindex = r0_index
rindex = r0_index
r0_2 = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (2048 + r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (2048 + r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tmp1 * tmp1
tmp2 = tmp1 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
tmp5 = _tmp4 + tmp3
tmp5 = _tmp4 + tmp3
_tmp4 = tl.where(r0_mask, tmp5, _tmp4)
_tmp4 = tl.where(r0_mask, tmp5, _tmp4)
tmp4 = tl.sum(_tmp4, 1)[:, None]
tmp4 = tl.sum(_tmp4, 1)[:, None]
tmp6 = 128.0
tmp6 = 128.0
tmp7 = tmp4 / tmp6
tmp7 = tmp4 / tmp6
tmp8 = 1.1920928955078125e-07
tmp8 = 1.1920928955078125e-07
tmp9 = tmp7 + tmp8
tmp9 = tmp7 + tmp8
tmp10 = libdevice.rsqrt(tmp9)
tmp10 = libdevice.rsqrt(tmp9)
tl.debug_barrier()
tl.debug_barrier()
tl.store(in_out_ptr0 + (x3), tmp10, None)
tl.store(in_out_ptr0 + (x3), tmp10, None)
tmp11 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
tmp11 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
tmp18 = tl.load(in_ptr3 + (48))
tmp18 = tl.load(in_ptr3 + (48))
tmp19 = tl.broadcast_to(tmp18, [XBLOCK, R0_BLOCK])
tmp19 = tl.broadcast_to(tmp18, [XBLOCK, R0_BLOCK])
tmp26 = tl.load(in_ptr3 + (49))
tmp26 = tl.load(in_ptr3 + (49))
tmp27 = tl.broadcast_to(tmp26, [XBLOCK, R0_BLOCK])
tmp27 = tl.broadcast_to(tmp26, [XBLOCK, R0_BLOCK])
for r0_offset in range(0, r0_numel, R0_BLOCK):
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
r0_mask = r0_index < r0_numel
roffset = r0_offset
roffset = r0_offset
rindex = r0_index
rindex = r0_index
r0_2 = r0_index
r0_2 = r0_index
tmp21 = tl.load(in_ptr0 + (2048 + r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp21 = tl.load(in_ptr0 + (2048 + r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.full([XBLOCK, R0_BLOCK], 50257, tl.int32)
tmp12 = tl.full([XBLOCK, R0_BLOCK], 50257, tl.int32)
tmp13 = tmp11 + tmp12
tmp13 = tmp11 + tmp12
tmp14 = tmp11 < 0
tmp14 = tmp11 < 0
tmp15 = tl.where(tmp14, tmp13, tmp11)
tmp15 = tl.where(tmp14, tmp13, tmp11)
tl.device_assert((0 <= tmp15) & (tmp15 < 50257), "index out of bounds: 0 <= tmp15 < 50257")
tl.device_assert((0 <= tmp15) & (tmp15 < 50257), "index out of bounds: 0 <= tmp15 < 50257")
tmp17 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 1024*tmp15), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp17 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 1024*tmp15), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp20 = tmp19.to(tl.float32)
tmp22 = tmp21.to(tl.float32)
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp22 * tmp10
tmp23 = tmp22 * tmp10
tmp24 = tmp23.to(tl.float32)
tmp24 = tmp23.to(tl.float32)
tmp25 = tmp20 * tmp24
tmp25 = tmp20 * tmp24
tmp28 = tmp27.to(tl.float32)
tmp28 = tmp27.to(tl.float32)
tmp29 = tmp28 * tmp17
tmp29 = tmp28 * tmp17
tmp30 = tmp25 + tmp29
tmp30 = tmp25 + tmp29
tl.store(out_ptr0 + (r0_2 + 128*x3), tmp17, r0_mask)
tl.store(out_ptr0 + (r0_2 + 128*x3), tmp17, r0_mask)
tl.store(out_ptr1 + (r0_2 + 128*x3), tmp30, r0_mask)
tl.store(out_ptr1 + (r0_2 + 128*x3), tmp30, r0_mask)
''', device_str='cuda')
''', device_str='cuda')




# kernel path: /tmp/torchinductor_root/6m/c6mw7nbczjgpvhrjjk6pggm7fvu5xptqqgclx5n5k4tqobkftzww.py
# kernel path: /tmp/torchinductor_root/6m/c6mw7nbczjgpvhrjjk6pggm7fvu5xptqqgclx5n5k4tqobkftzww.py
# Topologically Sorted Source Nodes: [eq, cumsum], Original ATen: [aten.eq, aten.cumsum]
# Topologically Sorted Source Nodes: [eq, cumsum], Original ATen: [aten.eq, aten.cumsum]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# cumsum => cumsum
# cumsum => cumsum
# eq => eq
# eq => eq
# Graph fragment:
# Graph fragment:
# %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%primals_1, 50256), kwargs = {})
# %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%primals_1, 50256), kwargs = {})
# %cumsum : [num_users=17] = call_function[target=torch.ops.aten.cumsum.default](args = (%eq, 0), kwargs = {})
# %cumsum : [num_users=17] = call_function[target=torch.ops.aten.cumsum.default](args = (%eq, 0), kwargs = {})
triton_spl_fused_cumsum_eq_2 = async_compile.triton('triton_spl_fused_cumsum_eq_2', '''
triton_spl_fused_cumsum_eq_2 = async_compile.triton('triton_spl_fused_cumsum_eq_2', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
triton_helpers.set_driver_to_gpu()


@triton.jit
@triton.jit
def _triton_helper_fn_add0(arg0_0, arg1_0):
def _triton_helper_fn_add0(arg0_0, arg1_0):
tmp0 = arg0_0 + arg1_0
tmp0 = arg0_0 + arg1_0
return tmp0
return tmp0


@triton_heuristics.split_scan(
@triton_heuristics.split_scan(
size_hints={'x': 1, 'r0_': 65536},
size_hints={'x': 1, 'r0_': 65536},
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i64', 'ws_ptr': '*u8', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 4), 'tt.equal_to': (3,)}, 'cls': 'AttrsDescriptor'})]},
triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i64', 'ws_ptr': '*u8', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 4), 'tt.equal_to': (3,)}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_spl_fused_cumsum_eq_2', 'mutated_arg_names': ['ws_ptr'], 'optimize_mem': False, 'no_x_dim': True, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_spl_fused_cumsum_eq_2', 'mutated_arg_names': ['ws_ptr'], 'optimize_mem': False, 'no_x_dim': True, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
)
@triton.jit
@triton.jit
def triton_spl_fused_cumsum_eq_2(in_ptr0, out_ptr0, ws_ptr, xnumel, r0_numel, R0_BLOCK : tl.constexpr):
def triton_spl_fused_cumsum_eq_2(in_ptr0, out_ptr0, ws_ptr, xnumel, r0_numel, R0_BLOCK : tl.constexpr):
xnumel = 1
xnumel = 1
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
r0_numel = 65536
r0_numel = 65536
rnumel = r0_numel
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(1) * XBLOCK
xoffset = tl.program_id(1) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_offset = tl.program_id(0) * R0_BLOCK
r0_offset = tl.program_id(0) * R0_BLOCK
r0_index = r0_offset + tl.arange(0, R0_BLOCK)[:]
r0_index = r0_offset + tl.arange(0, R0_BLOCK)[:]
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
roffset = r0_offset
rindex = r0_index
rindex = r0_index
r0_0 = r0_index
r0_0 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_0), None, eviction_policy='evict_last')
tmp0 = tl.load(in_ptr0 + (r0_0), None, eviction_policy='evict_last')
tmp4 = tl.num_programs(0)
tmp4 = tl.num_programs(0)
tmp5 = ws_ptr.to(tl.pointer_type(tl.uint64)) + xoffset * 3 * tmp4
tmp5 = ws_ptr.to(tl.pointer_type(tl.uint64)) + xoffset * 3 * tmp4
tmp1 = tl.full([1], 50256, tl.int32)
tmp1 = tl.full([1], 50256, tl.int32)
tmp2 = tmp0 == tmp1
tmp2 = tmp0 == tmp1
tmp3 = tmp2.to(tl.int64)
tmp3 = tmp2.to(tl.int64)
tmp6 = tmp3.to(tl.int64)
tmp6 = tmp3.to(tl.int64)
tmp7 = tl.broadcast_to(tmp6, [R0_BLOCK])
tmp7 = tl.broadcast_to(tmp6, [R0_BLOCK])
tmp8 = tl.reduce(tmp7, 0, _triton_helper_fn_add0)
tmp8 = tl.reduce(tmp7, 0, _triton_helper_fn_add0)
tmp9 = triton_helpers.exclusive_scan_decoupled_lookback_64(
tmp9 = triton_helpers.exclusive_scan_decoupled_lookback_64(
tmp5,
tmp5,
tmp8,
tmp8,
tl.program_id(0),
tl.program_id(0),
_triton_helper_fn_add0,
_triton_helper_fn_add0,
)
)
tmp10 = tl.associative_scan(tmp7, 0, _triton_helper_fn_add0)
tmp10 = tl.associative_scan(tmp7, 0, _triton_helper_fn_add0)
tmp11 = _triton_helper_fn_add0(tmp9, tmp10)
tmp11 = _triton_helper_fn_add0(tmp9, tmp10)
tmp12 = tl.where(roffset == 0, tmp10, tmp11)
tmp12 = tl.where(roffset == 0, tmp10, tmp11)
tl.store(out_ptr0 + (tl.broadcast_to(r0_0, [R0_BLOCK])), tmp12, None)
tl.store(out_ptr0 + (tl.broadcast_to(r0_0, [R0_BLOCK])), tmp12, None)
''', device_str='cuda')
''', device_str='cuda')




# kernel path: /tmp/torchinductor_root/bk/cbk7z3iym6kwqv7oruouptcabnpwb6ym23vicolbtktxjwhgmno4.py
# kernel path: /tmp/torchinductor_root/bk/cbk7z3iym6kwqv7oruouptcabnpwb6ym23vicolbtktxjwhgmno4.py
# Topologically Sorted Source Nodes: [block_idx, causal_blockmask_any, causal_blockmask_all, docs_low, docs_high, le, ge_1, document_blockmask_any, eq_1, eq_2, document_blockmask_all, blockmask_any, blockmask_all, invert, and__4, num_blocks, argsort, num_blocks_1, argsort_1, sub, clamp_min, child, sub_1, child_4, floordiv, sub_2, clamp_min_1, child_8, sub_3, child_11], Original ATen: [aten.arange, aten.ge, aten.gt, aten.clone, aten.le, aten.bitwise_and, aten.eq, aten.bitwise_not, aten.sum, aten.sort, aten.sub, aten.clamp_min, aten.clamp_max, aten.floor_divide]
# Topologically Sorted Source Nodes: [block_idx, causal_blockmask_any, causal_blockmask_all, docs_low, docs_high, le, ge_1, document_blockmask_any, eq_1, eq_2, document_blockmask_all, blockmask_any, blockmask_all, invert, and__4, num_blocks, argsort, num_blocks_1, argsort_1, sub, clamp_min, child, sub_1, child_4, floordiv, sub_2, clamp_min_1, child_8, sub_3, child_11], Original ATen: [aten.arange, aten.ge, aten.gt, aten.clone, aten.le, aten.bitwise_and, aten.eq, aten.bitwise_not, aten.sum, aten.sort, aten.sub, aten.clamp_min, aten.clamp_max, aten.floor_divide]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# and__4 => bitwise_and_4
# and__4 => bitwise_and_4
# argsort => sort
# argsort => sort
# argsort_1 => sort_1
# argsort_1 => sort_1
# block_idx => iota
# block_idx => iota
# blockmask_all => bitwise_and_3
# blockmask_all => bitwise_and_3
# blockmask_any => bitwise_and_2
# blockmask_any => bitwise_and_2
# causal_blockmask_all => gt
# causal_blockmask_all => gt
# causal_blockmask_any => ge
# causal_blockmask_any => ge
# child => clamp_max
# child => clamp_max
# child_11 => clamp_max_3
# child_11 => clamp_max_3
# child_4 => clamp_max_1
# child_4 => clamp_max_1
# child_8 => clamp_max_2
# child_8 => clamp_max_2
# clamp_min => clamp_min
# clamp_min => clamp_min
# clamp_min_1 => clamp_min_1
# clamp_min_1 => clamp_min_1
# docs_high => clone_1
# docs_high => clone_1
# docs_low => clone
# docs_low => clone
# document_blockmask_all => bitwise_and_1
# document_blockmask_all => bitwise_and_1
# document_blockmask_any => bitwise_and
# document_blockmask_any => bitwise_and
# eq_1 => eq_1
# eq_1 => eq_1
# eq_2 => eq_2
# eq_2 => eq_2
# floordiv => div
# floordiv => div
# ge_1 => ge_1
# ge_1 => ge_1
# invert => bitwise_not
# invert => bitwise_not
# le => le
# le => le
# num_blocks => sum_1
# num_blocks => sum_1
# num_blocks_1 => sum_2
# num_blocks_1 => sum_2
# sub => sub
# sub => sub
# sub_1 => sub_1
# sub_1 => sub_1
# sub_2 => sub_2
# sub_2 => sub_2
# sub_3 => sub_3
# sub_3 => sub_3
# Graph fragment:
# Graph fragment:
# %iota : [num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (512,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda, requires_grad: False})
# %iota : [num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (512,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda, requires_grad: False})
# %ge : [num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%unsqueeze, %iota), kwargs = {})
# %ge : [num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%unsqueeze, %iota), kwargs = {})
# %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Tensor](args = (%unsqueeze, %iota), kwargs = {})
# %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Tensor](args = (%unsqueeze, %iota), kwargs = {})
# %clone : [num_users=3] = call_function[target=torch.ops.aten.clone.default](args = (%select,), kwargs = {memory_format: torch.contiguous_format})
# %clone : [num_users=3] = call_function[target=torch.ops.aten.clone.default](args = (%select,), kwargs = {memory_format: torch.contiguous_format})
# %clone_1 : [num_users=3] = call_function[target=torch.ops.aten.clone.default](args = (%select_1,), kwargs = {memory_format: torch.contiguous_format})
# %clone_1 : [num_users=3] = call_function[target=torch.ops.aten.clone.default](args = (%select_1,), kwargs = {memory_format: torch.contiguous_format})
# %le : [num_users=1] = call_function[target=torch.ops.aten.le.Tensor](args = (%unsqueeze_2, %clone_1), kwargs = {})
# %le : [num_users=1] = call_function[target=torch.ops.aten.le.Tensor](args = (%unsqueeze_2, %clone_1), kwargs = {})
# %ge_1 : [num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%unsqueeze_3, %clone), kwargs = {})
# %ge_1 : [num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%unsqueeze_3, %clone), kwargs = {})
# %bitwise_and : [num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%le, %ge_1), kwargs = {})
# %bitwise_and : [num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%le, %ge_1), kwargs = {})
# %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%unsqueeze_2, %clone_1), kwargs = {})
# %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%unsqueeze_2, %clone_1), kwargs = {})
# %eq_2 : [num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%unsqueeze_3, %clone), kwargs = {})
# %eq_2 : [num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%unsqueeze_3, %clone), kwargs = {})
# %bitwise_and_1 : [num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%eq_1, %eq_2), kwargs = {})
# %bitwise_and_1 : [num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%eq_1, %eq_2), kwargs = {})
# %bitwise_and_2 : [num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {})
# %bitwise_and_2 : [num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {})
# %bitwise_and_3 : [num_users=3] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %bitwise_and_1), kwargs = {})
# %bitwise_and_3 : [num_users=3] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %bitwise_and_1), kwargs = {})
# %bitwise_not : [num_users=1] = call_function[target=torch.ops.aten.bitwise_not.default](args = (%bitwise_and_3,), kwargs = {})
# %bitwise_not : [num_users=1] = call_function[target=torch.ops.aten.bitwise_not.default](args = (%bitwise_and_3,), kwargs = {})
# %bitwise_and_4 : [num_users=2] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%bitwise_and_2, %bitwise_not), kwargs = {})
# %bitwise_and_4 : [num_users=2] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%bitwise_and_2, %bitwise_not), kwargs = {})
# %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%bitwise_and_4, [-1]), kwargs = {dtype: torch.int32})
# %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%bitwise_and_4, [-1]), kwargs = {dtype: torch.int32})
# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%bitwise_and_4,), kwargs = {stable: True})
# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%bitwise_and_4,), kwargs = {stable: True})
# %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%bitwise_and_3, [-1]), kwargs = {dtype: torch.int32})
# %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%bitwise_and_3, [-1]), kwargs = {dtype: torch.int32})
# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%bitwise_and_3,), kwargs = {stable: True})
# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%bitwise_and_3,), kwargs = {stable: True})
# %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%primals_5, %unsqueeze_11), kwargs = {})
# %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%primals_5, %unsqueeze_11), kwargs = {})
# %clamp_min : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sub, 1), kwargs = {})
# %clamp_min : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sub, 1), kwargs = {})
# %clamp_max : [num_users=6] = call_function[target=torch.ops.aten.clamp_max.Tensor](args = (%unsqueeze_7, %clamp_min), kwargs = {})
# %clamp_max : [num_users=6] = call_function[target=torch.ops.aten.clamp_max.Tensor](args = (%unsqueeze_7, %clamp_min), kwargs = {})
# %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%primals_5, 1), kwargs = {})
# %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%primals_5, 1), kwargs = {})
# %clamp_max_1 : [num_users=6] = call_function[target=torch.ops.aten.clamp_max.Tensor](args = (%unsqueeze_11, %sub_1), kwargs = {})
# %clamp_max_1 : [num_users=6] = call_function[target=torch.ops.aten.clamp_max.Tensor](args = (%unsqueeze_11, %sub_1), kwargs = {})
# %div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor_mode](args = (%primals_5, 2), kwargs = {rounding_mode: floor})
# %div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor_mode](args = (%primals_5, 2), kwargs = {rounding_mode: floor})
# %sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%div, %unsqueeze_11), kwargs = {})
# %sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%div, %unsqueeze_11), kwargs = {})
# %clamp_min_1 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sub_2, 1), kwargs = {})
# %clamp_min_1 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sub_2, 1), kwargs = {})
# %clamp_max_2 : [num_users=13] = call_function[target=torch.ops.aten.clamp_max.Tensor](args = (%unsqueeze_7, %clamp_min_1), kwargs = {})
# %clamp_max_2 : [num_users=13] = call_function[target=torch.ops.aten.clamp_max.Tensor](args = (%unsqueeze_7, %clamp_min_1), kwargs = {})
# %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%div, 1), kwargs = {})
# %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%div, 1), kwargs = {})
# %clamp_max_3 : [num_users=13] = call_function[target=torch.ops.aten.clamp_max.Tensor](args = (%unsqueeze_11, %sub_3), kwargs = {})
# %clamp_max_3 : [num_users=13] = call_function[target=torch.ops.aten.clamp_max.Tensor](args = (%unsqueeze_11, %sub_3), kwargs = {})
triton_per_fused_arange_bitwise_and_bitwise_not_clamp_max_clamp_min_clone_eq_floor_divide_ge_gt_le_sort_sub_sum_3 = async_compile.triton('triton_per_fused_arange_bitwise_and_bitwise_not_clamp_max_clamp_min_clone_eq_floor_divide_ge_gt_le_sort_sub_sum_3', '''
triton_per_fused_arange_bitwise_and_bitwise_not_clamp_max_clamp_min_clone_eq_floor_divide_ge_gt_le_sort_sub_sum_3 = async_compile.triton('triton_per_fused_arange_bitwise_and_bitwise_not_clamp_max_clamp_min_clone_eq_floor_divide_ge_gt_le_sort_sub_sum_3', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
triton_helpers.set_driver_to_gpu()


@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints={'x': 512, 'r0_': 512},
size_hints={'x': 512, 'r0_': 512},
reduction_hint=ReductionHint.DEFAULT,
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr3': '*i16', 'out_ptr5': '*i16', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr3': '*i16', 'out_ptr5': '*i16', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_arange_bitwise_and_bitwise_not_clamp_max_clamp_min_clone_eq_floor_divide_ge_gt_le_s
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_arange_bitwise_and_bitwise_not_clamp_max_clamp_min_clone_eq_floor_divide_ge_gt_le_s