output_code_before_after

Created Diff never expires
73 removals
Lines
Total469
Removed-14.1%66
Words
Total2,382
Removed-11.5%273
469 lines
66 additions
Lines
Total463
Added+13.0%60
Words
Total2,327
Added+9.4%218
463 lines


# AOT ID: ['0_inference']
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
from ctypes import c_void_p, c_long
import torch
import torch
import math
import math
import random
import random
import os
import os
import tempfile
import tempfile
from math import inf, nan
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch._inductor.codegen.memory_planning import _align as align


from torch import device, empty_strided
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.codegen.multi_kernel import MultiKernelCall


aten = torch.ops.aten
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
async_compile = AsyncCompile()




# kernel path: /tmp/torchinductor_eellison/wj/cwjmdogaso56hzkb4pdlo4k625tmwevhq5krfbpmpx5szmhdjbsf.py
# kernel path: /tmp/torchinductor_eellison/wj/cwjmdogaso56hzkb4pdlo4k625tmwevhq5krfbpmpx5szmhdjbsf.py
# Source Nodes: [embeddings, embeddings_1, embeddings_2, inputs_embeds, position_embeddings, token_type_embeddings], Original ATen: [aten.add, aten.embedding, aten.native_layer_norm]
# Source Nodes: [embeddings, embeddings_1, embeddings_2, inputs_embeds, position_embeddings, token_type_embeddings], Original ATen: [aten.add, aten.embedding, aten.native_layer_norm]
# embeddings => add
# embeddings => add
# embeddings_1 => add_1
# embeddings_1 => add_1
# embeddings_2 => add_2, add_3, convert_element_type_1, convert_element_type_2, mul_1, mul_2, rsqrt, sub_1, var_mean
# embeddings_2 => add_2, add_3, convert_element_type, convert_element_type_1, mul, mul_1, rsqrt, sub, var_mean
# inputs_embeds => embedding
# inputs_embeds => embedding
# position_embeddings => embedding_2
# position_embeddings => embedding_2
# token_type_embeddings => embedding_1
# token_type_embeddings => embedding_1
triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_', '''
triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties


@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*i64', 3: '*bf16', 4: '*i64', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: '*bf16', 9: '*bf16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]},
triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*i64', 3: '*bf16', 4: '*i64', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: '*bf16', 9: '*bf16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr3, xnumel, rnumel):
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr3, xnumel, rnumel):
xnumel = 8192
xnumel = 8192
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
rnumel = 768
rnumel = 768
RBLOCK: tl.constexpr = 1024
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
roffset = 0
rmask = rindex < rnumel
rmask = rindex < rnumel
x3 = xindex
x3 = xindex
r2 = rindex
r2 = rindex
x0 = xindex % 512
x0 = xindex % 512
tmp0 = tl.load(in_ptr0 + (x3), None, eviction_policy='evict_last')
tmp0 = tl.load(in_ptr0 + (x3), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp15 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp15 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp47 = tl.load(in_ptr6 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp47 = tl.load(in_ptr6 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp50 = tl.load(in_ptr7 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp50 = tl.load(in_ptr7 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.full([RBLOCK], 30522, tl.int32)
tmp1 = tl.full([RBLOCK], 30522, tl.int32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 30522), "index out of bounds: 0 <= tmp4 < 30522")
tl.device_assert((0 <= tmp4) & (tmp4 < 30522), "index out of bounds: 0 <= tmp4 < 30522")
tmp6 = tl.load(in_ptr1 + (r2 + (768*tmp4)), rmask, other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr1 + (r2 + (768*tmp4)), rmask, other=0.0).to(tl.float32)
tmp8 = tl.full([RBLOCK], 2, tl.int32)
tmp8 = tl.full([RBLOCK], 2, tl.int32)
tmp9 = tmp7 + tmp8
tmp9 = tmp7 + tmp8
tmp10 = tmp7 < 0
tmp10 = tmp7 < 0
tmp11 = tl.where(tmp10, tmp9, tmp7)
tmp11 = tl.where(tmp10, tmp9, tmp7)
tl.device_assert((0 <= tmp11) & (tmp11 < 2), "index out of bounds: 0 <= tmp11 < 2")
tl.device_assert((0 <= tmp11) & (tmp11 < 2), "index out of bounds: 0 <= tmp11 < 2")
tmp13 = tl.load(in_ptr3 + (r2 + (768*tmp11)), rmask, other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr3 + (r2 + (768*tmp11)), rmask, other=0.0).to(tl.float32)
tmp14 = tmp6 + tmp13
tmp14 = tmp6 + tmp13
tmp16 = tl.full([RBLOCK], 512, tl.int32)
tmp16 = tl.full([RBLOCK], 512, tl.int32)
tmp17 = tmp15 + tmp16
tmp17 = tmp15 + tmp16
tmp18 = tmp15 < 0
tmp18 = tmp15 < 0
tmp19 = tl.where(tmp18, tmp17, tmp15)
tmp19 = tl.where(tmp18, tmp17, tmp15)
tl.device_assert((0 <= tmp19) & (tmp19 < 512), "index out of bounds: 0 <= tmp19 < 512")
tl.device_assert((0 <= tmp19) & (tmp19 < 512), "index out of bounds: 0 <= tmp19 < 512")
tmp21 = tl.load(in_ptr5 + (r2 + (768*tmp19)), rmask, other=0.0).to(tl.float32)
tmp21 = tl.load(in_ptr5 + (r2 + (768*tmp19)), rmask, other=0.0).to(tl.float32)
tmp22 = tmp14 + tmp21
tmp22 = tmp14 + tmp21
tmp23 = tmp22.to(tl.float32)
tmp23 = tmp22.to(tl.float32)
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp26 = tl.where(rmask, tmp24, 0)
tmp26 = tl.where(rmask, tmp24, 0)
tmp27 = tl.broadcast_to(tmp24, [RBLOCK])
tmp27 = tl.broadcast_to(tmp24, [RBLOCK])
tmp29 = tl.where(rmask, tmp27, 0)
tmp29 = tl.where(rmask, tmp27, 0)
tmp30 = triton_helpers.promote_to_tensor(tl.sum(tmp29, 0))
tmp30 = triton_helpers.promote_to_tensor(tl.sum(tmp29, 0))
tmp31 = tl.full([1], 768, tl.int32)
tmp31 = tl.full([1], 768, tl.int32)
tmp32 = tmp31.to(tl.float32)
tmp32 = tmp31.to(tl.float32)
tmp33 = tmp30 / tmp32
tmp33 = tmp30 / tmp32
tmp34 = tmp24 - tmp33
tmp34 = tmp24 - tmp33
tmp35 = tmp34 * tmp34
tmp35 = tmp34 * tmp34
tmp36 = tl.broadcast_to(tmp35, [RBLOCK])
tmp36 = tl.broadcast_to(tmp35, [RBLOCK])
tmp38 = tl.where(rmask, tmp36, 0)
tmp38 = tl.where(rmask, tmp36, 0)
tmp39 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0))
tmp39 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0))
tmp40 = tmp23 - tmp33
tmp40 = tmp23 - tmp33
tmp41 = 768.0
tmp41 = 768.0
tmp42 = tmp39 / tmp41
tmp42 = tmp39 / tmp41
tmp43 = 1e-12
tmp43 = 1e-12
tmp44 = tmp42 + tmp43
tmp44 = tmp42 + tmp43
tmp45 = libdevice.rsqrt(tmp44)
tmp45 = libdevice.rsqrt(tmp44)
tmp46 = tmp40 * tmp45
tmp46 = tmp40 * tmp45
tmp48 = tmp47.to(tl.float32)
tmp48 = tmp47.to(tl.float32)
tmp49 = tmp46 * tmp48
tmp49 = tmp46 * tmp48
tmp51 = tmp50.to(tl.float32)
tmp51 = tmp50.to(tl.float32)
tmp52 = tmp49 + tmp51
tmp52 = tmp49 + tmp51
tmp53 = tmp52.to(tl.float32)
tmp53 = tmp52.to(tl.float32)
tl.store(out_ptr0 + (r2 + (768*x3)), tmp22, rmask)
tl.store(out_ptr0 + (r2 + (768*x3)), tmp22, rmask)
tl.store(out_ptr3 + (r2 + (768*x3)), tmp53, rmask)
tl.store(out_ptr3 + (r2 + (768*x3)), tmp53, rmask)
''', device_str='cuda')
''', device_str='cuda')


import triton
import triton
import triton.language as tl
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream




# kernel path: /tmp/torchinductor_eellison/ze/czezsvjgwtdzru43t5gig4zb6t4auharyypwotkirkhmzu46fkzp.py
# kernel path: /tmp/torchinductor_eellison/6q/c6qq6qfsawjfikfoeruueuup5cnzvmpzkusjkel6l6wcw43mgauj.py
# Source Nodes: [add_2, hidden_states_2], Original ATen: [aten.add, aten.native_layer_norm]
# Source Nodes: [attn_output], Original ATen: [aten._scaled_dot_product_efficient_attention]
# add_2 => add_5
# attn_output => _scaled_dot_product_efficient_attention
# hidden_states_2 => add_6, add_7, convert_element_type_21, convert_element_type_22, mul_3, mul_4, rsqrt_1, sub_3, var_mean_1
triton_poi_fused__scaled_dot_product_efficient_attention_1 = async_compile.triton('triton_', '''
triton_per_fused_add_native_layer_norm_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
size_hints=[67108864],
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_efficient_attention_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 50331648
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.full([1], False, tl.int1)
tmp1 = -3.3895313892515355e+38
tmp2 = 0.0
Text moved with changes from lines 464-469 (92.0% similarity)
tmp3 = tl.where(tmp0, tmp1, tmp2)
tl.store(out_ptr0 + (x0), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_eellison/at/catipo3niy6sjxmx7lban4oumf6lplhx7qd3qxdhlcdrzt4xmv65.py
# Source Nodes: [add_1, hidden_states_2], Original ATen: [aten.add, aten.native_layer_norm]
# add_1 => add_4
# hidden_states_2 => add_5, add_6, convert_element_type_16, convert_element_type_17, mul_2, mul_3, rsqrt_1, sub_2, var_mean_1
triton_per_fused_add_native_layer_norm_2 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties


@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_1', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_2', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel):
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel):
xnumel = 8192
xnumel = 8192
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
rnumel = 768
rnumel = 768
RBLOCK: tl.constexpr = 1024
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
roffset = 0
rmask = rindex < rnumel
rmask = rindex < rnumel
r1 = rindex
r1 = rindex
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp4 = tmp2 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp5 = tmp4.to(tl.float32)
tmp6 = tl.broadcast_to(tmp5, [RBLOCK])
tmp6 = tl.broadcast_to(tmp5, [RBLOCK])
tmp8 = tl.where(rmask, tmp6, 0)
tmp8 = tl.where(rmask, tmp6, 0)
tmp9 = tl.broadcast_to(tmp6, [RBLOCK])
tmp9 = tl.broadcast_to(tmp6, [RBLOCK])
tmp11 = tl.where(rmask, tmp9, 0)
tmp11 = tl.where(rmask, tmp9, 0)
tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp13 = tl.full([1], 768, tl.int32)
tmp13 = tl.full([1], 768, tl.int32)
tmp14 = tmp13.to(tl.float32)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 / tmp14
tmp15 = tmp12 / tmp14
tmp16 = tmp6 - tmp15
tmp16 = tmp6 - tmp15
tmp17 = tmp16 * tmp16
tmp17 = tmp16 * tmp16
tmp18 = tl.broadcast_to(tmp17, [RBLOCK])
tmp18 = tl.broadcast_to(tmp17, [RBLOCK])
tmp20 = tl.where(rmask, tmp18, 0)
tmp20 = tl.where(rmask, tmp18, 0)
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))
tmp22 = tmp5 - tmp15
tmp22 = tmp5 - tmp15
tmp23 = 768.0
tmp23 = 768.0
tmp24 = tmp21 / tmp23
tmp24 = tmp21 / tmp23
tmp25 = 1e-12
tmp25 = 1e-12
tmp26 = tmp24 + tmp25
tmp26 = tmp24 + tmp25
tmp27 = libdevice.rsqrt(tmp26)
tmp27 = libdevice.rsqrt(tmp26)
tmp28 = tmp22 * tmp27
tmp28 = tmp22 * tmp27
tmp30 = tmp29.to(tl.float32)
tmp30 = tmp29.to(tl.float32)
tmp31 = tmp28 * tmp30
tmp31 = tmp28 * tmp30
tmp33 = tmp32.to(tl.float32)
tmp33 = tmp32.to(tl.float32)
tmp34 = tmp31 + tmp33
tmp34 = tmp31 + tmp33
tmp35 = tmp34.to(tl.float32)
tmp35 = tmp34.to(tl.float32)
tl.store(out_ptr2 + (r1 + (768*x0)), tmp35, rmask)
tl.store(out_ptr2 + (r1 + (768*x0)), tmp35, rmask)
''', device_str='cuda')
''', device_str='cuda')




# kernel path: /tmp/torchinductor_eellison/h7/ch7fp7gaeqqseqnc45vn7aus3pytqrkcm2gt6tswtsp4robtviff.py
# kernel path: /tmp/torchinductor_eellison/lp/clpuyqd2qh624e2ye7osnyieorymdus3kflfhssozom7b33frjei.py
# Source Nodes: [hidden_states_4], Original ATen: [aten.gelu]
# Source Nodes: [hidden_states_4], Original ATen: [aten.gelu]
# hidden_states_4 => add_8, convert_element_type_26, convert_element_type_27, erf, mul_5, mul_6, mul_7
# hidden_states_4 => add_7, convert_element_type_21, convert_element_type_22, erf, mul_4, mul_5, mul_6
triton_poi_fused_gelu_2 = async_compile.triton('triton_', '''
triton_poi_fused_gelu_3 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties


@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[33554432],
size_hints=[33554432],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_2', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_3', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 25165824
xnumel = 25165824
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
xmask = tl.full([XBLOCK], True, tl.int1)
x2 = xindex
x2 = xindex
x0 = xindex % 3072
x0 = xindex % 3072
tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 0.5
tmp4 = 0.5
tmp5 = tmp3 * tmp4
tmp5 = tmp3 * tmp4
tmp6 = 0.7071067811865476
tmp6 = 0.7071067811865476
tmp7 = tmp3 * tmp6
tmp7 = tmp3 * tmp6
tmp8 = libdevice.erf(tmp7)
tmp8 = libdevice.erf(tmp7)
tmp9 = 1.0
tmp9 = 1.0
tmp10 = tmp8 + tmp9
tmp10 = tmp8 + tmp9
tmp11 = tmp5 * tmp10
tmp11 = tmp5 * tmp10
tmp12 = tmp11.to(tl.float32)
tmp12 = tmp11.to(tl.float32)
tl.store(in_out_ptr0 + (x2), tmp12, None)
tl.store(in_out_ptr0 + (x2), tmp12, None)
''', device_str='cuda')
''', device_str='cuda')




# kernel path: /tmp/torchinductor_eellison/kv/ckv6n5lz6ocro4pxsewmvlwwdchxn7ycxtd2mm4zi5mmbccvo7j4.py
# kernel path: /tmp/torchinductor_eellison/ai/caiiwspixmf6ozjdgwruqmkld4lyd2nnoyrhxzmgg3mafv6a6yka.py
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten.gelu, aten.native_layer_norm]
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten.gelu, aten.native_layer_norm]
# hidden_states_97 => add_100, convert_element_type_366, convert_element_type_367, erf_12, mul_87, mul_88, mul_89
# hidden_states_97 => add_88, convert_element_type_295, convert_element_type_296, erf_12, mul_86, mul_87, mul_88
# hidden_states_98 => add_101, add_102, convert_element_type_368, convert_element_type_369, mul_90, mul_91, rsqrt_25, sub_38, var_mean_25
# hidden_states_98 => add_89, add_90, convert_element_type_297, convert_element_type_298, mul_89, mul_90, rsqrt_25, sub_26, var_mean_25
triton_per_fused_gelu_native_layer_norm_3 = async_compile.triton('triton_', '''
triton_per_fused_gelu_native_layer_norm_4 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties


@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_gelu_native_layer_norm_3', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 4, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_gelu_native_layer_norm_4', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 4, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel):
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel):
xnumel = 8192
xnumel = 8192
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
rnumel = 768
rnumel = 768
RBLOCK: tl.constexpr = 1024
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
roffset = 0
rmask = rindex < rnumel
rmask = rindex < rnumel
r1 = rindex
r1 = rindex
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp37 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp37 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp40 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp40 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 0.5
tmp4 = 0.5
tmp5 = tmp3 * tmp4
tmp5 = tmp3 * tmp4
tmp6 = 0.7071067811865476
tmp6 = 0.7071067811865476
tmp7 = tmp3 * tmp6
tmp7 = tmp3 * tmp6
tmp8 = libdevice.erf(tmp7)
tmp8 = libdevice.erf(tmp7)
tmp9 = 1.0
tmp9 = 1.0
tmp10 = tmp8 + tmp9
tmp10 = tmp8 + tmp9
tmp11 = tmp5 * tmp10
tmp11 = tmp5 * tmp10
tmp12 = tmp11.to(tl.float32)
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp12.to(tl.float32)
tmp13 = tmp12.to(tl.float32)
tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
tmp16 = tl.where(rmask, tmp14, 0)
tmp16 = tl.where(rmask, tmp14, 0)
tmp17 = tl.broadcast_to(tmp14, [RBLOCK])
tmp17 = tl.broadcast_to(tmp14, [RBLOCK])
tmp19 = tl.where(rmask, tmp17, 0)
tmp19 = tl.where(rmask, tmp17, 0)
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0))
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0))
tmp21 = tl.full([1], 768, tl.int32)
tmp21 = tl.full([1], 768, tl.int32)
tmp22 = tmp21.to(tl.float32)
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp20 / tmp22
tmp23 = tmp20 / tmp22
tmp24 = tmp14 - tmp23
tmp24 = tmp14 - tmp23
tmp25 = tmp24 * tmp24
tmp25 = tmp24 * tmp24
tmp26 = tl.broadcast_to(tmp25, [RBLOCK])
tmp26 = tl.broadcast_to(tmp25, [RBLOCK])
tmp28 = tl.where(rmask, tmp26, 0)
tmp28 = tl.where(rmask, tmp26, 0)
tmp29 = triton_helpers.promote_to_tensor(tl.sum(tmp28, 0))
tmp29 = triton_helpers.promote_to_tensor(tl.sum(tmp28, 0))
tmp30 = tmp13 - tmp23
tmp30 = tmp13 - tmp23
tmp31 = 768.0
tmp31 = 768.0
tmp32 = tmp29 / tmp31
tmp32 = tmp29 / tmp31
tmp33 = 1e-12
tmp33 = 1e-12
tmp34 = tmp32 + tmp33
tmp34 = tmp32 + tmp33
tmp35 = libdevice.rsqrt(tmp34)
tmp35 = libdevice.rsqrt(tmp34)
tmp36 = tmp30 * tmp35
tmp36 = tmp30 * tmp35
tmp38 = tmp37.to(tl.float32)
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp36 * tmp38
tmp39 = tmp36 * tmp38
tmp41 = tmp40.to(tl.float32)
tmp41 = tmp40.to(tl.float32)
tmp42 = tmp39 + tmp41
tmp42 = tmp39 + tmp41
tmp43 = tmp42.to(tl.float32)
tmp43 = tmp42.to(tl.float32)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp43, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp43, rmask)
''', device_str='cuda')
''', device_str='cuda')




# kernel path: /tmp/torchinductor_eellison/ry/cryghuddhmky5nywmwltcxvq5qstdsr5rotzufalizyup7kvjzmb.py
# kernel path: /tmp/torchinductor_eellison/fl/cfllfdzx4v3f3zccah7zu5u634j2vrlvbkru74wmaerhgadlizat.py
# Source Nodes: [], Original ATen: []
# Source Nodes: [], Original ATen: []


triton_poi_fused_4 = async_compile.triton('triton_', '''
triton_poi_fused_5 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties


@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[33554432],
size_hints=[33554432],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 23445504
xnumel = 23445504
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex % 30528
x0 = xindex % 30528
x1 = (xindex // 30528)
x1 = (xindex // 30528)
x2 = xindex
x2 = xindex
tmp0 = x0
tmp0 = x0
tmp1 = tl.full([1], 0, tl.int64)
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 30522, tl.int64)
tmp3 = tl.full([1], 30522, tl.int64)
tmp4 = tmp0 < tmp3
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (x1 + (768*x0)), tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x1 + (768*x0)), tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tmp0 >= tmp3
tmp6 = tmp0 >= tmp3
tmp7 = tl.full([1], 30528, tl.int64)
tmp7 = tl.full([1], 30528, tl.int64)
tmp8 = tmp0 < tmp7
tmp8 = tmp0 < tmp7
tmp9 = 0.0
tmp9 = 0.0
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp6, tmp9, tmp10)
tmp11 = tl.where(tmp6, tmp9, tmp10)
tmp12 = tl.where(tmp4, tmp5, tmp11)
tmp12 = tl.where(tmp4, tmp5, tmp11)
tl.store(out_ptr0 + (x2), tmp12, None)
tl.store(out_ptr0 + (x2), tmp12, None)
''', device_str='cuda')
''', device_str='cuda')




# kernel path: /tmp/torchinductor_eellison/iq/ciq3xiils5u6jz63putfodpdyv6oshlb3gbss7hjki6u3ogkhwqq.py
# kernel path: /tmp/torchinductor_eellison/si/csikk4r46efsgklpubt6iamwjm3jsevq2h2pkkvjbgcc7sj4p24c.py
# Source Nodes: [], Original ATen: []
# Source Nodes: [], Original ATen: []


triton_poi_fused_5 = async_compile.triton('triton_', '''
triton_poi_fused_6 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties


@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[32768],
size_hints=[32768],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 30528
xnumel = 30528
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
xmask = xindex < xnumel
x0 = xindex
x0 = xindex
tmp0 = x0
tmp0 = x0
tmp1 = tl.full([1], 0, tl.int64)
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 30522, tl.int64)
tmp3 = tl.full([1], 30522, tl.int64)
tmp4 = tmp0 < tmp3
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tmp0 >= tmp3
tmp6 = tmp0 >= tmp3
tmp7 = tl.full([1], 30528, tl.int64)
tmp7 = tl.full([1], 30528, tl.int64)
tmp8 = tmp0 < tmp7
tmp8 = tmp0 < tmp7
tmp9 = 0.0
tmp9 = 0.0
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp6, tmp9, tmp10)
tmp11 = tl.where(tmp6, tmp9, tmp10)
tmp12 = tl.where(tmp4, tmp5, tmp11)
tmp12 = tl.where(tmp4, tmp5, tmp11)
tl.store(out_ptr0 + (x0), tmp12, xmask)
tl.store(out_ptr0 + (x0), tmp12, xmask)
''', device_str='cuda')
''', device_str='cuda')




# kernel path: /tmp/torchinductor_eellison/vf/cvfw2hcsiy2zivm66wjz4xfhs2ysdjbvqk2rsig4hp62bpxn2vvd.py
# kernel path: /tmp/torchinductor_eellison/z7/cz72sjk2qwjjxgfdcyk6de22sryhbhr6pptrvx4aq2hq2soai2p5.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax]
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax]
# masked_lm_loss => amax_12, convert_element_type_373, exp_12, sub_39, sum_13
# masked_lm_loss => amax, convert_element_type_302, exp, sub_27, sum_1
triton_red_fused__log_softmax_6 = async_compile.triton('triton_', '''
triton_red_fused__log_softmax_7 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor


from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties


@triton_heuristics.reduction(
@triton_heuristics.reduction(
size_hints=[8192, 32768],
size_hints=[8192, 32768],
reduction_hint=ReductionHint.DEFAULT,
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_7', 'mutated_arg_name
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 30522
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp3 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp4 = triton_helpers.maximum(_tmp3, tmp2)
_tmp3 = tl.where(rmask, tmp4, _tmp3)
tmp3 = triton_helpers.max2(_tmp3, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp3, None)
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp5 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp6 - tmp3
tmp8 = tl_math.exp(tmp7)
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask, tmp11, _tmp10)
Text moved with changes to lines 159-164 (92.0% similarity)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tl.store(out_ptr1 + (x0), tmp10, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_eellison/bt/cbtmzvri2vmmhtlgzxhr7zowbeqsafuqeui5gh3j2bsgiibilujp