Diff
checker
텍스트
텍스트
이미지
문서
Excel
폴더
Legal
Enterprise
데스크톱
요금제
로그인
데스크톱 앱 다운로드
텍스트 비교
두 텍스트 파일의 차이점을 찾아보세요
도구
기록
실시간 편집
변경 없는 행 숨기기
줄바꿈 비활성화
레이아웃
나란히 보기
합쳐 보기
비교 단위
스마트
단어
글자
구문 강조
언어 선택
제외
텍스트 변환
첫 변경으로
수정
Diffchecker Desktop
가장 안전하게 Diffchecker를 사용하는 방법. 데스크톱 앱을 사용하면 비교 데이터가 외부로 전송되지 않습니다!
데스크톱 앱 받기
output_code_before_after
생성일
2년 전
비교 결과 만료 없음
초기화
내보내기
공유
설명
95 삭제
행
총
삭제
글자
총
삭제
이 기능을 계속 사용하려면 업그레이드해 주세요
Diff
checker
Pro
요금제 보기
469 행
복사
82 추가
행
총
추가
글자
총
추가
이 기능을 계속 사용하려면 업그레이드해 주세요
Diff
checker
Pro
요금제 보기
463 행
복사
# AOT ID: ['0_inference']
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
from ctypes import c_void_p, c_long
import torch
import torch
import math
import math
import random
import random
import os
import os
import tempfile
import tempfile
from math import inf, nan
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_eellison/wj/cwjmdogaso56hzkb4pdlo4k625tmwevhq5krfbpmpx5szmhdjbsf.py
# kernel path: /tmp/torchinductor_eellison/wj/cwjmdogaso56hzkb4pdlo4k625tmwevhq5krfbpmpx5szmhdjbsf.py
# Source Nodes: [embeddings, embeddings_1, embeddings_2, inputs_embeds, position_embeddings, token_type_embeddings], Original ATen: [aten.add, aten.embedding, aten.native_layer_norm]
# Source Nodes: [embeddings, embeddings_1, embeddings_2, inputs_embeds, position_embeddings, token_type_embeddings], Original ATen: [aten.add, aten.embedding, aten.native_layer_norm]
# embeddings => add
# embeddings => add
# embeddings_1 => add_1
# embeddings_1 => add_1
복사
복사됨
복사
복사됨
# embeddings_2 => add_2, add_3, convert_element_type
_1
, convert_element_type_
2
, mul_1,
mul_2,
rsqrt, sub
_1
, var_mean
# embeddings_2 => add_2, add_3, convert_element_type
, convert_element_type_
1, mul
, mul_1,
rsqrt, sub
, var_mean
# inputs_embeds => embedding
# inputs_embeds => embedding
# position_embeddings => embedding_2
# position_embeddings => embedding_2
# token_type_embeddings => embedding_1
# token_type_embeddings => embedding_1
triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_', '''
triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*i64', 3: '*bf16', 4: '*i64', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: '*bf16', 9: '*bf16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]},
triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*i64', 3: '*bf16', 4: '*i64', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: '*bf16', 9: '*bf16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr3, xnumel, rnumel):
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr3, xnumel, rnumel):
xnumel = 8192
xnumel = 8192
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
rnumel = 768
rnumel = 768
RBLOCK: tl.constexpr = 1024
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
roffset = 0
rmask = rindex < rnumel
rmask = rindex < rnumel
x3 = xindex
x3 = xindex
r2 = rindex
r2 = rindex
x0 = xindex % 512
x0 = xindex % 512
tmp0 = tl.load(in_ptr0 + (x3), None, eviction_policy='evict_last')
tmp0 = tl.load(in_ptr0 + (x3), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp15 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp15 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp47 = tl.load(in_ptr6 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp47 = tl.load(in_ptr6 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp50 = tl.load(in_ptr7 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp50 = tl.load(in_ptr7 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.full([RBLOCK], 30522, tl.int32)
tmp1 = tl.full([RBLOCK], 30522, tl.int32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 30522), "index out of bounds: 0 <= tmp4 < 30522")
tl.device_assert((0 <= tmp4) & (tmp4 < 30522), "index out of bounds: 0 <= tmp4 < 30522")
tmp6 = tl.load(in_ptr1 + (r2 + (768*tmp4)), rmask, other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr1 + (r2 + (768*tmp4)), rmask, other=0.0).to(tl.float32)
tmp8 = tl.full([RBLOCK], 2, tl.int32)
tmp8 = tl.full([RBLOCK], 2, tl.int32)
tmp9 = tmp7 + tmp8
tmp9 = tmp7 + tmp8
tmp10 = tmp7 < 0
tmp10 = tmp7 < 0
tmp11 = tl.where(tmp10, tmp9, tmp7)
tmp11 = tl.where(tmp10, tmp9, tmp7)
tl.device_assert((0 <= tmp11) & (tmp11 < 2), "index out of bounds: 0 <= tmp11 < 2")
tl.device_assert((0 <= tmp11) & (tmp11 < 2), "index out of bounds: 0 <= tmp11 < 2")
tmp13 = tl.load(in_ptr3 + (r2 + (768*tmp11)), rmask, other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr3 + (r2 + (768*tmp11)), rmask, other=0.0).to(tl.float32)
tmp14 = tmp6 + tmp13
tmp14 = tmp6 + tmp13
tmp16 = tl.full([RBLOCK], 512, tl.int32)
tmp16 = tl.full([RBLOCK], 512, tl.int32)
tmp17 = tmp15 + tmp16
tmp17 = tmp15 + tmp16
tmp18 = tmp15 < 0
tmp18 = tmp15 < 0
tmp19 = tl.where(tmp18, tmp17, tmp15)
tmp19 = tl.where(tmp18, tmp17, tmp15)
tl.device_assert((0 <= tmp19) & (tmp19 < 512), "index out of bounds: 0 <= tmp19 < 512")
tl.device_assert((0 <= tmp19) & (tmp19 < 512), "index out of bounds: 0 <= tmp19 < 512")
tmp21 = tl.load(in_ptr5 + (r2 + (768*tmp19)), rmask, other=0.0).to(tl.float32)
tmp21 = tl.load(in_ptr5 + (r2 + (768*tmp19)), rmask, other=0.0).to(tl.float32)
tmp22 = tmp14 + tmp21
tmp22 = tmp14 + tmp21
tmp23 = tmp22.to(tl.float32)
tmp23 = tmp22.to(tl.float32)
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp26 = tl.where(rmask, tmp24, 0)
tmp26 = tl.where(rmask, tmp24, 0)
tmp27 = tl.broadcast_to(tmp24, [RBLOCK])
tmp27 = tl.broadcast_to(tmp24, [RBLOCK])
tmp29 = tl.where(rmask, tmp27, 0)
tmp29 = tl.where(rmask, tmp27, 0)
tmp30 = triton_helpers.promote_to_tensor(tl.sum(tmp29, 0))
tmp30 = triton_helpers.promote_to_tensor(tl.sum(tmp29, 0))
tmp31 = tl.full([1], 768, tl.int32)
tmp31 = tl.full([1], 768, tl.int32)
tmp32 = tmp31.to(tl.float32)
tmp32 = tmp31.to(tl.float32)
tmp33 = tmp30 / tmp32
tmp33 = tmp30 / tmp32
tmp34 = tmp24 - tmp33
tmp34 = tmp24 - tmp33
tmp35 = tmp34 * tmp34
tmp35 = tmp34 * tmp34
tmp36 = tl.broadcast_to(tmp35, [RBLOCK])
tmp36 = tl.broadcast_to(tmp35, [RBLOCK])
tmp38 = tl.where(rmask, tmp36, 0)
tmp38 = tl.where(rmask, tmp36, 0)
tmp39 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0))
tmp39 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0))
tmp40 = tmp23 - tmp33
tmp40 = tmp23 - tmp33
tmp41 = 768.0
tmp41 = 768.0
tmp42 = tmp39 / tmp41
tmp42 = tmp39 / tmp41
tmp43 = 1e-12
tmp43 = 1e-12
tmp44 = tmp42 + tmp43
tmp44 = tmp42 + tmp43
tmp45 = libdevice.rsqrt(tmp44)
tmp45 = libdevice.rsqrt(tmp44)
tmp46 = tmp40 * tmp45
tmp46 = tmp40 * tmp45
tmp48 = tmp47.to(tl.float32)
tmp48 = tmp47.to(tl.float32)
tmp49 = tmp46 * tmp48
tmp49 = tmp46 * tmp48
tmp51 = tmp50.to(tl.float32)
tmp51 = tmp50.to(tl.float32)
tmp52 = tmp49 + tmp51
tmp52 = tmp49 + tmp51
tmp53 = tmp52.to(tl.float32)
tmp53 = tmp52.to(tl.float32)
tl.store(out_ptr0 + (r2 + (768*x3)), tmp22, rmask)
tl.store(out_ptr0 + (r2 + (768*x3)), tmp22, rmask)
tl.store(out_ptr3 + (r2 + (768*x3)), tmp53, rmask)
tl.store(out_ptr3 + (r2 + (768*x3)), tmp53, rmask)
''', device_str='cuda')
''', device_str='cuda')
import triton
import triton
import triton.language as tl
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
복사
복사됨
복사
복사됨
# kernel path: /tmp/torchinductor_eellison/
ze/czezsvjgwtdzru43t5gig4zb6t4auharyypwotkirkhmzu46fkzp.py
# kernel path: /tmp/torchinductor_eellison/
6q/c6qq6qfsawjfikfoeruueuup5cnzvmpzkusjkel6l6wcw43mgauj.py
# Source Nodes: [add_
2
, hidden_states_2], Original ATen: [aten.add, aten.native_layer_norm]
# Source Nodes: [attn_output], Original ATen: [aten._scaled_dot_product_efficient_attention]
# add_
2
=> add_
5
# attn_output => _scaled_dot_product_efficient_attention
# hidden_states_2 => add_
6
, add_
7
, convert_element_type_
21
, convert_element_type_
22
, mul_
3
, mul_
4
, rsqrt_1, sub_
3
, var_mean_1
triton_poi_fused__scaled_dot_product_efficient_attention_1 = async_compile.triton('triton_', '''
triton_per_fused_add_native_layer_norm_
1
= async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[67108864],
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_efficient_attention_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 50331648
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.full([1], False, tl.int1)
tmp1 = -3.3895313892515355e+38
tmp2 = 0.0
Text moved with changes from lines 464-469 (92.0% similarity)
tmp3 = tl.where(tmp0, tmp1, tmp2)
tl.store(out_ptr0 + (x0), tmp3, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_eellison/at/catipo3niy6sjxmx7lban4oumf6lplhx7qd3qxdhlcdrzt4xmv65.py
# Source Nodes: [add_
1
, hidden_states_2], Original ATen: [aten.add, aten.native_layer_norm]
# add_
1
=> add_
4
# hidden_states_2 => add_
5
, add_
6
, convert_element_type_
16
, convert_element_type_
17
, mul_
2
, mul_
3
, rsqrt_1, sub_
2
, var_mean_1
triton_per_fused_add_native_layer_norm_
2
= async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
복사
복사됨
복사
복사됨
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_
1
', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_
2
', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel):
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel):
xnumel = 8192
xnumel = 8192
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
rnumel = 768
rnumel = 768
RBLOCK: tl.constexpr = 1024
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
roffset = 0
rmask = rindex < rnumel
rmask = rindex < rnumel
r1 = rindex
r1 = rindex
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp4 = tmp2 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp5 = tmp4.to(tl.float32)
tmp6 = tl.broadcast_to(tmp5, [RBLOCK])
tmp6 = tl.broadcast_to(tmp5, [RBLOCK])
tmp8 = tl.where(rmask, tmp6, 0)
tmp8 = tl.where(rmask, tmp6, 0)
tmp9 = tl.broadcast_to(tmp6, [RBLOCK])
tmp9 = tl.broadcast_to(tmp6, [RBLOCK])
tmp11 = tl.where(rmask, tmp9, 0)
tmp11 = tl.where(rmask, tmp9, 0)
tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp13 = tl.full([1], 768, tl.int32)
tmp13 = tl.full([1], 768, tl.int32)
tmp14 = tmp13.to(tl.float32)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 / tmp14
tmp15 = tmp12 / tmp14
tmp16 = tmp6 - tmp15
tmp16 = tmp6 - tmp15
tmp17 = tmp16 * tmp16
tmp17 = tmp16 * tmp16
tmp18 = tl.broadcast_to(tmp17, [RBLOCK])
tmp18 = tl.broadcast_to(tmp17, [RBLOCK])
tmp20 = tl.where(rmask, tmp18, 0)
tmp20 = tl.where(rmask, tmp18, 0)
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))
tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))
tmp22 = tmp5 - tmp15
tmp22 = tmp5 - tmp15
tmp23 = 768.0
tmp23 = 768.0
tmp24 = tmp21 / tmp23
tmp24 = tmp21 / tmp23
tmp25 = 1e-12
tmp25 = 1e-12
tmp26 = tmp24 + tmp25
tmp26 = tmp24 + tmp25
tmp27 = libdevice.rsqrt(tmp26)
tmp27 = libdevice.rsqrt(tmp26)
tmp28 = tmp22 * tmp27
tmp28 = tmp22 * tmp27
tmp30 = tmp29.to(tl.float32)
tmp30 = tmp29.to(tl.float32)
tmp31 = tmp28 * tmp30
tmp31 = tmp28 * tmp30
tmp33 = tmp32.to(tl.float32)
tmp33 = tmp32.to(tl.float32)
tmp34 = tmp31 + tmp33
tmp34 = tmp31 + tmp33
tmp35 = tmp34.to(tl.float32)
tmp35 = tmp34.to(tl.float32)
tl.store(out_ptr2 + (r1 + (768*x0)), tmp35, rmask)
tl.store(out_ptr2 + (r1 + (768*x0)), tmp35, rmask)
''', device_str='cuda')
''', device_str='cuda')
복사
복사됨
복사
복사됨
# kernel path: /tmp/torchinductor_eellison/
h7/ch7fp7gaeqqseqnc45vn7aus3pytqrkcm2gt6tswtsp4robtviff.py
# kernel path: /tmp/torchinductor_eellison/
lp/clpuyqd2qh624e2ye7osnyieorymdus3kflfhssozom7b33frjei.py
# Source Nodes: [hidden_states_4], Original ATen: [aten.gelu]
# Source Nodes: [hidden_states_4], Original ATen: [aten.gelu]
복사
복사됨
복사
복사됨
# hidden_states_4 => add_
8
, convert_element_type_
26
, convert_element_type_
27
, erf, mul_
5
, mul_
6
, mul_
7
# hidden_states_4 => add_
7
, convert_element_type_
21
, convert_element_type_
22
, erf, mul_
4
, mul_
5
, mul_
6
triton_poi_fused_gelu_
2
= async_compile.triton('triton_', '''
triton_poi_fused_gelu_
3
= async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[33554432],
size_hints=[33554432],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
복사
복사됨
복사
복사됨
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_
2
', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_
3
', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 25165824
xnumel = 25165824
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
xmask = tl.full([XBLOCK], True, tl.int1)
x2 = xindex
x2 = xindex
x0 = xindex % 3072
x0 = xindex % 3072
tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 0.5
tmp4 = 0.5
tmp5 = tmp3 * tmp4
tmp5 = tmp3 * tmp4
tmp6 = 0.7071067811865476
tmp6 = 0.7071067811865476
tmp7 = tmp3 * tmp6
tmp7 = tmp3 * tmp6
tmp8 = libdevice.erf(tmp7)
tmp8 = libdevice.erf(tmp7)
tmp9 = 1.0
tmp9 = 1.0
tmp10 = tmp8 + tmp9
tmp10 = tmp8 + tmp9
tmp11 = tmp5 * tmp10
tmp11 = tmp5 * tmp10
tmp12 = tmp11.to(tl.float32)
tmp12 = tmp11.to(tl.float32)
tl.store(in_out_ptr0 + (x2), tmp12, None)
tl.store(in_out_ptr0 + (x2), tmp12, None)
''', device_str='cuda')
''', device_str='cuda')
복사
복사됨
복사
복사됨
# kernel path: /tmp/torchinductor_eellison/
kv/ckv6n5lz6ocro4pxsewmvlwwdchxn7ycxtd2mm4zi5mmbccvo7j4.py
# kernel path: /tmp/torchinductor_eellison/
ai/caiiwspixmf6ozjdgwruqmkld4lyd2nnoyrhxzmgg3mafv6a6yka.py
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten.gelu, aten.native_layer_norm]
# Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten.gelu, aten.native_layer_norm]
복사
복사됨
복사
복사됨
# hidden_states_97 => add_
100
, convert_element_type_
366
, convert_element_type_
367
, erf_12, mul_
87
, mul_
88
, mul_
89
# hidden_states_97 => add_
88
, convert_element_type_
295
, convert_element_type_
296
, erf_12, mul_
86
, mul_
87
, mul_
88
# hidden_states_98 => add_
101
, add_
102
, convert_element_type_
368
, convert_element_type_
369
, mul_90,
mul_91,
rsqrt_25, sub_
38
, var_mean_25
# hidden_states_98 => add_
89
, add_
90
, convert_element_type_
297
, convert_element_type_
298, mul_89
, mul_90,
rsqrt_25, sub_
26
, var_mean_25
triton_per_fused_gelu_native_layer_norm_
3
= async_compile.triton('triton_', '''
triton_per_fused_gelu_native_layer_norm_
4
= async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints=[8192, 1024],
size_hints=[8192, 1024],
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
복사
복사됨
복사
복사됨
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_gelu_native_layer_norm_
3
', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 4, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_gelu_native_layer_norm_
4
', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 4, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel):
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel):
xnumel = 8192
xnumel = 8192
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
rnumel = 768
rnumel = 768
RBLOCK: tl.constexpr = 1024
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
roffset = 0
rmask = rindex < rnumel
rmask = rindex < rnumel
r1 = rindex
r1 = rindex
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp37 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp37 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp40 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp40 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = 0.5
tmp4 = 0.5
tmp5 = tmp3 * tmp4
tmp5 = tmp3 * tmp4
tmp6 = 0.7071067811865476
tmp6 = 0.7071067811865476
tmp7 = tmp3 * tmp6
tmp7 = tmp3 * tmp6
tmp8 = libdevice.erf(tmp7)
tmp8 = libdevice.erf(tmp7)
tmp9 = 1.0
tmp9 = 1.0
tmp10 = tmp8 + tmp9
tmp10 = tmp8 + tmp9
tmp11 = tmp5 * tmp10
tmp11 = tmp5 * tmp10
tmp12 = tmp11.to(tl.float32)
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp12.to(tl.float32)
tmp13 = tmp12.to(tl.float32)
tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
tmp16 = tl.where(rmask, tmp14, 0)
tmp16 = tl.where(rmask, tmp14, 0)
tmp17 = tl.broadcast_to(tmp14, [RBLOCK])
tmp17 = tl.broadcast_to(tmp14, [RBLOCK])
tmp19 = tl.where(rmask, tmp17, 0)
tmp19 = tl.where(rmask, tmp17, 0)
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0))
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0))
tmp21 = tl.full([1], 768, tl.int32)
tmp21 = tl.full([1], 768, tl.int32)
tmp22 = tmp21.to(tl.float32)
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp20 / tmp22
tmp23 = tmp20 / tmp22
tmp24 = tmp14 - tmp23
tmp24 = tmp14 - tmp23
tmp25 = tmp24 * tmp24
tmp25 = tmp24 * tmp24
tmp26 = tl.broadcast_to(tmp25, [RBLOCK])
tmp26 = tl.broadcast_to(tmp25, [RBLOCK])
tmp28 = tl.where(rmask, tmp26, 0)
tmp28 = tl.where(rmask, tmp26, 0)
tmp29 = triton_helpers.promote_to_tensor(tl.sum(tmp28, 0))
tmp29 = triton_helpers.promote_to_tensor(tl.sum(tmp28, 0))
tmp30 = tmp13 - tmp23
tmp30 = tmp13 - tmp23
tmp31 = 768.0
tmp31 = 768.0
tmp32 = tmp29 / tmp31
tmp32 = tmp29 / tmp31
tmp33 = 1e-12
tmp33 = 1e-12
tmp34 = tmp32 + tmp33
tmp34 = tmp32 + tmp33
tmp35 = libdevice.rsqrt(tmp34)
tmp35 = libdevice.rsqrt(tmp34)
tmp36 = tmp30 * tmp35
tmp36 = tmp30 * tmp35
tmp38 = tmp37.to(tl.float32)
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp36 * tmp38
tmp39 = tmp36 * tmp38
tmp41 = tmp40.to(tl.float32)
tmp41 = tmp40.to(tl.float32)
tmp42 = tmp39 + tmp41
tmp42 = tmp39 + tmp41
tmp43 = tmp42.to(tl.float32)
tmp43 = tmp42.to(tl.float32)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp43, rmask)
tl.store(out_ptr3 + (r1 + (768*x0)), tmp43, rmask)
''', device_str='cuda')
''', device_str='cuda')
복사
복사됨
복사
복사됨
# kernel path: /tmp/torchinductor_eellison/
ry/cryghuddhmky5nywmwltcxvq5qstdsr5rotzufalizyup7kvjzmb.py
# kernel path: /tmp/torchinductor_eellison/
fl/cfllfdzx4v3f3zccah7zu5u634j2vrlvbkru74wmaerhgadlizat.py
# Source Nodes: [], Original ATen: []
# Source Nodes: [], Original ATen: []
복사
복사됨
복사
복사됨
triton_poi_fused_
4
= async_compile.triton('triton_', '''
triton_poi_fused_
5
= async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[33554432],
size_hints=[33554432],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
복사
복사됨
복사
복사됨
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_
4
', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_
5
', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 23445504
xnumel = 23445504
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex % 30528
x0 = xindex % 30528
x1 = (xindex // 30528)
x1 = (xindex // 30528)
x2 = xindex
x2 = xindex
tmp0 = x0
tmp0 = x0
tmp1 = tl.full([1], 0, tl.int64)
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 30522, tl.int64)
tmp3 = tl.full([1], 30522, tl.int64)
tmp4 = tmp0 < tmp3
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (x1 + (768*x0)), tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x1 + (768*x0)), tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tmp0 >= tmp3
tmp6 = tmp0 >= tmp3
tmp7 = tl.full([1], 30528, tl.int64)
tmp7 = tl.full([1], 30528, tl.int64)
tmp8 = tmp0 < tmp7
tmp8 = tmp0 < tmp7
tmp9 = 0.0
tmp9 = 0.0
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp6, tmp9, tmp10)
tmp11 = tl.where(tmp6, tmp9, tmp10)
tmp12 = tl.where(tmp4, tmp5, tmp11)
tmp12 = tl.where(tmp4, tmp5, tmp11)
tl.store(out_ptr0 + (x2), tmp12, None)
tl.store(out_ptr0 + (x2), tmp12, None)
''', device_str='cuda')
''', device_str='cuda')
복사
복사됨
복사
복사됨
# kernel path: /tmp/torchinductor_eellison/
iq/ciq3xiils5u6jz63putfodpdyv6oshlb3gbss7hjki6u3ogkhwqq.py
# kernel path: /tmp/torchinductor_eellison/
si/csikk4r46efsgklpubt6iamwjm3jsevq2h2pkkvjbgcc7sj4p24c.py
# Source Nodes: [], Original ATen: []
# Source Nodes: [], Original ATen: []
복사
복사됨
복사
복사됨
triton_poi_fused_
5
= async_compile.triton('triton_', '''
triton_poi_fused_
6
= async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[32768],
size_hints=[32768],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
복사
복사됨
복사
복사됨
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_
5
', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_
6
', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 30528
xnumel = 30528
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
xmask = xindex < xnumel
x0 = xindex
x0 = xindex
tmp0 = x0
tmp0 = x0
tmp1 = tl.full([1], 0, tl.int64)
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 30522, tl.int64)
tmp3 = tl.full([1], 30522, tl.int64)
tmp4 = tmp0 < tmp3
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tmp0 >= tmp3
tmp6 = tmp0 >= tmp3
tmp7 = tl.full([1], 30528, tl.int64)
tmp7 = tl.full([1], 30528, tl.int64)
tmp8 = tmp0 < tmp7
tmp8 = tmp0 < tmp7
tmp9 = 0.0
tmp9 = 0.0
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp6, tmp9, tmp10)
tmp11 = tl.where(tmp6, tmp9, tmp10)
tmp12 = tl.where(tmp4, tmp5, tmp11)
tmp12 = tl.where(tmp4, tmp5, tmp11)
tl.store(out_ptr0 + (x0), tmp12, xmask)
tl.store(out_ptr0 + (x0), tmp12, xmask)
''', device_str='cuda')
''', device_str='cuda')
복사
복사됨
복사
복사됨
# kernel path: /tmp/torchinductor_eellison/
vf/cvfw2hcsiy2zivm66wjz4xfhs2ysdjbvqk2rsig4hp62bpxn2vvd.py
# kernel path: /tmp/torchinductor_eellison/
z7/cz72sjk2qwjjxgfdcyk6de22sryhbhr6pptrvx4aq2hq2soai2p5.py
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax]
# Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax]
복사
복사됨
복사
복사됨
# masked_lm_loss => amax
_12
, convert_element_type_
373
, exp
_12
, sub_
39
, sum_1
3
# masked_lm_loss => amax
, convert_element_type_
302
, exp
, sub_
27
, sum_1
triton_red_fused__log_softmax_
6
= async_compile.triton('triton_', '''
triton_red_fused__log_softmax_
7
= async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
@triton_heuristics.reduction(
size_hints=[8192, 32768],
size_hints=[8192, 32768],
reduction_hint=ReductionHint.DEFAULT,
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
triton_meta={'signature': {0: '*bf16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
복사
복사됨
복사
복사됨
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_
6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_
7', 'mutated_arg_name
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 30522
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp3 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp4 = triton_helpers.maximum(_tmp3, tmp2)
_tmp3 = tl.where(rmask, tmp4, _tmp3)
tmp3 = triton_helpers.max2(_tmp3, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp3, None)
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp5 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp6 - tmp3
tmp8 = tl_math.exp(tmp7)
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask, tmp11, _tmp10)
Text moved with changes to lines 159-164 (92.0% similarity)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tl.store(out_ptr1 + (x0), tmp10, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_eellison/bt/cbtmzvri2vmmhtlgzxhr7zowbeqsafuqeui5gh3j2bsgiibilujp
저장된 비교 결과
원본
파일 열기
# AOT ID: ['0_inference'] from ctypes import c_void_p, c_long import torch import math import random import os import tempfile from math import inf, nan from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided from torch._inductor.async_compile import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall aten = torch.ops.aten inductor_ops = torch.ops.inductor _quantized = torch.ops._quantized assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool async_compile = AsyncCompile() # kernel path: /tmp/torchinductor_eellison/wj/cwjmdogaso56hzkb4pdlo4k625tmwevhq5krfbpmpx5szmhdjbsf.py # Source Nodes: [embeddings, embeddings_1, embeddings_2, inputs_embeds, position_embeddings, token_type_embeddings], Original ATen: [aten.add, aten.embedding, aten.native_layer_norm] # embeddings => add # embeddings_1 => add_1 # embeddings_2 => add_2, add_3, convert_element_type_1, convert_element_type_2, mul_1, mul_2, rsqrt, sub_1, var_mean # inputs_embeds => embedding # position_embeddings => embedding_2 # token_type_embeddings => embedding_1 triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.persistent_reduction( size_hints=[8192, 1024], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*i64', 3: '*bf16', 4: '*i64', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: '*bf16', 9: '*bf16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr3, xnumel, rnumel): xnumel = 8192 XBLOCK: tl.constexpr = 1 rnumel = 768 RBLOCK: tl.constexpr = 1024 xoffset = tl.program_id(0) * XBLOCK xindex = tl.full([1], xoffset, tl.int32) xmask = tl.full([RBLOCK], True, tl.int1) rindex = tl.arange(0, RBLOCK)[:] roffset = 0 rmask = rindex < rnumel x3 = xindex r2 = rindex x0 = xindex % 512 tmp0 = tl.load(in_ptr0 + (x3), None, eviction_policy='evict_last') tmp7 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') tmp15 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') tmp47 = tl.load(in_ptr6 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp50 = tl.load(in_ptr7 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp1 = tl.full([RBLOCK], 30522, tl.int32) tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) tl.device_assert((0 <= tmp4) & (tmp4 < 30522), "index out of bounds: 0 <= tmp4 < 30522") tmp6 = tl.load(in_ptr1 + (r2 + (768*tmp4)), rmask, other=0.0).to(tl.float32) tmp8 = tl.full([RBLOCK], 2, tl.int32) tmp9 = tmp7 + tmp8 tmp10 = tmp7 < 0 tmp11 = tl.where(tmp10, tmp9, tmp7) tl.device_assert((0 <= tmp11) & (tmp11 < 2), "index out of bounds: 0 <= tmp11 < 2") tmp13 = tl.load(in_ptr3 + (r2 + (768*tmp11)), rmask, other=0.0).to(tl.float32) tmp14 = tmp6 + tmp13 tmp16 = tl.full([RBLOCK], 512, tl.int32) tmp17 = tmp15 + tmp16 tmp18 = tmp15 < 0 tmp19 = tl.where(tmp18, tmp17, tmp15) tl.device_assert((0 <= tmp19) & (tmp19 < 512), "index out of bounds: 0 <= tmp19 < 512") tmp21 = tl.load(in_ptr5 + (r2 + (768*tmp19)), rmask, other=0.0).to(tl.float32) tmp22 = tmp14 + tmp21 tmp23 = tmp22.to(tl.float32) tmp24 = tl.broadcast_to(tmp23, [RBLOCK]) tmp26 = tl.where(rmask, tmp24, 0) tmp27 = tl.broadcast_to(tmp24, [RBLOCK]) tmp29 = tl.where(rmask, tmp27, 0) tmp30 = triton_helpers.promote_to_tensor(tl.sum(tmp29, 0)) tmp31 = tl.full([1], 768, tl.int32) tmp32 = tmp31.to(tl.float32) tmp33 = tmp30 / tmp32 tmp34 = tmp24 - tmp33 tmp35 = tmp34 * tmp34 tmp36 = tl.broadcast_to(tmp35, [RBLOCK]) tmp38 = tl.where(rmask, tmp36, 0) tmp39 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0)) tmp40 = tmp23 - tmp33 tmp41 = 768.0 tmp42 = tmp39 / tmp41 tmp43 = 1e-12 tmp44 = tmp42 + tmp43 tmp45 = libdevice.rsqrt(tmp44) tmp46 = tmp40 * tmp45 tmp48 = tmp47.to(tl.float32) tmp49 = tmp46 * tmp48 tmp51 = tmp50.to(tl.float32) tmp52 = tmp49 + tmp51 tmp53 = tmp52.to(tl.float32) tl.store(out_ptr0 + (r2 + (768*x3)), tmp22, rmask) tl.store(out_ptr3 + (r2 + (768*x3)), tmp53, rmask) ''', device_str='cuda') import triton import triton.language as tl from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph from torch._C import _cuda_getCurrentRawStream as get_raw_stream # kernel path: /tmp/torchinductor_eellison/ze/czezsvjgwtdzru43t5gig4zb6t4auharyypwotkirkhmzu46fkzp.py # Source Nodes: [add_2, hidden_states_2], Original ATen: [aten.add, aten.native_layer_norm] # add_2 => add_5 # hidden_states_2 => add_6, add_7, convert_element_type_21, convert_element_type_22, mul_3, mul_4, rsqrt_1, sub_3, var_mean_1 triton_per_fused_add_native_layer_norm_1 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.persistent_reduction( size_hints=[8192, 1024], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_1', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel): xnumel = 8192 XBLOCK: tl.constexpr = 1 rnumel = 768 RBLOCK: tl.constexpr = 1024 xoffset = tl.program_id(0) * XBLOCK xindex = tl.full([1], xoffset, tl.int32) xmask = tl.full([RBLOCK], True, tl.int1) rindex = tl.arange(0, RBLOCK)[:] roffset = 0 rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp2 = tmp0 + tmp1 tmp4 = tmp2 + tmp3 tmp5 = tmp4.to(tl.float32) tmp6 = tl.broadcast_to(tmp5, [RBLOCK]) tmp8 = tl.where(rmask, tmp6, 0) tmp9 = tl.broadcast_to(tmp6, [RBLOCK]) tmp11 = tl.where(rmask, tmp9, 0) tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0)) tmp13 = tl.full([1], 768, tl.int32) tmp14 = tmp13.to(tl.float32) tmp15 = tmp12 / tmp14 tmp16 = tmp6 - tmp15 tmp17 = tmp16 * tmp16 tmp18 = tl.broadcast_to(tmp17, [RBLOCK]) tmp20 = tl.where(rmask, tmp18, 0) tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0)) tmp22 = tmp5 - tmp15 tmp23 = 768.0 tmp24 = tmp21 / tmp23 tmp25 = 1e-12 tmp26 = tmp24 + tmp25 tmp27 = libdevice.rsqrt(tmp26) tmp28 = tmp22 * tmp27 tmp30 = tmp29.to(tl.float32) tmp31 = tmp28 * tmp30 tmp33 = tmp32.to(tl.float32) tmp34 = tmp31 + tmp33 tmp35 = tmp34.to(tl.float32) tl.store(out_ptr2 + (r1 + (768*x0)), tmp35, rmask) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/h7/ch7fp7gaeqqseqnc45vn7aus3pytqrkcm2gt6tswtsp4robtviff.py # Source Nodes: [hidden_states_4], Original ATen: [aten.gelu] # hidden_states_4 => add_8, convert_element_type_26, convert_element_type_27, erf, mul_5, mul_6, mul_7 triton_poi_fused_gelu_2 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.pointwise( size_hints=[33554432], filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_2', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, min_elem_per_thread=0 ) @triton.jit def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 25165824 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = tl.full([XBLOCK], True, tl.int1) x2 = xindex x0 = xindex % 3072 tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32) tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32) tmp2 = tmp0 + tmp1 tmp3 = tmp2.to(tl.float32) tmp4 = 0.5 tmp5 = tmp3 * tmp4 tmp6 = 0.7071067811865476 tmp7 = tmp3 * tmp6 tmp8 = libdevice.erf(tmp7) tmp9 = 1.0 tmp10 = tmp8 + tmp9 tmp11 = tmp5 * tmp10 tmp12 = tmp11.to(tl.float32) tl.store(in_out_ptr0 + (x2), tmp12, None) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/kv/ckv6n5lz6ocro4pxsewmvlwwdchxn7ycxtd2mm4zi5mmbccvo7j4.py # Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten.gelu, aten.native_layer_norm] # hidden_states_97 => add_100, convert_element_type_366, convert_element_type_367, erf_12, mul_87, mul_88, mul_89 # hidden_states_98 => add_101, add_102, convert_element_type_368, convert_element_type_369, mul_90, mul_91, rsqrt_25, sub_38, var_mean_25 triton_per_fused_gelu_native_layer_norm_3 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.persistent_reduction( size_hints=[8192, 1024], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_gelu_native_layer_norm_3', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 4, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel): xnumel = 8192 XBLOCK: tl.constexpr = 1 rnumel = 768 RBLOCK: tl.constexpr = 1024 xoffset = tl.program_id(0) * XBLOCK xindex = tl.full([1], xoffset, tl.int32) xmask = tl.full([RBLOCK], True, tl.int1) rindex = tl.arange(0, RBLOCK)[:] roffset = 0 rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp37 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp40 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp2 = tmp0 + tmp1 tmp3 = tmp2.to(tl.float32) tmp4 = 0.5 tmp5 = tmp3 * tmp4 tmp6 = 0.7071067811865476 tmp7 = tmp3 * tmp6 tmp8 = libdevice.erf(tmp7) tmp9 = 1.0 tmp10 = tmp8 + tmp9 tmp11 = tmp5 * tmp10 tmp12 = tmp11.to(tl.float32) tmp13 = tmp12.to(tl.float32) tmp14 = tl.broadcast_to(tmp13, [RBLOCK]) tmp16 = tl.where(rmask, tmp14, 0) tmp17 = tl.broadcast_to(tmp14, [RBLOCK]) tmp19 = tl.where(rmask, tmp17, 0) tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0)) tmp21 = tl.full([1], 768, tl.int32) tmp22 = tmp21.to(tl.float32) tmp23 = tmp20 / tmp22 tmp24 = tmp14 - tmp23 tmp25 = tmp24 * tmp24 tmp26 = tl.broadcast_to(tmp25, [RBLOCK]) tmp28 = tl.where(rmask, tmp26, 0) tmp29 = triton_helpers.promote_to_tensor(tl.sum(tmp28, 0)) tmp30 = tmp13 - tmp23 tmp31 = 768.0 tmp32 = tmp29 / tmp31 tmp33 = 1e-12 tmp34 = tmp32 + tmp33 tmp35 = libdevice.rsqrt(tmp34) tmp36 = tmp30 * tmp35 tmp38 = tmp37.to(tl.float32) tmp39 = tmp36 * tmp38 tmp41 = tmp40.to(tl.float32) tmp42 = tmp39 + tmp41 tmp43 = tmp42.to(tl.float32) tl.store(out_ptr3 + (r1 + (768*x0)), tmp43, rmask) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/ry/cryghuddhmky5nywmwltcxvq5qstdsr5rotzufalizyup7kvjzmb.py # Source Nodes: [], Original ATen: [] triton_poi_fused_4 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.pointwise( size_hints=[33554432], filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, min_elem_per_thread=0 ) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 23445504 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = tl.full([XBLOCK], True, tl.int1) x0 = xindex % 30528 x1 = (xindex // 30528) x2 = xindex tmp0 = x0 tmp1 = tl.full([1], 0, tl.int64) tmp2 = tmp0 >= tmp1 tmp3 = tl.full([1], 30522, tl.int64) tmp4 = tmp0 < tmp3 tmp5 = tl.load(in_ptr0 + (x1 + (768*x0)), tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp6 = tmp0 >= tmp3 tmp7 = tl.full([1], 30528, tl.int64) tmp8 = tmp0 < tmp7 tmp9 = 0.0 tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype) tmp11 = tl.where(tmp6, tmp9, tmp10) tmp12 = tl.where(tmp4, tmp5, tmp11) tl.store(out_ptr0 + (x2), tmp12, None) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/iq/ciq3xiils5u6jz63putfodpdyv6oshlb3gbss7hjki6u3ogkhwqq.py # Source Nodes: [], Original ATen: [] triton_poi_fused_5 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.pointwise( size_hints=[32768], filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, min_elem_per_thread=0 ) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 30528 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = x0 tmp1 = tl.full([1], 0, tl.int64) tmp2 = tmp0 >= tmp1 tmp3 = tl.full([1], 30522, tl.int64) tmp4 = tmp0 < tmp3 tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp6 = tmp0 >= tmp3 tmp7 = tl.full([1], 30528, tl.int64) tmp8 = tmp0 < tmp7 tmp9 = 0.0 tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype) tmp11 = tl.where(tmp6, tmp9, tmp10) tmp12 = tl.where(tmp4, tmp5, tmp11) tl.store(out_ptr0 + (x0), tmp12, xmask) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/vf/cvfw2hcsiy2zivm66wjz4xfhs2ysdjbvqk2rsig4hp62bpxn2vvd.py # Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax] # masked_lm_loss => amax_12, convert_element_type_373, exp_12, sub_39, sum_13 triton_red_fused__log_softmax_6 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.reduction( size_hints=[8192, 32768], reduction_hint=ReductionHint.DEFAULT, filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 8192 rnumel = 30522 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp3 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp1 = tmp0.to(tl.float32) tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) tmp4 = triton_helpers.maximum(_tmp3, tmp2) _tmp3 = tl.where(rmask, tmp4, _tmp3) tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] tl.store(out_ptr0 + (x0), tmp3, None) _tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp5 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) tmp6 = tmp5.to(tl.float32) tmp7 = tmp6 - tmp3 tmp8 = tl_math.exp(tmp7) tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK]) tmp11 = _tmp10 + tmp9 _tmp10 = tl.where(rmask, tmp11, _tmp10) tmp10 = tl.sum(_tmp10, 1)[:, None] tl.store(out_ptr1 + (x0), tmp10, None) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/bt/cbtmzvri2vmmhtlgzxhr7zowbeqsafuqeui5gh3j2bsgiibilujp.py # Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_forward] # masked_lm_loss => convert_element_type_375, div_24, full_default_2, ne_1, ne_2, neg, sum_14, sum_15, where_1 triton_red_fused_nll_loss_forward_7 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.reduction( size_hints=[1, 8192], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*i64', 2: '*bf16', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_forward_7', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 2, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 1 rnumel = 8192 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) rbase = tl.arange(0, RBLOCK)[None, :] _tmp22 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) _tmp26 = tl.full([XBLOCK, RBLOCK], 0, tl.int64) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r0 = rindex tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0) tmp12 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0) tmp14 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0) tmp1 = tl.full([1, 1], -100, tl.int64) tmp2 = tmp0 != tmp1 tmp3 = tl.full([1, 1], 0, tl.int64) tmp4 = tl.where(tmp2, tmp0, tmp3) tmp5 = tl.full([XBLOCK, RBLOCK], 30522, tl.int32) tmp6 = tmp4 + tmp5 tmp7 = tmp4 < 0 tmp8 = tl.where(tmp7, tmp6, tmp4) tl.device_assert(((0 <= tmp8) & (tmp8 < 30522)) | ~(rmask), "index out of bounds: 0 <= tmp8 < 30522") tmp10 = tl.load(in_ptr1 + (tmp8 + (30528*r0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp11 = tmp10.to(tl.float32) tmp13 = tmp11 - tmp12 tmp15 = tl_math.log(tmp14) tmp16 = tmp13 - tmp15 tmp17 = tmp16.to(tl.float32) tmp18 = -tmp17 tmp19 = 0.0 tmp20 = tl.where(tmp2, tmp18, tmp19) tmp21 = tl.broadcast_to(tmp20, [XBLOCK, RBLOCK]) tmp23 = _tmp22 + tmp21 _tmp22 = tl.where(rmask, tmp23, _tmp22) tmp24 = tmp2.to(tl.int64) tmp25 = tl.broadcast_to(tmp24, [XBLOCK, RBLOCK]) tmp27 = _tmp26 + tmp25 _tmp26 = tl.where(rmask, tmp27, _tmp26) tmp22 = tl.sum(_tmp22, 1)[:, None] tmp26 = tl.sum(_tmp26, 1)[:, None] tmp28 = tmp26.to(tl.float32) tmp29 = tmp22 / tmp28 tl.debug_barrier() tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp29, None) ''', device_str='cuda') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1 = args args.clear() assert_size_stride(arg0_1, (30522, 768), (768, 1)) assert_size_stride(arg1_1, (2, 768), (768, 1)) assert_size_stride(arg2_1, (512, 768), (768, 1)) assert_size_stride(arg3_1, (768, ), (1, )) assert_size_stride(arg4_1, (768, ), (1, )) assert_size_stride(arg5_1, (768, 768), (768, 1)) assert_size_stride(arg6_1, (768, ), (1, )) assert_size_stride(arg7_1, (768, 768), (768, 1)) assert_size_stride(arg8_1, (768, ), (1, )) assert_size_stride(arg9_1, (768, 768), (768, 1)) assert_size_stride(arg10_1, (768, ), (1, )) assert_size_stride(arg11_1, (768, 768), (768, 1)) assert_size_stride(arg12_1, (768, ), (1, )) assert_size_stride(arg13_1, (768, ), (1, )) assert_size_stride(arg14_1, (768, ), (1, )) assert_size_stride(arg15_1, (3072, 768), (768, 1)) assert_size_stride(arg16_1, (3072, ), (1, )) assert_size_stride(arg17_1, (768, 3072), (3072, 1)) assert_size_stride(arg18_1, (768, ), (1, )) assert_size_stride(arg19_1, (768, ), (1, )) assert_size_stride(arg20_1, (768, ), (1, )) assert_size_stride(arg21_1, (768, 768), (768, 1)) assert_size_stride(arg22_1, (768, ), (1, )) assert_size_stride(arg23_1, (768, 768), (768, 1)) assert_size_stride(arg24_1, (768, ), (1, )) assert_size_stride(arg25_1, (768, 768), (768, 1)) assert_size_stride(arg26_1, (768, ), (1, )) assert_size_stride(arg27_1, (768, 768), (768, 1)) assert_size_stride(arg28_1, (768, ), (1, )) assert_size_stride(arg29_1, (768, ), (1, )) assert_size_stride(arg30_1, (768, ), (1, )) assert_size_stride(arg31_1, (3072, 768), (768, 1)) assert_size_stride(arg32_1, (3072, ), (1, )) assert_size_stride(arg33_1, (768, 3072), (3072, 1)) assert_size_stride(arg34_1, (768, ), (1, )) assert_size_stride(arg35_1, (768, ), (1, )) assert_size_stride(arg36_1, (768, ), (1, )) assert_size_stride(arg37_1, (768, 768), (768, 1)) assert_size_stride(arg38_1, (768, ), (1, )) assert_size_stride(arg39_1, (768, 768), (768, 1)) assert_size_stride(arg40_1, (768, ), (1, )) assert_size_stride(arg41_1, (768, 768), (768, 1)) assert_size_stride(arg42_1, (768, ), (1, )) assert_size_stride(arg43_1, (768, 768), (768, 1)) assert_size_stride(arg44_1, (768, ), (1, )) assert_size_stride(arg45_1, (768, ), (1, )) assert_size_stride(arg46_1, (768, ), (1, )) assert_size_stride(arg47_1, (3072, 768), (768, 1)) assert_size_stride(arg48_1, (3072, ), (1, )) assert_size_stride(arg49_1, (768, 3072), (3072, 1)) assert_size_stride(arg50_1, (768, ), (1, )) assert_size_stride(arg51_1, (768, ), (1, )) assert_size_stride(arg52_1, (768, ), (1, )) assert_size_stride(arg53_1, (768, 768), (768, 1)) assert_size_stride(arg54_1, (768, ), (1, )) assert_size_stride(arg55_1, (768, 768), (768, 1)) assert_size_stride(arg56_1, (768, ), (1, )) assert_size_stride(arg57_1, (768, 768), (768, 1)) assert_size_stride(arg58_1, (768, ), (1, )) assert_size_stride(arg59_1, (768, 768), (768, 1)) assert_size_stride(arg60_1, (768, ), (1, )) assert_size_stride(arg61_1, (768, ), (1, )) assert_size_stride(arg62_1, (768, ), (1, )) assert_size_stride(arg63_1, (3072, 768), (768, 1)) assert_size_stride(arg64_1, (3072, ), (1, )) assert_size_stride(arg65_1, (768, 3072), (3072, 1)) assert_size_stride(arg66_1, (768, ), (1, )) assert_size_stride(arg67_1, (768, ), (1, )) assert_size_stride(arg68_1, (768, ), (1, )) assert_size_stride(arg69_1, (768, 768), (768, 1)) assert_size_stride(arg70_1, (768, ), (1, )) assert_size_stride(arg71_1, (768, 768), (768, 1)) assert_size_stride(arg72_1, (768, ), (1, )) assert_size_stride(arg73_1, (768, 768), (768, 1)) assert_size_stride(arg74_1, (768, ), (1, )) assert_size_stride(arg75_1, (768, 768), (768, 1)) assert_size_stride(arg76_1, (768, ), (1, )) assert_size_stride(arg77_1, (768, ), (1, )) assert_size_stride(arg78_1, (768, ), (1, )) assert_size_stride(arg79_1, (3072, 768), (768, 1)) assert_size_stride(arg80_1, (3072, ), (1, )) assert_size_stride(arg81_1, (768, 3072), (3072, 1)) assert_size_stride(arg82_1, (768, ), (1, )) assert_size_stride(arg83_1, (768, ), (1, )) assert_size_stride(arg84_1, (768, ), (1, )) assert_size_stride(arg85_1, (768, 768), (768, 1)) assert_size_stride(arg86_1, (768, ), (1, )) assert_size_stride(arg87_1, (768, 768), (768, 1)) assert_size_stride(arg88_1, (768, ), (1, )) assert_size_stride(arg89_1, (768, 768), (768, 1)) assert_size_stride(arg90_1, (768, ), (1, )) assert_size_stride(arg91_1, (768, 768), (768, 1)) assert_size_stride(arg92_1, (768, ), (1, )) assert_size_stride(arg93_1, (768, ), (1, )) assert_size_stride(arg94_1, (768, ), (1, )) assert_size_stride(arg95_1, (3072, 768), (768, 1)) assert_size_stride(arg96_1, (3072, ), (1, )) assert_size_stride(arg97_1, (768, 3072), (3072, 1)) assert_size_stride(arg98_1, (768, ), (1, )) assert_size_stride(arg99_1, (768, ), (1, )) assert_size_stride(arg100_1, (768, ), (1, )) assert_size_stride(arg101_1, (768, 768), (768, 1)) assert_size_stride(arg102_1, (768, ), (1, )) assert_size_stride(arg103_1, (768, 768), (768, 1)) assert_size_stride(arg104_1, (768, ), (1, )) assert_size_stride(arg105_1, (768, 768), (768, 1)) assert_size_stride(arg106_1, (768, ), (1, )) assert_size_stride(arg107_1, (768, 768), (768, 1)) assert_size_stride(arg108_1, (768, ), (1, )) assert_size_stride(arg109_1, (768, ), (1, )) assert_size_stride(arg110_1, (768, ), (1, )) assert_size_stride(arg111_1, (3072, 768), (768, 1)) assert_size_stride(arg112_1, (3072, ), (1, )) assert_size_stride(arg113_1, (768, 3072), (3072, 1)) assert_size_stride(arg114_1, (768, ), (1, )) assert_size_stride(arg115_1, (768, ), (1, )) assert_size_stride(arg116_1, (768, ), (1, )) assert_size_stride(arg117_1, (768, 768), (768, 1)) assert_size_stride(arg118_1, (768, ), (1, )) assert_size_stride(arg119_1, (768, 768), (768, 1)) assert_size_stride(arg120_1, (768, ), (1, )) assert_size_stride(arg121_1, (768, 768), (768, 1)) assert_size_stride(arg122_1, (768, ), (1, )) assert_size_stride(arg123_1, (768, 768), (768, 1)) assert_size_stride(arg124_1, (768, ), (1, )) assert_size_stride(arg125_1, (768, ), (1, )) assert_size_stride(arg126_1, (768, ), (1, )) assert_size_stride(arg127_1, (3072, 768), (768, 1)) assert_size_stride(arg128_1, (3072, ), (1, )) assert_size_stride(arg129_1, (768, 3072), (3072, 1)) assert_size_stride(arg130_1, (768, ), (1, )) assert_size_stride(arg131_1, (768, ), (1, )) assert_size_stride(arg132_1, (768, ), (1, )) assert_size_stride(arg133_1, (768, 768), (768, 1)) assert_size_stride(arg134_1, (768, ), (1, )) assert_size_stride(arg135_1, (768, 768), (768, 1)) assert_size_stride(arg136_1, (768, ), (1, )) assert_size_stride(arg137_1, (768, 768), (768, 1)) assert_size_stride(arg138_1, (768, ), (1, )) assert_size_stride(arg139_1, (768, 768), (768, 1)) assert_size_stride(arg140_1, (768, ), (1, )) assert_size_stride(arg141_1, (768, ), (1, )) assert_size_stride(arg142_1, (768, ), (1, )) assert_size_stride(arg143_1, (3072, 768), (768, 1)) assert_size_stride(arg144_1, (3072, ), (1, )) assert_size_stride(arg145_1, (768, 3072), (3072, 1)) assert_size_stride(arg146_1, (768, ), (1, )) assert_size_stride(arg147_1, (768, ), (1, )) assert_size_stride(arg148_1, (768, ), (1, )) assert_size_stride(arg149_1, (768, 768), (768, 1)) assert_size_stride(arg150_1, (768, ), (1, )) assert_size_stride(arg151_1, (768, 768), (768, 1)) assert_size_stride(arg152_1, (768, ), (1, )) assert_size_stride(arg153_1, (768, 768), (768, 1)) assert_size_stride(arg154_1, (768, ), (1, )) assert_size_stride(arg155_1, (768, 768), (768, 1)) assert_size_stride(arg156_1, (768, ), (1, )) assert_size_stride(arg157_1, (768, ), (1, )) assert_size_stride(arg158_1, (768, ), (1, )) assert_size_stride(arg159_1, (3072, 768), (768, 1)) assert_size_stride(arg160_1, (3072, ), (1, )) assert_size_stride(arg161_1, (768, 3072), (3072, 1)) assert_size_stride(arg162_1, (768, ), (1, )) assert_size_stride(arg163_1, (768, ), (1, )) assert_size_stride(arg164_1, (768, ), (1, )) assert_size_stride(arg165_1, (768, 768), (768, 1)) assert_size_stride(arg166_1, (768, ), (1, )) assert_size_stride(arg167_1, (768, 768), (768, 1)) assert_size_stride(arg168_1, (768, ), (1, )) assert_size_stride(arg169_1, (768, 768), (768, 1)) assert_size_stride(arg170_1, (768, ), (1, )) assert_size_stride(arg171_1, (768, 768), (768, 1)) assert_size_stride(arg172_1, (768, ), (1, )) assert_size_stride(arg173_1, (768, ), (1, )) assert_size_stride(arg174_1, (768, ), (1, )) assert_size_stride(arg175_1, (3072, 768), (768, 1)) assert_size_stride(arg176_1, (3072, ), (1, )) assert_size_stride(arg177_1, (768, 3072), (3072, 1)) assert_size_stride(arg178_1, (768, ), (1, )) assert_size_stride(arg179_1, (768, ), (1, )) assert_size_stride(arg180_1, (768, ), (1, )) assert_size_stride(arg181_1, (768, 768), (768, 1)) assert_size_stride(arg182_1, (768, ), (1, )) assert_size_stride(arg183_1, (768, 768), (768, 1)) assert_size_stride(arg184_1, (768, ), (1, )) assert_size_stride(arg185_1, (768, 768), (768, 1)) assert_size_stride(arg186_1, (768, ), (1, )) assert_size_stride(arg187_1, (768, 768), (768, 1)) assert_size_stride(arg188_1, (768, ), (1, )) assert_size_stride(arg189_1, (768, ), (1, )) assert_size_stride(arg190_1, (768, ), (1, )) assert_size_stride(arg191_1, (3072, 768), (768, 1)) assert_size_stride(arg192_1, (3072, ), (1, )) assert_size_stride(arg193_1, (768, 3072), (3072, 1)) assert_size_stride(arg194_1, (768, ), (1, )) assert_size_stride(arg195_1, (768, ), (1, )) assert_size_stride(arg196_1, (768, ), (1, )) assert_size_stride(arg197_1, (768, 768), (768, 1)) assert_size_stride(arg198_1, (768, ), (1, )) assert_size_stride(arg199_1, (768, ), (1, )) assert_size_stride(arg200_1, (768, ), (1, )) assert_size_stride(arg201_1, (30522, 768), (768, 1)) assert_size_stride(arg202_1, (30522, ), (1, )) assert_size_stride(arg203_1, (1, 512), (512, 1)) assert_size_stride(arg204_1, (1, 512), (512, 1)) assert_size_stride(arg205_1, (16, 512), (512, 1)) assert_size_stride(arg206_1, (16, 512), (512, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.bfloat16) buf4 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.bfloat16) # Source Nodes: [embeddings, embeddings_1, embeddings_2, inputs_embeds, position_embeddings, token_type_embeddings], Original ATen: [aten.add, aten.embedding, aten.native_layer_norm] stream0 = get_raw_stream(0) triton_per_fused_add_embedding_native_layer_norm_0.run(arg205_1, arg0_1, arg203_1, arg1_1, arg204_1, arg2_1, arg3_1, arg4_1, buf0, buf4, 8192, 768, grid=grid(8192), stream=stream0) del arg0_1 del arg1_1 del arg203_1 del arg204_1 del arg205_1 del arg2_1 del arg3_1 del arg4_1 buf5 = reinterpret_tensor(buf0, (8192, 768), (768, 1), 0); del buf0 # reuse # Source Nodes: [mixed_query_layer], Original ATen: [aten.addmm] extern_kernels.addmm(arg6_1, reinterpret_tensor(buf4, (8192, 768), (768, 1), 0), reinterpret_tensor(arg5_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf5) del arg5_1 del arg6_1 buf6 = empty_strided_cuda((8192, 768), (768, 1), torch.bfloat16) # Source Nodes: [l__mod___bert_encoder_layer_0_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg8_1, reinterpret_tensor(buf4, (8192, 768), (768, 1), 0), reinterpret_tensor(arg7_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf6) del arg7_1 del arg8_1 buf7 = empty_strided_cuda((8192, 768), (768, 1), torch.bfloat16) # Source Nodes: [l__mod___bert_encoder_layer_0_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg10_1, reinterpret_tensor(buf4, (8192, 768), (768, 1), 0), reinterpret_tensor(arg9_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf7) del arg10_1 del arg9_1 # Source Nodes: [], Original ATen: [] buf8 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf5, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf6, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf7, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf5 buf9 = buf8[0] del buf8 buf14 = buf7; del buf7 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf9, (8192, 768), (768, 1), 0), reinterpret_tensor(arg11_1, (768, 768), (1, 768), 0), out=buf14) del arg11_1 buf18 = reinterpret_tensor(buf9, (16, 512, 768), (393216, 768, 1), 0); del buf9 # reuse # Source Nodes: [add_2, hidden_states_2], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf14, arg12_1, buf4, arg13_1, arg14_1, buf18, 8192, 768, grid=grid(8192), stream=stream0) del arg12_1 del arg13_1 del arg14_1 buf19 = empty_strided_cuda((8192, 3072), (3072, 1), torch.bfloat16) # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf18, (8192, 768), (768, 1), 0), reinterpret_tensor(arg15_1, (768, 3072), (1, 768), 0), out=buf19) del arg15_1 buf20 = reinterpret_tensor(buf19, (16, 512, 3072), (1572864, 3072, 1), 0); del buf19 # reuse # Source Nodes: [hidden_states_4], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf20, arg16_1, 25165824, grid=grid(25165824), stream=stream0) del arg16_1 buf21 = reinterpret_tensor(buf4, (8192, 768), (768, 1), 0); del buf4 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf20, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg17_1, (3072, 768), (1, 3072), 0), out=buf21) del arg17_1 buf25 = reinterpret_tensor(buf14, (16, 512, 768), (393216, 768, 1), 0); del buf14 # reuse # Source Nodes: [add_3, hidden_states_7], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf21, arg18_1, buf18, arg19_1, arg20_1, buf25, 8192, 768, grid=grid(8192), stream=stream0) del arg18_1 del arg19_1 del arg20_1 buf26 = buf21; del buf21 # reuse # Source Nodes: [mixed_query_layer_1], Original ATen: [aten.addmm] extern_kernels.addmm(arg22_1, reinterpret_tensor(buf25, (8192, 768), (768, 1), 0), reinterpret_tensor(arg21_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf26) del arg21_1 del arg22_1 buf27 = reinterpret_tensor(buf18, (8192, 768), (768, 1), 0); del buf18 # reuse # Source Nodes: [l__mod___bert_encoder_layer_1_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg24_1, reinterpret_tensor(buf25, (8192, 768), (768, 1), 0), reinterpret_tensor(arg23_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf27) del arg23_1 del arg24_1 buf28 = buf6; del buf6 # reuse # Source Nodes: [l__mod___bert_encoder_layer_1_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg26_1, reinterpret_tensor(buf25, (8192, 768), (768, 1), 0), reinterpret_tensor(arg25_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf28) del arg25_1 del arg26_1 # Source Nodes: [], Original ATen: [] buf29 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf26, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf27, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf28, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf26 buf30 = buf29[0] del buf29 buf35 = buf28; del buf28 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf30, (8192, 768), (768, 1), 0), reinterpret_tensor(arg27_1, (768, 768), (1, 768), 0), out=buf35) del arg27_1 buf39 = reinterpret_tensor(buf30, (16, 512, 768), (393216, 768, 1), 0); del buf30 # reuse # Source Nodes: [add_5, hidden_states_10], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf35, arg28_1, buf25, arg29_1, arg30_1, buf39, 8192, 768, grid=grid(8192), stream=stream0) del arg28_1 del arg29_1 del arg30_1 buf40 = reinterpret_tensor(buf20, (8192, 3072), (3072, 1), 0); del buf20 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf39, (8192, 768), (768, 1), 0), reinterpret_tensor(arg31_1, (768, 3072), (1, 768), 0), out=buf40) del arg31_1 buf41 = reinterpret_tensor(buf40, (16, 512, 3072), (1572864, 3072, 1), 0); del buf40 # reuse # Source Nodes: [hidden_states_12], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf41, arg32_1, 25165824, grid=grid(25165824), stream=stream0) del arg32_1 buf42 = buf35; del buf35 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf41, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg33_1, (3072, 768), (1, 3072), 0), out=buf42) del arg33_1 buf46 = buf25; del buf25 # reuse # Source Nodes: [add_6, hidden_states_15], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf42, arg34_1, buf39, arg35_1, arg36_1, buf46, 8192, 768, grid=grid(8192), stream=stream0) del arg34_1 del arg35_1 del arg36_1 buf47 = buf42; del buf42 # reuse # Source Nodes: [mixed_query_layer_2], Original ATen: [aten.addmm] extern_kernels.addmm(arg38_1, reinterpret_tensor(buf46, (8192, 768), (768, 1), 0), reinterpret_tensor(arg37_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf47) del arg37_1 del arg38_1 buf48 = reinterpret_tensor(buf39, (8192, 768), (768, 1), 0); del buf39 # reuse # Source Nodes: [l__mod___bert_encoder_layer_2_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg40_1, reinterpret_tensor(buf46, (8192, 768), (768, 1), 0), reinterpret_tensor(arg39_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf48) del arg39_1 del arg40_1 buf49 = buf27; del buf27 # reuse # Source Nodes: [l__mod___bert_encoder_layer_2_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg42_1, reinterpret_tensor(buf46, (8192, 768), (768, 1), 0), reinterpret_tensor(arg41_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf49) del arg41_1 del arg42_1 # Source Nodes: [], Original ATen: [] buf50 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf47, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf48, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf49, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf47 buf51 = buf50[0] del buf50 buf56 = buf49; del buf49 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf51, (8192, 768), (768, 1), 0), reinterpret_tensor(arg43_1, (768, 768), (1, 768), 0), out=buf56) del arg43_1 buf60 = reinterpret_tensor(buf51, (16, 512, 768), (393216, 768, 1), 0); del buf51 # reuse # Source Nodes: [add_8, hidden_states_18], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf56, arg44_1, buf46, arg45_1, arg46_1, buf60, 8192, 768, grid=grid(8192), stream=stream0) del arg44_1 del arg45_1 del arg46_1 buf61 = reinterpret_tensor(buf41, (8192, 3072), (3072, 1), 0); del buf41 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf60, (8192, 768), (768, 1), 0), reinterpret_tensor(arg47_1, (768, 3072), (1, 768), 0), out=buf61) del arg47_1 buf62 = reinterpret_tensor(buf61, (16, 512, 3072), (1572864, 3072, 1), 0); del buf61 # reuse # Source Nodes: [hidden_states_20], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf62, arg48_1, 25165824, grid=grid(25165824), stream=stream0) del arg48_1 buf63 = buf56; del buf56 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf62, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg49_1, (3072, 768), (1, 3072), 0), out=buf63) del arg49_1 buf67 = buf46; del buf46 # reuse # Source Nodes: [add_9, hidden_states_23], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf63, arg50_1, buf60, arg51_1, arg52_1, buf67, 8192, 768, grid=grid(8192), stream=stream0) del arg50_1 del arg51_1 del arg52_1 buf68 = buf63; del buf63 # reuse # Source Nodes: [mixed_query_layer_3], Original ATen: [aten.addmm] extern_kernels.addmm(arg54_1, reinterpret_tensor(buf67, (8192, 768), (768, 1), 0), reinterpret_tensor(arg53_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf68) del arg53_1 del arg54_1 buf69 = reinterpret_tensor(buf60, (8192, 768), (768, 1), 0); del buf60 # reuse # Source Nodes: [l__mod___bert_encoder_layer_3_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg56_1, reinterpret_tensor(buf67, (8192, 768), (768, 1), 0), reinterpret_tensor(arg55_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf69) del arg55_1 del arg56_1 buf70 = buf48; del buf48 # reuse # Source Nodes: [l__mod___bert_encoder_layer_3_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg58_1, reinterpret_tensor(buf67, (8192, 768), (768, 1), 0), reinterpret_tensor(arg57_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf70) del arg57_1 del arg58_1 # Source Nodes: [], Original ATen: [] buf71 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf68, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf69, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf70, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf68 buf72 = buf71[0] del buf71 buf77 = buf70; del buf70 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf72, (8192, 768), (768, 1), 0), reinterpret_tensor(arg59_1, (768, 768), (1, 768), 0), out=buf77) del arg59_1 buf81 = reinterpret_tensor(buf72, (16, 512, 768), (393216, 768, 1), 0); del buf72 # reuse # Source Nodes: [add_11, hidden_states_26], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf77, arg60_1, buf67, arg61_1, arg62_1, buf81, 8192, 768, grid=grid(8192), stream=stream0) del arg60_1 del arg61_1 del arg62_1 buf82 = reinterpret_tensor(buf62, (8192, 3072), (3072, 1), 0); del buf62 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf81, (8192, 768), (768, 1), 0), reinterpret_tensor(arg63_1, (768, 3072), (1, 768), 0), out=buf82) del arg63_1 buf83 = reinterpret_tensor(buf82, (16, 512, 3072), (1572864, 3072, 1), 0); del buf82 # reuse # Source Nodes: [hidden_states_28], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf83, arg64_1, 25165824, grid=grid(25165824), stream=stream0) del arg64_1 buf84 = buf77; del buf77 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf83, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg65_1, (3072, 768), (1, 3072), 0), out=buf84) del arg65_1 buf88 = buf67; del buf67 # reuse # Source Nodes: [add_12, hidden_states_31], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf84, arg66_1, buf81, arg67_1, arg68_1, buf88, 8192, 768, grid=grid(8192), stream=stream0) del arg66_1 del arg67_1 del arg68_1 buf89 = buf84; del buf84 # reuse # Source Nodes: [mixed_query_layer_4], Original ATen: [aten.addmm] extern_kernels.addmm(arg70_1, reinterpret_tensor(buf88, (8192, 768), (768, 1), 0), reinterpret_tensor(arg69_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf89) del arg69_1 del arg70_1 buf90 = reinterpret_tensor(buf81, (8192, 768), (768, 1), 0); del buf81 # reuse # Source Nodes: [l__mod___bert_encoder_layer_4_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg72_1, reinterpret_tensor(buf88, (8192, 768), (768, 1), 0), reinterpret_tensor(arg71_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf90) del arg71_1 del arg72_1 buf91 = buf69; del buf69 # reuse # Source Nodes: [l__mod___bert_encoder_layer_4_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg74_1, reinterpret_tensor(buf88, (8192, 768), (768, 1), 0), reinterpret_tensor(arg73_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf91) del arg73_1 del arg74_1 # Source Nodes: [], Original ATen: [] buf92 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf89, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf90, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf91, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf89 buf93 = buf92[0] del buf92 buf98 = buf91; del buf91 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf93, (8192, 768), (768, 1), 0), reinterpret_tensor(arg75_1, (768, 768), (1, 768), 0), out=buf98) del arg75_1 buf102 = reinterpret_tensor(buf93, (16, 512, 768), (393216, 768, 1), 0); del buf93 # reuse # Source Nodes: [add_14, hidden_states_34], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf98, arg76_1, buf88, arg77_1, arg78_1, buf102, 8192, 768, grid=grid(8192), stream=stream0) del arg76_1 del arg77_1 del arg78_1 buf103 = reinterpret_tensor(buf83, (8192, 3072), (3072, 1), 0); del buf83 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf102, (8192, 768), (768, 1), 0), reinterpret_tensor(arg79_1, (768, 3072), (1, 768), 0), out=buf103) del arg79_1 buf104 = reinterpret_tensor(buf103, (16, 512, 3072), (1572864, 3072, 1), 0); del buf103 # reuse # Source Nodes: [hidden_states_36], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf104, arg80_1, 25165824, grid=grid(25165824), stream=stream0) del arg80_1 buf105 = buf98; del buf98 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf104, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg81_1, (3072, 768), (1, 3072), 0), out=buf105) del arg81_1 buf109 = buf88; del buf88 # reuse # Source Nodes: [add_15, hidden_states_39], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf105, arg82_1, buf102, arg83_1, arg84_1, buf109, 8192, 768, grid=grid(8192), stream=stream0) del arg82_1 del arg83_1 del arg84_1 buf110 = buf105; del buf105 # reuse # Source Nodes: [mixed_query_layer_5], Original ATen: [aten.addmm] extern_kernels.addmm(arg86_1, reinterpret_tensor(buf109, (8192, 768), (768, 1), 0), reinterpret_tensor(arg85_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf110) del arg85_1 del arg86_1 buf111 = reinterpret_tensor(buf102, (8192, 768), (768, 1), 0); del buf102 # reuse # Source Nodes: [l__mod___bert_encoder_layer_5_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg88_1, reinterpret_tensor(buf109, (8192, 768), (768, 1), 0), reinterpret_tensor(arg87_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf111) del arg87_1 del arg88_1 buf112 = buf90; del buf90 # reuse # Source Nodes: [l__mod___bert_encoder_layer_5_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg90_1, reinterpret_tensor(buf109, (8192, 768), (768, 1), 0), reinterpret_tensor(arg89_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf112) del arg89_1 del arg90_1 # Source Nodes: [], Original ATen: [] buf113 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf110, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf111, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf112, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf110 buf114 = buf113[0] del buf113 buf119 = buf112; del buf112 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf114, (8192, 768), (768, 1), 0), reinterpret_tensor(arg91_1, (768, 768), (1, 768), 0), out=buf119) del arg91_1 buf123 = reinterpret_tensor(buf114, (16, 512, 768), (393216, 768, 1), 0); del buf114 # reuse # Source Nodes: [add_17, hidden_states_42], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf119, arg92_1, buf109, arg93_1, arg94_1, buf123, 8192, 768, grid=grid(8192), stream=stream0) del arg92_1 del arg93_1 del arg94_1 buf124 = reinterpret_tensor(buf104, (8192, 3072), (3072, 1), 0); del buf104 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf123, (8192, 768), (768, 1), 0), reinterpret_tensor(arg95_1, (768, 3072), (1, 768), 0), out=buf124) del arg95_1 buf125 = reinterpret_tensor(buf124, (16, 512, 3072), (1572864, 3072, 1), 0); del buf124 # reuse # Source Nodes: [hidden_states_44], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf125, arg96_1, 25165824, grid=grid(25165824), stream=stream0) del arg96_1 buf126 = buf119; del buf119 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf125, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg97_1, (3072, 768), (1, 3072), 0), out=buf126) del arg97_1 buf130 = buf109; del buf109 # reuse # Source Nodes: [add_18, hidden_states_47], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf126, arg98_1, buf123, arg99_1, arg100_1, buf130, 8192, 768, grid=grid(8192), stream=stream0) del arg100_1 del arg98_1 del arg99_1 buf131 = buf126; del buf126 # reuse # Source Nodes: [mixed_query_layer_6], Original ATen: [aten.addmm] extern_kernels.addmm(arg102_1, reinterpret_tensor(buf130, (8192, 768), (768, 1), 0), reinterpret_tensor(arg101_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf131) del arg101_1 del arg102_1 buf132 = reinterpret_tensor(buf123, (8192, 768), (768, 1), 0); del buf123 # reuse # Source Nodes: [l__mod___bert_encoder_layer_6_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg104_1, reinterpret_tensor(buf130, (8192, 768), (768, 1), 0), reinterpret_tensor(arg103_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf132) del arg103_1 del arg104_1 buf133 = buf111; del buf111 # reuse # Source Nodes: [l__mod___bert_encoder_layer_6_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg106_1, reinterpret_tensor(buf130, (8192, 768), (768, 1), 0), reinterpret_tensor(arg105_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf133) del arg105_1 del arg106_1 # Source Nodes: [], Original ATen: [] buf134 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf131, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf132, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf133, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf131 buf135 = buf134[0] del buf134 buf140 = buf133; del buf133 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf135, (8192, 768), (768, 1), 0), reinterpret_tensor(arg107_1, (768, 768), (1, 768), 0), out=buf140) del arg107_1 buf144 = reinterpret_tensor(buf135, (16, 512, 768), (393216, 768, 1), 0); del buf135 # reuse # Source Nodes: [add_20, hidden_states_50], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf140, arg108_1, buf130, arg109_1, arg110_1, buf144, 8192, 768, grid=grid(8192), stream=stream0) del arg108_1 del arg109_1 del arg110_1 buf145 = reinterpret_tensor(buf125, (8192, 3072), (3072, 1), 0); del buf125 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf144, (8192, 768), (768, 1), 0), reinterpret_tensor(arg111_1, (768, 3072), (1, 768), 0), out=buf145) del arg111_1 buf146 = reinterpret_tensor(buf145, (16, 512, 3072), (1572864, 3072, 1), 0); del buf145 # reuse # Source Nodes: [hidden_states_52], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf146, arg112_1, 25165824, grid=grid(25165824), stream=stream0) del arg112_1 buf147 = buf140; del buf140 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf146, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg113_1, (3072, 768), (1, 3072), 0), out=buf147) del arg113_1 buf151 = buf130; del buf130 # reuse # Source Nodes: [add_21, hidden_states_55], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf147, arg114_1, buf144, arg115_1, arg116_1, buf151, 8192, 768, grid=grid(8192), stream=stream0) del arg114_1 del arg115_1 del arg116_1 buf152 = buf147; del buf147 # reuse # Source Nodes: [mixed_query_layer_7], Original ATen: [aten.addmm] extern_kernels.addmm(arg118_1, reinterpret_tensor(buf151, (8192, 768), (768, 1), 0), reinterpret_tensor(arg117_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf152) del arg117_1 del arg118_1 buf153 = reinterpret_tensor(buf144, (8192, 768), (768, 1), 0); del buf144 # reuse # Source Nodes: [l__mod___bert_encoder_layer_7_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg120_1, reinterpret_tensor(buf151, (8192, 768), (768, 1), 0), reinterpret_tensor(arg119_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf153) del arg119_1 del arg120_1 buf154 = buf132; del buf132 # reuse # Source Nodes: [l__mod___bert_encoder_layer_7_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg122_1, reinterpret_tensor(buf151, (8192, 768), (768, 1), 0), reinterpret_tensor(arg121_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf154) del arg121_1 del arg122_1 # Source Nodes: [], Original ATen: [] buf155 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf152, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf153, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf154, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf152 buf156 = buf155[0] del buf155 buf161 = buf154; del buf154 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf156, (8192, 768), (768, 1), 0), reinterpret_tensor(arg123_1, (768, 768), (1, 768), 0), out=buf161) del arg123_1 buf165 = reinterpret_tensor(buf156, (16, 512, 768), (393216, 768, 1), 0); del buf156 # reuse # Source Nodes: [add_23, hidden_states_58], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf161, arg124_1, buf151, arg125_1, arg126_1, buf165, 8192, 768, grid=grid(8192), stream=stream0) del arg124_1 del arg125_1 del arg126_1 buf166 = reinterpret_tensor(buf146, (8192, 3072), (3072, 1), 0); del buf146 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf165, (8192, 768), (768, 1), 0), reinterpret_tensor(arg127_1, (768, 3072), (1, 768), 0), out=buf166) del arg127_1 buf167 = reinterpret_tensor(buf166, (16, 512, 3072), (1572864, 3072, 1), 0); del buf166 # reuse # Source Nodes: [hidden_states_60], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf167, arg128_1, 25165824, grid=grid(25165824), stream=stream0) del arg128_1 buf168 = buf161; del buf161 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf167, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg129_1, (3072, 768), (1, 3072), 0), out=buf168) del arg129_1 buf172 = buf151; del buf151 # reuse # Source Nodes: [add_24, hidden_states_63], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf168, arg130_1, buf165, arg131_1, arg132_1, buf172, 8192, 768, grid=grid(8192), stream=stream0) del arg130_1 del arg131_1 del arg132_1 buf173 = buf168; del buf168 # reuse # Source Nodes: [mixed_query_layer_8], Original ATen: [aten.addmm] extern_kernels.addmm(arg134_1, reinterpret_tensor(buf172, (8192, 768), (768, 1), 0), reinterpret_tensor(arg133_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf173) del arg133_1 del arg134_1 buf174 = reinterpret_tensor(buf165, (8192, 768), (768, 1), 0); del buf165 # reuse # Source Nodes: [l__mod___bert_encoder_layer_8_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg136_1, reinterpret_tensor(buf172, (8192, 768), (768, 1), 0), reinterpret_tensor(arg135_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf174) del arg135_1 del arg136_1 buf175 = buf153; del buf153 # reuse # Source Nodes: [l__mod___bert_encoder_layer_8_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg138_1, reinterpret_tensor(buf172, (8192, 768), (768, 1), 0), reinterpret_tensor(arg137_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf175) del arg137_1 del arg138_1 # Source Nodes: [], Original ATen: [] buf176 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf173, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf174, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf175, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf173 buf177 = buf176[0] del buf176 buf182 = buf175; del buf175 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf177, (8192, 768), (768, 1), 0), reinterpret_tensor(arg139_1, (768, 768), (1, 768), 0), out=buf182) del arg139_1 buf186 = reinterpret_tensor(buf177, (16, 512, 768), (393216, 768, 1), 0); del buf177 # reuse # Source Nodes: [add_26, hidden_states_66], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf182, arg140_1, buf172, arg141_1, arg142_1, buf186, 8192, 768, grid=grid(8192), stream=stream0) del arg140_1 del arg141_1 del arg142_1 buf187 = reinterpret_tensor(buf167, (8192, 3072), (3072, 1), 0); del buf167 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf186, (8192, 768), (768, 1), 0), reinterpret_tensor(arg143_1, (768, 3072), (1, 768), 0), out=buf187) del arg143_1 buf188 = reinterpret_tensor(buf187, (16, 512, 3072), (1572864, 3072, 1), 0); del buf187 # reuse # Source Nodes: [hidden_states_68], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf188, arg144_1, 25165824, grid=grid(25165824), stream=stream0) del arg144_1 buf189 = buf182; del buf182 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf188, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg145_1, (3072, 768), (1, 3072), 0), out=buf189) del arg145_1 buf193 = buf172; del buf172 # reuse # Source Nodes: [add_27, hidden_states_71], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf189, arg146_1, buf186, arg147_1, arg148_1, buf193, 8192, 768, grid=grid(8192), stream=stream0) del arg146_1 del arg147_1 del arg148_1 buf194 = buf189; del buf189 # reuse # Source Nodes: [mixed_query_layer_9], Original ATen: [aten.addmm] extern_kernels.addmm(arg150_1, reinterpret_tensor(buf193, (8192, 768), (768, 1), 0), reinterpret_tensor(arg149_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf194) del arg149_1 del arg150_1 buf195 = reinterpret_tensor(buf186, (8192, 768), (768, 1), 0); del buf186 # reuse # Source Nodes: [l__mod___bert_encoder_layer_9_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg152_1, reinterpret_tensor(buf193, (8192, 768), (768, 1), 0), reinterpret_tensor(arg151_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf195) del arg151_1 del arg152_1 buf196 = buf174; del buf174 # reuse # Source Nodes: [l__mod___bert_encoder_layer_9_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg154_1, reinterpret_tensor(buf193, (8192, 768), (768, 1), 0), reinterpret_tensor(arg153_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf196) del arg153_1 del arg154_1 # Source Nodes: [], Original ATen: [] buf197 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf194, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf195, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf196, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf194 buf198 = buf197[0] del buf197 buf203 = buf196; del buf196 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf198, (8192, 768), (768, 1), 0), reinterpret_tensor(arg155_1, (768, 768), (1, 768), 0), out=buf203) del arg155_1 buf207 = reinterpret_tensor(buf198, (16, 512, 768), (393216, 768, 1), 0); del buf198 # reuse # Source Nodes: [add_29, hidden_states_74], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf203, arg156_1, buf193, arg157_1, arg158_1, buf207, 8192, 768, grid=grid(8192), stream=stream0) del arg156_1 del arg157_1 del arg158_1 buf208 = reinterpret_tensor(buf188, (8192, 3072), (3072, 1), 0); del buf188 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf207, (8192, 768), (768, 1), 0), reinterpret_tensor(arg159_1, (768, 3072), (1, 768), 0), out=buf208) del arg159_1 buf209 = reinterpret_tensor(buf208, (16, 512, 3072), (1572864, 3072, 1), 0); del buf208 # reuse # Source Nodes: [hidden_states_76], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf209, arg160_1, 25165824, grid=grid(25165824), stream=stream0) del arg160_1 buf210 = buf203; del buf203 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf209, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg161_1, (3072, 768), (1, 3072), 0), out=buf210) del arg161_1 buf214 = buf193; del buf193 # reuse # Source Nodes: [add_30, hidden_states_79], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf210, arg162_1, buf207, arg163_1, arg164_1, buf214, 8192, 768, grid=grid(8192), stream=stream0) del arg162_1 del arg163_1 del arg164_1 buf215 = buf210; del buf210 # reuse # Source Nodes: [mixed_query_layer_10], Original ATen: [aten.addmm] extern_kernels.addmm(arg166_1, reinterpret_tensor(buf214, (8192, 768), (768, 1), 0), reinterpret_tensor(arg165_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf215) del arg165_1 del arg166_1 buf216 = reinterpret_tensor(buf207, (8192, 768), (768, 1), 0); del buf207 # reuse # Source Nodes: [l__mod___bert_encoder_layer_10_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg168_1, reinterpret_tensor(buf214, (8192, 768), (768, 1), 0), reinterpret_tensor(arg167_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf216) del arg167_1 del arg168_1 buf217 = buf195; del buf195 # reuse # Source Nodes: [l__mod___bert_encoder_layer_10_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg170_1, reinterpret_tensor(buf214, (8192, 768), (768, 1), 0), reinterpret_tensor(arg169_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf217) del arg169_1 del arg170_1 # Source Nodes: [], Original ATen: [] buf218 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf215, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf216, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf217, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf215 buf219 = buf218[0] del buf218 buf224 = buf217; del buf217 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf219, (8192, 768), (768, 1), 0), reinterpret_tensor(arg171_1, (768, 768), (1, 768), 0), out=buf224) del arg171_1 buf228 = reinterpret_tensor(buf219, (16, 512, 768), (393216, 768, 1), 0); del buf219 # reuse # Source Nodes: [add_32, hidden_states_82], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf224, arg172_1, buf214, arg173_1, arg174_1, buf228, 8192, 768, grid=grid(8192), stream=stream0) del arg172_1 del arg173_1 del arg174_1 buf229 = reinterpret_tensor(buf209, (8192, 3072), (3072, 1), 0); del buf209 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf228, (8192, 768), (768, 1), 0), reinterpret_tensor(arg175_1, (768, 3072), (1, 768), 0), out=buf229) del arg175_1 buf230 = reinterpret_tensor(buf229, (16, 512, 3072), (1572864, 3072, 1), 0); del buf229 # reuse # Source Nodes: [hidden_states_84], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf230, arg176_1, 25165824, grid=grid(25165824), stream=stream0) del arg176_1 buf231 = buf224; del buf224 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf230, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg177_1, (3072, 768), (1, 3072), 0), out=buf231) del arg177_1 buf235 = buf214; del buf214 # reuse # Source Nodes: [add_33, hidden_states_87], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf231, arg178_1, buf228, arg179_1, arg180_1, buf235, 8192, 768, grid=grid(8192), stream=stream0) del arg178_1 del arg179_1 del arg180_1 buf236 = buf231; del buf231 # reuse # Source Nodes: [mixed_query_layer_11], Original ATen: [aten.addmm] extern_kernels.addmm(arg182_1, reinterpret_tensor(buf235, (8192, 768), (768, 1), 0), reinterpret_tensor(arg181_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf236) del arg181_1 del arg182_1 buf237 = reinterpret_tensor(buf228, (8192, 768), (768, 1), 0); del buf228 # reuse # Source Nodes: [l__mod___bert_encoder_layer_11_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg184_1, reinterpret_tensor(buf235, (8192, 768), (768, 1), 0), reinterpret_tensor(arg183_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf237) del arg183_1 del arg184_1 buf238 = buf216; del buf216 # reuse # Source Nodes: [l__mod___bert_encoder_layer_11_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg186_1, reinterpret_tensor(buf235, (8192, 768), (768, 1), 0), reinterpret_tensor(arg185_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf238) del arg185_1 del arg186_1 # Source Nodes: [], Original ATen: [] buf239 = torch.ops.aten._scaled_dot_product_flash_attention.default(reinterpret_tensor(buf236, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf237, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf238, (16, 12, 512, 64), (393216, 64, 768, 1), 0), scale=0.125) del buf236 del buf237 buf240 = buf239[0] del buf239 buf245 = buf238; del buf238 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf240, (8192, 768), (768, 1), 0), reinterpret_tensor(arg187_1, (768, 768), (1, 768), 0), out=buf245) del arg187_1 buf249 = reinterpret_tensor(buf240, (16, 512, 768), (393216, 768, 1), 0); del buf240 # reuse # Source Nodes: [add_35, hidden_states_90], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf245, arg188_1, buf235, arg189_1, arg190_1, buf249, 8192, 768, grid=grid(8192), stream=stream0) del arg188_1 del arg189_1 del arg190_1 buf250 = reinterpret_tensor(buf230, (8192, 3072), (3072, 1), 0); del buf230 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf249, (8192, 768), (768, 1), 0), reinterpret_tensor(arg191_1, (768, 3072), (1, 768), 0), out=buf250) del arg191_1 buf251 = reinterpret_tensor(buf250, (16, 512, 3072), (1572864, 3072, 1), 0); del buf250 # reuse # Source Nodes: [hidden_states_92], Original ATen: [aten.gelu] triton_poi_fused_gelu_2.run(buf251, arg192_1, 25165824, grid=grid(25165824), stream=stream0) del arg192_1 buf252 = buf245; del buf245 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf251, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg193_1, (3072, 768), (1, 3072), 0), out=buf252) del arg193_1 del buf251 buf256 = buf235; del buf235 # reuse # Source Nodes: [add_36, hidden_states_95], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_1.run(buf252, arg194_1, buf249, arg195_1, arg196_1, buf256, 8192, 768, grid=grid(8192), stream=stream0) del arg194_1 del arg195_1 del arg196_1 del buf249 buf257 = buf252; del buf252 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf256, (8192, 768), (768, 1), 0), reinterpret_tensor(arg197_1, (768, 768), (1, 768), 0), out=buf257) del arg197_1 buf262 = buf256; del buf256 # reuse # Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten.gelu, aten.native_layer_norm] triton_per_fused_gelu_native_layer_norm_3.run(buf257, arg198_1, arg199_1, arg200_1, buf262, 8192, 768, grid=grid(8192), stream=stream0) del arg198_1 del arg199_1 del arg200_1 del buf257 buf263 = empty_strided_cuda((768, 30528), (30528, 1), torch.bfloat16) # Source Nodes: [], Original ATen: [] triton_poi_fused_4.run(arg201_1, buf263, 23445504, grid=grid(23445504), stream=stream0) del arg201_1 buf264 = empty_strided_cuda((30528, ), (1, ), torch.bfloat16) # Source Nodes: [], Original ATen: [] triton_poi_fused_5.run(arg202_1, buf264, 30528, grid=grid(30528), stream=stream0) del arg202_1 buf265 = empty_strided_cuda((8192, 30528), (30528, 1), torch.bfloat16) # Source Nodes: [], Original ATen: [] extern_kernels.addmm(buf264, reinterpret_tensor(buf262, (8192, 768), (768, 1), 0), buf263, alpha=1, beta=1, out=buf265) del buf262 del buf263 del buf264 buf266 = empty_strided_cuda((8192, 1), (1, 8192), torch.float32) buf267 = empty_strided_cuda((8192, 1), (1, 8192), torch.float32) # Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax] triton_red_fused__log_softmax_6.run(buf265, buf266, buf267, 8192, 30522, grid=grid(8192), stream=stream0) buf268 = empty_strided_cuda((), (), torch.bfloat16) buf270 = buf268; del buf268 # reuse # Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_forward] triton_red_fused_nll_loss_forward_7.run(buf270, arg206_1, buf265, buf266, buf267, 1, 8192, grid=grid(1), stream=stream0) del arg206_1 del buf266 del buf267 return (buf270, reinterpret_tensor(buf265, (16, 512, 30522), (15630336, 30528, 1), 0), ) def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg1_1 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg2_1 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg3_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg4_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg5_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg6_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg7_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg8_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg9_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg10_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg11_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg12_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg13_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg14_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg15_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg16_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg17_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg18_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg19_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg20_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg21_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg22_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg23_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg24_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg25_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg26_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg27_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg28_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg29_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg30_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg31_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg32_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg33_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg34_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg35_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg36_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg37_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg38_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg39_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg40_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg41_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg42_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg43_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg44_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg45_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg46_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg47_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg48_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg49_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg50_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg51_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg52_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg53_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg54_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg55_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg56_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg57_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg58_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg59_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg60_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg61_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg62_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg63_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg64_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg65_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg66_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg67_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg68_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg69_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg70_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg71_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg72_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg73_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg74_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg75_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg76_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg77_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg78_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg79_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg80_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg81_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg82_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg83_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg84_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg85_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg86_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg87_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg88_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg89_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg90_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg91_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg92_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg93_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg94_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg95_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg96_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg97_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg98_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg99_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg100_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg101_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg102_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg103_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg104_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg105_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg106_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg107_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg108_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg109_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg110_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg111_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg112_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg113_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg114_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg115_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg116_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg117_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg118_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg119_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg120_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg121_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg122_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg123_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg124_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg125_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg126_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg127_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg128_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg129_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg130_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg131_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg132_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg133_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg134_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg135_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg136_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg137_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg138_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg139_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg140_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg141_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg142_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg143_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg144_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg145_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg146_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg147_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg148_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg149_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg150_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg151_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg152_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg153_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg154_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg155_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg156_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg157_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg158_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg159_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg160_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg161_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg162_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg163_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg164_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg165_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg166_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg167_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg168_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg169_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg170_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg171_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg172_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg173_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg174_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg175_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg176_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg177_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg178_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg179_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg180_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg181_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg182_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg183_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg184_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg185_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg186_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg187_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg188_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg189_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg190_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg191_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg192_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg193_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg194_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg195_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg196_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg197_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg198_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg199_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg200_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg201_1 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg202_1 = rand_strided((30522, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg203_1 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) arg204_1 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) arg205_1 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) arg206_1 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1]) return print_performance(fn, times=times, repeat=repeat) if __name__ == "__main__": from torch._inductor.wrapper_benchmark import compiled_module_main compiled_module_main('BertForMaskedLM', benchmark_compiled_module)
수정본
파일 열기
# AOT ID: ['0_inference'] from ctypes import c_void_p, c_long import torch import math import random import os import tempfile from math import inf, nan from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided from torch._inductor.async_compile import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall aten = torch.ops.aten inductor_ops = torch.ops.inductor _quantized = torch.ops._quantized assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool async_compile = AsyncCompile() # kernel path: /tmp/torchinductor_eellison/wj/cwjmdogaso56hzkb4pdlo4k625tmwevhq5krfbpmpx5szmhdjbsf.py # Source Nodes: [embeddings, embeddings_1, embeddings_2, inputs_embeds, position_embeddings, token_type_embeddings], Original ATen: [aten.add, aten.embedding, aten.native_layer_norm] # embeddings => add # embeddings_1 => add_1 # embeddings_2 => add_2, add_3, convert_element_type, convert_element_type_1, mul, mul_1, rsqrt, sub, var_mean # inputs_embeds => embedding # position_embeddings => embedding_2 # token_type_embeddings => embedding_1 triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.persistent_reduction( size_hints=[8192, 1024], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*i64', 3: '*bf16', 4: '*i64', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: '*bf16', 9: '*bf16', 10: 'i32', 11: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr3, xnumel, rnumel): xnumel = 8192 XBLOCK: tl.constexpr = 1 rnumel = 768 RBLOCK: tl.constexpr = 1024 xoffset = tl.program_id(0) * XBLOCK xindex = tl.full([1], xoffset, tl.int32) xmask = tl.full([RBLOCK], True, tl.int1) rindex = tl.arange(0, RBLOCK)[:] roffset = 0 rmask = rindex < rnumel x3 = xindex r2 = rindex x0 = xindex % 512 tmp0 = tl.load(in_ptr0 + (x3), None, eviction_policy='evict_last') tmp7 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') tmp15 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') tmp47 = tl.load(in_ptr6 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp50 = tl.load(in_ptr7 + (r2), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp1 = tl.full([RBLOCK], 30522, tl.int32) tmp2 = tmp0 + tmp1 tmp3 = tmp0 < 0 tmp4 = tl.where(tmp3, tmp2, tmp0) tl.device_assert((0 <= tmp4) & (tmp4 < 30522), "index out of bounds: 0 <= tmp4 < 30522") tmp6 = tl.load(in_ptr1 + (r2 + (768*tmp4)), rmask, other=0.0).to(tl.float32) tmp8 = tl.full([RBLOCK], 2, tl.int32) tmp9 = tmp7 + tmp8 tmp10 = tmp7 < 0 tmp11 = tl.where(tmp10, tmp9, tmp7) tl.device_assert((0 <= tmp11) & (tmp11 < 2), "index out of bounds: 0 <= tmp11 < 2") tmp13 = tl.load(in_ptr3 + (r2 + (768*tmp11)), rmask, other=0.0).to(tl.float32) tmp14 = tmp6 + tmp13 tmp16 = tl.full([RBLOCK], 512, tl.int32) tmp17 = tmp15 + tmp16 tmp18 = tmp15 < 0 tmp19 = tl.where(tmp18, tmp17, tmp15) tl.device_assert((0 <= tmp19) & (tmp19 < 512), "index out of bounds: 0 <= tmp19 < 512") tmp21 = tl.load(in_ptr5 + (r2 + (768*tmp19)), rmask, other=0.0).to(tl.float32) tmp22 = tmp14 + tmp21 tmp23 = tmp22.to(tl.float32) tmp24 = tl.broadcast_to(tmp23, [RBLOCK]) tmp26 = tl.where(rmask, tmp24, 0) tmp27 = tl.broadcast_to(tmp24, [RBLOCK]) tmp29 = tl.where(rmask, tmp27, 0) tmp30 = triton_helpers.promote_to_tensor(tl.sum(tmp29, 0)) tmp31 = tl.full([1], 768, tl.int32) tmp32 = tmp31.to(tl.float32) tmp33 = tmp30 / tmp32 tmp34 = tmp24 - tmp33 tmp35 = tmp34 * tmp34 tmp36 = tl.broadcast_to(tmp35, [RBLOCK]) tmp38 = tl.where(rmask, tmp36, 0) tmp39 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0)) tmp40 = tmp23 - tmp33 tmp41 = 768.0 tmp42 = tmp39 / tmp41 tmp43 = 1e-12 tmp44 = tmp42 + tmp43 tmp45 = libdevice.rsqrt(tmp44) tmp46 = tmp40 * tmp45 tmp48 = tmp47.to(tl.float32) tmp49 = tmp46 * tmp48 tmp51 = tmp50.to(tl.float32) tmp52 = tmp49 + tmp51 tmp53 = tmp52.to(tl.float32) tl.store(out_ptr0 + (r2 + (768*x3)), tmp22, rmask) tl.store(out_ptr3 + (r2 + (768*x3)), tmp53, rmask) ''', device_str='cuda') import triton import triton.language as tl from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph from torch._C import _cuda_getCurrentRawStream as get_raw_stream # kernel path: /tmp/torchinductor_eellison/6q/c6qq6qfsawjfikfoeruueuup5cnzvmpzkusjkel6l6wcw43mgauj.py # Source Nodes: [attn_output], Original ATen: [aten._scaled_dot_product_efficient_attention] # attn_output => _scaled_dot_product_efficient_attention triton_poi_fused__scaled_dot_product_efficient_attention_1 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.pointwise( size_hints=[67108864], filename=__file__, triton_meta={'signature': {0: '*bf16', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_efficient_attention_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, min_elem_per_thread=0 ) @triton.jit def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 50331648 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = tl.full([XBLOCK], True, tl.int1) x0 = xindex tmp0 = tl.full([1], False, tl.int1) tmp1 = -3.3895313892515355e+38 tmp2 = 0.0 tmp3 = tl.where(tmp0, tmp1, tmp2) tl.store(out_ptr0 + (x0), tmp3, None) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/at/catipo3niy6sjxmx7lban4oumf6lplhx7qd3qxdhlcdrzt4xmv65.py # Source Nodes: [add_1, hidden_states_2], Original ATen: [aten.add, aten.native_layer_norm] # add_1 => add_4 # hidden_states_2 => add_5, add_6, convert_element_type_16, convert_element_type_17, mul_2, mul_3, rsqrt_1, sub_2, var_mean_1 triton_per_fused_add_native_layer_norm_2 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.persistent_reduction( size_hints=[8192, 1024], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_2', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel): xnumel = 8192 XBLOCK: tl.constexpr = 1 rnumel = 768 RBLOCK: tl.constexpr = 1024 xoffset = tl.program_id(0) * XBLOCK xindex = tl.full([1], xoffset, tl.int32) xmask = tl.full([RBLOCK], True, tl.int1) rindex = tl.arange(0, RBLOCK)[:] roffset = 0 rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp3 = tl.load(in_ptr2 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp2 = tmp0 + tmp1 tmp4 = tmp2 + tmp3 tmp5 = tmp4.to(tl.float32) tmp6 = tl.broadcast_to(tmp5, [RBLOCK]) tmp8 = tl.where(rmask, tmp6, 0) tmp9 = tl.broadcast_to(tmp6, [RBLOCK]) tmp11 = tl.where(rmask, tmp9, 0) tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0)) tmp13 = tl.full([1], 768, tl.int32) tmp14 = tmp13.to(tl.float32) tmp15 = tmp12 / tmp14 tmp16 = tmp6 - tmp15 tmp17 = tmp16 * tmp16 tmp18 = tl.broadcast_to(tmp17, [RBLOCK]) tmp20 = tl.where(rmask, tmp18, 0) tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0)) tmp22 = tmp5 - tmp15 tmp23 = 768.0 tmp24 = tmp21 / tmp23 tmp25 = 1e-12 tmp26 = tmp24 + tmp25 tmp27 = libdevice.rsqrt(tmp26) tmp28 = tmp22 * tmp27 tmp30 = tmp29.to(tl.float32) tmp31 = tmp28 * tmp30 tmp33 = tmp32.to(tl.float32) tmp34 = tmp31 + tmp33 tmp35 = tmp34.to(tl.float32) tl.store(out_ptr2 + (r1 + (768*x0)), tmp35, rmask) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/lp/clpuyqd2qh624e2ye7osnyieorymdus3kflfhssozom7b33frjei.py # Source Nodes: [hidden_states_4], Original ATen: [aten.gelu] # hidden_states_4 => add_7, convert_element_type_21, convert_element_type_22, erf, mul_4, mul_5, mul_6 triton_poi_fused_gelu_3 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.pointwise( size_hints=[33554432], filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_3', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, min_elem_per_thread=0 ) @triton.jit def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 25165824 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = tl.full([XBLOCK], True, tl.int1) x2 = xindex x0 = xindex % 3072 tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32) tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32) tmp2 = tmp0 + tmp1 tmp3 = tmp2.to(tl.float32) tmp4 = 0.5 tmp5 = tmp3 * tmp4 tmp6 = 0.7071067811865476 tmp7 = tmp3 * tmp6 tmp8 = libdevice.erf(tmp7) tmp9 = 1.0 tmp10 = tmp8 + tmp9 tmp11 = tmp5 * tmp10 tmp12 = tmp11.to(tl.float32) tl.store(in_out_ptr0 + (x2), tmp12, None) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/ai/caiiwspixmf6ozjdgwruqmkld4lyd2nnoyrhxzmgg3mafv6a6yka.py # Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten.gelu, aten.native_layer_norm] # hidden_states_97 => add_88, convert_element_type_295, convert_element_type_296, erf_12, mul_86, mul_87, mul_88 # hidden_states_98 => add_89, add_90, convert_element_type_297, convert_element_type_298, mul_89, mul_90, rsqrt_25, sub_26, var_mean_25 triton_per_fused_gelu_native_layer_norm_4 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.persistent_reduction( size_hints=[8192, 1024], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_gelu_native_layer_norm_4', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 4, 'num_reduction': 4, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel): xnumel = 8192 XBLOCK: tl.constexpr = 1 rnumel = 768 RBLOCK: tl.constexpr = 1024 xoffset = tl.program_id(0) * XBLOCK xindex = tl.full([1], xoffset, tl.int32) xmask = tl.full([RBLOCK], True, tl.int1) rindex = tl.arange(0, RBLOCK)[:] roffset = 0 rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask, other=0.0).to(tl.float32) tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp37 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp40 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp2 = tmp0 + tmp1 tmp3 = tmp2.to(tl.float32) tmp4 = 0.5 tmp5 = tmp3 * tmp4 tmp6 = 0.7071067811865476 tmp7 = tmp3 * tmp6 tmp8 = libdevice.erf(tmp7) tmp9 = 1.0 tmp10 = tmp8 + tmp9 tmp11 = tmp5 * tmp10 tmp12 = tmp11.to(tl.float32) tmp13 = tmp12.to(tl.float32) tmp14 = tl.broadcast_to(tmp13, [RBLOCK]) tmp16 = tl.where(rmask, tmp14, 0) tmp17 = tl.broadcast_to(tmp14, [RBLOCK]) tmp19 = tl.where(rmask, tmp17, 0) tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0)) tmp21 = tl.full([1], 768, tl.int32) tmp22 = tmp21.to(tl.float32) tmp23 = tmp20 / tmp22 tmp24 = tmp14 - tmp23 tmp25 = tmp24 * tmp24 tmp26 = tl.broadcast_to(tmp25, [RBLOCK]) tmp28 = tl.where(rmask, tmp26, 0) tmp29 = triton_helpers.promote_to_tensor(tl.sum(tmp28, 0)) tmp30 = tmp13 - tmp23 tmp31 = 768.0 tmp32 = tmp29 / tmp31 tmp33 = 1e-12 tmp34 = tmp32 + tmp33 tmp35 = libdevice.rsqrt(tmp34) tmp36 = tmp30 * tmp35 tmp38 = tmp37.to(tl.float32) tmp39 = tmp36 * tmp38 tmp41 = tmp40.to(tl.float32) tmp42 = tmp39 + tmp41 tmp43 = tmp42.to(tl.float32) tl.store(out_ptr3 + (r1 + (768*x0)), tmp43, rmask) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/fl/cfllfdzx4v3f3zccah7zu5u634j2vrlvbkru74wmaerhgadlizat.py # Source Nodes: [], Original ATen: [] triton_poi_fused_5 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.pointwise( size_hints=[33554432], filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, min_elem_per_thread=0 ) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 23445504 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = tl.full([XBLOCK], True, tl.int1) x0 = xindex % 30528 x1 = (xindex // 30528) x2 = xindex tmp0 = x0 tmp1 = tl.full([1], 0, tl.int64) tmp2 = tmp0 >= tmp1 tmp3 = tl.full([1], 30522, tl.int64) tmp4 = tmp0 < tmp3 tmp5 = tl.load(in_ptr0 + (x1 + (768*x0)), tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp6 = tmp0 >= tmp3 tmp7 = tl.full([1], 30528, tl.int64) tmp8 = tmp0 < tmp7 tmp9 = 0.0 tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype) tmp11 = tl.where(tmp6, tmp9, tmp10) tmp12 = tl.where(tmp4, tmp5, tmp11) tl.store(out_ptr0 + (x2), tmp12, None) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/si/csikk4r46efsgklpubt6iamwjm3jsevq2h2pkkvjbgcc7sj4p24c.py # Source Nodes: [], Original ATen: [] triton_poi_fused_6 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.pointwise( size_hints=[32768], filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, min_elem_per_thread=0 ) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 30528 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = x0 tmp1 = tl.full([1], 0, tl.int64) tmp2 = tmp0 >= tmp1 tmp3 = tl.full([1], 30522, tl.int64) tmp4 = tmp0 < tmp3 tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp6 = tmp0 >= tmp3 tmp7 = tl.full([1], 30528, tl.int64) tmp8 = tmp0 < tmp7 tmp9 = 0.0 tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype) tmp11 = tl.where(tmp6, tmp9, tmp10) tmp12 = tl.where(tmp4, tmp5, tmp11) tl.store(out_ptr0 + (x0), tmp12, xmask) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/z7/cz72sjk2qwjjxgfdcyk6de22sryhbhr6pptrvx4aq2hq2soai2p5.py # Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax] # masked_lm_loss => amax, convert_element_type_302, exp, sub_27, sum_1 triton_red_fused__log_softmax_7 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.reduction( size_hints=[8192, 32768], reduction_hint=ReductionHint.DEFAULT, filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax_7', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 8192 rnumel = 30522 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp3 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp1 = tmp0.to(tl.float32) tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK]) tmp4 = triton_helpers.maximum(_tmp3, tmp2) _tmp3 = tl.where(rmask, tmp4, _tmp3) tmp3 = triton_helpers.max2(_tmp3, 1)[:, None] tl.store(out_ptr0 + (x0), tmp3, None) _tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp5 = tl.load(in_ptr0 + (r1 + (30528*x0)), rmask, eviction_policy='evict_first', other=0.0).to(tl.float32) tmp6 = tmp5.to(tl.float32) tmp7 = tmp6 - tmp3 tmp8 = tl_math.exp(tmp7) tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK]) tmp11 = _tmp10 + tmp9 _tmp10 = tl.where(rmask, tmp11, _tmp10) tmp10 = tl.sum(_tmp10, 1)[:, None] tl.store(out_ptr1 + (x0), tmp10, None) ''', device_str='cuda') # kernel path: /tmp/torchinductor_eellison/jo/cjonz7vxay7gmy5qv5ie2jflm5jixsrzswexulnc4drrofc2cr64.py # Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_forward] # masked_lm_loss => convert_element_type_304, div, full_default_1, ne_1, ne_2, neg, sum_2, sum_3, where_2 triton_red_fused_nll_loss_forward_8 = async_compile.triton('triton_', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties @triton_heuristics.reduction( size_hints=[1, 8192], reduction_hint=ReductionHint.INNER, filename=__file__, triton_meta={'signature': {0: '*bf16', 1: '*i64', 2: '*bf16', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {5: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 6), equal_to_1=(5,))]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_nll_loss_forward_8', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 2, 'backend_hash': 'A4DB0438D2F455BC2DA87EC6800D76C325D555A608FFCEAD009A7F894B0D3C83', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 1 rnumel = 8192 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1) rbase = tl.arange(0, RBLOCK)[None, :] _tmp22 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) _tmp26 = tl.full([XBLOCK, RBLOCK], 0, tl.int64) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r0 = rindex tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0) tmp12 = tl.load(in_ptr2 + (r0), rmask, eviction_policy='evict_first', other=0.0) tmp14 = tl.load(in_ptr3 + (r0), rmask, eviction_policy='evict_first', other=0.0) tmp1 = tl.full([1, 1], -100, tl.int64) tmp2 = tmp0 != tmp1 tmp3 = tl.full([1, 1], 0, tl.int64) tmp4 = tl.where(tmp2, tmp0, tmp3) tmp5 = tl.full([XBLOCK, RBLOCK], 30522, tl.int32) tmp6 = tmp4 + tmp5 tmp7 = tmp4 < 0 tmp8 = tl.where(tmp7, tmp6, tmp4) tl.device_assert(((0 <= tmp8) & (tmp8 < 30522)) | ~(rmask), "index out of bounds: 0 <= tmp8 < 30522") tmp10 = tl.load(in_ptr1 + (tmp8 + (30528*r0)), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) tmp11 = tmp10.to(tl.float32) tmp13 = tmp11 - tmp12 tmp15 = tl_math.log(tmp14) tmp16 = tmp13 - tmp15 tmp17 = tmp16.to(tl.float32) tmp18 = -tmp17 tmp19 = 0.0 tmp20 = tl.where(tmp2, tmp18, tmp19) tmp21 = tl.broadcast_to(tmp20, [XBLOCK, RBLOCK]) tmp23 = _tmp22 + tmp21 _tmp22 = tl.where(rmask, tmp23, _tmp22) tmp24 = tmp2.to(tl.int64) tmp25 = tl.broadcast_to(tmp24, [XBLOCK, RBLOCK]) tmp27 = _tmp26 + tmp25 _tmp26 = tl.where(rmask, tmp27, _tmp26) tmp22 = tl.sum(_tmp22, 1)[:, None] tmp26 = tl.sum(_tmp26, 1)[:, None] tmp28 = tmp26.to(tl.float32) tmp29 = tmp22 / tmp28 tl.debug_barrier() tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp29, None) ''', device_str='cuda') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1 = args args.clear() assert_size_stride(arg0_1, (30522, 768), (768, 1)) assert_size_stride(arg1_1, (2, 768), (768, 1)) assert_size_stride(arg2_1, (512, 768), (768, 1)) assert_size_stride(arg3_1, (768, ), (1, )) assert_size_stride(arg4_1, (768, ), (1, )) assert_size_stride(arg5_1, (768, 768), (768, 1)) assert_size_stride(arg6_1, (768, ), (1, )) assert_size_stride(arg7_1, (768, 768), (768, 1)) assert_size_stride(arg8_1, (768, ), (1, )) assert_size_stride(arg9_1, (768, 768), (768, 1)) assert_size_stride(arg10_1, (768, ), (1, )) assert_size_stride(arg11_1, (768, 768), (768, 1)) assert_size_stride(arg12_1, (768, ), (1, )) assert_size_stride(arg13_1, (768, ), (1, )) assert_size_stride(arg14_1, (768, ), (1, )) assert_size_stride(arg15_1, (3072, 768), (768, 1)) assert_size_stride(arg16_1, (3072, ), (1, )) assert_size_stride(arg17_1, (768, 3072), (3072, 1)) assert_size_stride(arg18_1, (768, ), (1, )) assert_size_stride(arg19_1, (768, ), (1, )) assert_size_stride(arg20_1, (768, ), (1, )) assert_size_stride(arg21_1, (768, 768), (768, 1)) assert_size_stride(arg22_1, (768, ), (1, )) assert_size_stride(arg23_1, (768, 768), (768, 1)) assert_size_stride(arg24_1, (768, ), (1, )) assert_size_stride(arg25_1, (768, 768), (768, 1)) assert_size_stride(arg26_1, (768, ), (1, )) assert_size_stride(arg27_1, (768, 768), (768, 1)) assert_size_stride(arg28_1, (768, ), (1, )) assert_size_stride(arg29_1, (768, ), (1, )) assert_size_stride(arg30_1, (768, ), (1, )) assert_size_stride(arg31_1, (3072, 768), (768, 1)) assert_size_stride(arg32_1, (3072, ), (1, )) assert_size_stride(arg33_1, (768, 3072), (3072, 1)) assert_size_stride(arg34_1, (768, ), (1, )) assert_size_stride(arg35_1, (768, ), (1, )) assert_size_stride(arg36_1, (768, ), (1, )) assert_size_stride(arg37_1, (768, 768), (768, 1)) assert_size_stride(arg38_1, (768, ), (1, )) assert_size_stride(arg39_1, (768, 768), (768, 1)) assert_size_stride(arg40_1, (768, ), (1, )) assert_size_stride(arg41_1, (768, 768), (768, 1)) assert_size_stride(arg42_1, (768, ), (1, )) assert_size_stride(arg43_1, (768, 768), (768, 1)) assert_size_stride(arg44_1, (768, ), (1, )) assert_size_stride(arg45_1, (768, ), (1, )) assert_size_stride(arg46_1, (768, ), (1, )) assert_size_stride(arg47_1, (3072, 768), (768, 1)) assert_size_stride(arg48_1, (3072, ), (1, )) assert_size_stride(arg49_1, (768, 3072), (3072, 1)) assert_size_stride(arg50_1, (768, ), (1, )) assert_size_stride(arg51_1, (768, ), (1, )) assert_size_stride(arg52_1, (768, ), (1, )) assert_size_stride(arg53_1, (768, 768), (768, 1)) assert_size_stride(arg54_1, (768, ), (1, )) assert_size_stride(arg55_1, (768, 768), (768, 1)) assert_size_stride(arg56_1, (768, ), (1, )) assert_size_stride(arg57_1, (768, 768), (768, 1)) assert_size_stride(arg58_1, (768, ), (1, )) assert_size_stride(arg59_1, (768, 768), (768, 1)) assert_size_stride(arg60_1, (768, ), (1, )) assert_size_stride(arg61_1, (768, ), (1, )) assert_size_stride(arg62_1, (768, ), (1, )) assert_size_stride(arg63_1, (3072, 768), (768, 1)) assert_size_stride(arg64_1, (3072, ), (1, )) assert_size_stride(arg65_1, (768, 3072), (3072, 1)) assert_size_stride(arg66_1, (768, ), (1, )) assert_size_stride(arg67_1, (768, ), (1, )) assert_size_stride(arg68_1, (768, ), (1, )) assert_size_stride(arg69_1, (768, 768), (768, 1)) assert_size_stride(arg70_1, (768, ), (1, )) assert_size_stride(arg71_1, (768, 768), (768, 1)) assert_size_stride(arg72_1, (768, ), (1, )) assert_size_stride(arg73_1, (768, 768), (768, 1)) assert_size_stride(arg74_1, (768, ), (1, )) assert_size_stride(arg75_1, (768, 768), (768, 1)) assert_size_stride(arg76_1, (768, ), (1, )) assert_size_stride(arg77_1, (768, ), (1, )) assert_size_stride(arg78_1, (768, ), (1, )) assert_size_stride(arg79_1, (3072, 768), (768, 1)) assert_size_stride(arg80_1, (3072, ), (1, )) assert_size_stride(arg81_1, (768, 3072), (3072, 1)) assert_size_stride(arg82_1, (768, ), (1, )) assert_size_stride(arg83_1, (768, ), (1, )) assert_size_stride(arg84_1, (768, ), (1, )) assert_size_stride(arg85_1, (768, 768), (768, 1)) assert_size_stride(arg86_1, (768, ), (1, )) assert_size_stride(arg87_1, (768, 768), (768, 1)) assert_size_stride(arg88_1, (768, ), (1, )) assert_size_stride(arg89_1, (768, 768), (768, 1)) assert_size_stride(arg90_1, (768, ), (1, )) assert_size_stride(arg91_1, (768, 768), (768, 1)) assert_size_stride(arg92_1, (768, ), (1, )) assert_size_stride(arg93_1, (768, ), (1, )) assert_size_stride(arg94_1, (768, ), (1, )) assert_size_stride(arg95_1, (3072, 768), (768, 1)) assert_size_stride(arg96_1, (3072, ), (1, )) assert_size_stride(arg97_1, (768, 3072), (3072, 1)) assert_size_stride(arg98_1, (768, ), (1, )) assert_size_stride(arg99_1, (768, ), (1, )) assert_size_stride(arg100_1, (768, ), (1, )) assert_size_stride(arg101_1, (768, 768), (768, 1)) assert_size_stride(arg102_1, (768, ), (1, )) assert_size_stride(arg103_1, (768, 768), (768, 1)) assert_size_stride(arg104_1, (768, ), (1, )) assert_size_stride(arg105_1, (768, 768), (768, 1)) assert_size_stride(arg106_1, (768, ), (1, )) assert_size_stride(arg107_1, (768, 768), (768, 1)) assert_size_stride(arg108_1, (768, ), (1, )) assert_size_stride(arg109_1, (768, ), (1, )) assert_size_stride(arg110_1, (768, ), (1, )) assert_size_stride(arg111_1, (3072, 768), (768, 1)) assert_size_stride(arg112_1, (3072, ), (1, )) assert_size_stride(arg113_1, (768, 3072), (3072, 1)) assert_size_stride(arg114_1, (768, ), (1, )) assert_size_stride(arg115_1, (768, ), (1, )) assert_size_stride(arg116_1, (768, ), (1, )) assert_size_stride(arg117_1, (768, 768), (768, 1)) assert_size_stride(arg118_1, (768, ), (1, )) assert_size_stride(arg119_1, (768, 768), (768, 1)) assert_size_stride(arg120_1, (768, ), (1, )) assert_size_stride(arg121_1, (768, 768), (768, 1)) assert_size_stride(arg122_1, (768, ), (1, )) assert_size_stride(arg123_1, (768, 768), (768, 1)) assert_size_stride(arg124_1, (768, ), (1, )) assert_size_stride(arg125_1, (768, ), (1, )) assert_size_stride(arg126_1, (768, ), (1, )) assert_size_stride(arg127_1, (3072, 768), (768, 1)) assert_size_stride(arg128_1, (3072, ), (1, )) assert_size_stride(arg129_1, (768, 3072), (3072, 1)) assert_size_stride(arg130_1, (768, ), (1, )) assert_size_stride(arg131_1, (768, ), (1, )) assert_size_stride(arg132_1, (768, ), (1, )) assert_size_stride(arg133_1, (768, 768), (768, 1)) assert_size_stride(arg134_1, (768, ), (1, )) assert_size_stride(arg135_1, (768, 768), (768, 1)) assert_size_stride(arg136_1, (768, ), (1, )) assert_size_stride(arg137_1, (768, 768), (768, 1)) assert_size_stride(arg138_1, (768, ), (1, )) assert_size_stride(arg139_1, (768, 768), (768, 1)) assert_size_stride(arg140_1, (768, ), (1, )) assert_size_stride(arg141_1, (768, ), (1, )) assert_size_stride(arg142_1, (768, ), (1, )) assert_size_stride(arg143_1, (3072, 768), (768, 1)) assert_size_stride(arg144_1, (3072, ), (1, )) assert_size_stride(arg145_1, (768, 3072), (3072, 1)) assert_size_stride(arg146_1, (768, ), (1, )) assert_size_stride(arg147_1, (768, ), (1, )) assert_size_stride(arg148_1, (768, ), (1, )) assert_size_stride(arg149_1, (768, 768), (768, 1)) assert_size_stride(arg150_1, (768, ), (1, )) assert_size_stride(arg151_1, (768, 768), (768, 1)) assert_size_stride(arg152_1, (768, ), (1, )) assert_size_stride(arg153_1, (768, 768), (768, 1)) assert_size_stride(arg154_1, (768, ), (1, )) assert_size_stride(arg155_1, (768, 768), (768, 1)) assert_size_stride(arg156_1, (768, ), (1, )) assert_size_stride(arg157_1, (768, ), (1, )) assert_size_stride(arg158_1, (768, ), (1, )) assert_size_stride(arg159_1, (3072, 768), (768, 1)) assert_size_stride(arg160_1, (3072, ), (1, )) assert_size_stride(arg161_1, (768, 3072), (3072, 1)) assert_size_stride(arg162_1, (768, ), (1, )) assert_size_stride(arg163_1, (768, ), (1, )) assert_size_stride(arg164_1, (768, ), (1, )) assert_size_stride(arg165_1, (768, 768), (768, 1)) assert_size_stride(arg166_1, (768, ), (1, )) assert_size_stride(arg167_1, (768, 768), (768, 1)) assert_size_stride(arg168_1, (768, ), (1, )) assert_size_stride(arg169_1, (768, 768), (768, 1)) assert_size_stride(arg170_1, (768, ), (1, )) assert_size_stride(arg171_1, (768, 768), (768, 1)) assert_size_stride(arg172_1, (768, ), (1, )) assert_size_stride(arg173_1, (768, ), (1, )) assert_size_stride(arg174_1, (768, ), (1, )) assert_size_stride(arg175_1, (3072, 768), (768, 1)) assert_size_stride(arg176_1, (3072, ), (1, )) assert_size_stride(arg177_1, (768, 3072), (3072, 1)) assert_size_stride(arg178_1, (768, ), (1, )) assert_size_stride(arg179_1, (768, ), (1, )) assert_size_stride(arg180_1, (768, ), (1, )) assert_size_stride(arg181_1, (768, 768), (768, 1)) assert_size_stride(arg182_1, (768, ), (1, )) assert_size_stride(arg183_1, (768, 768), (768, 1)) assert_size_stride(arg184_1, (768, ), (1, )) assert_size_stride(arg185_1, (768, 768), (768, 1)) assert_size_stride(arg186_1, (768, ), (1, )) assert_size_stride(arg187_1, (768, 768), (768, 1)) assert_size_stride(arg188_1, (768, ), (1, )) assert_size_stride(arg189_1, (768, ), (1, )) assert_size_stride(arg190_1, (768, ), (1, )) assert_size_stride(arg191_1, (3072, 768), (768, 1)) assert_size_stride(arg192_1, (3072, ), (1, )) assert_size_stride(arg193_1, (768, 3072), (3072, 1)) assert_size_stride(arg194_1, (768, ), (1, )) assert_size_stride(arg195_1, (768, ), (1, )) assert_size_stride(arg196_1, (768, ), (1, )) assert_size_stride(arg197_1, (768, 768), (768, 1)) assert_size_stride(arg198_1, (768, ), (1, )) assert_size_stride(arg199_1, (768, ), (1, )) assert_size_stride(arg200_1, (768, ), (1, )) assert_size_stride(arg201_1, (30522, 768), (768, 1)) assert_size_stride(arg202_1, (30522, ), (1, )) assert_size_stride(arg203_1, (1, 512), (512, 1)) assert_size_stride(arg204_1, (1, 512), (512, 1)) assert_size_stride(arg205_1, (16, 512), (512, 1)) assert_size_stride(arg206_1, (16, 512), (512, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.bfloat16) buf4 = empty_strided_cuda((16, 512, 768), (393216, 768, 1), torch.bfloat16) # Source Nodes: [embeddings, embeddings_1, embeddings_2, inputs_embeds, position_embeddings, token_type_embeddings], Original ATen: [aten.add, aten.embedding, aten.native_layer_norm] stream0 = get_raw_stream(0) triton_per_fused_add_embedding_native_layer_norm_0.run(arg205_1, arg0_1, arg203_1, arg1_1, arg204_1, arg2_1, arg3_1, arg4_1, buf0, buf4, 8192, 768, grid=grid(8192), stream=stream0) del arg0_1 del arg1_1 del arg203_1 del arg204_1 del arg205_1 del arg2_1 del arg3_1 del arg4_1 buf5 = reinterpret_tensor(buf0, (8192, 768), (768, 1), 0); del buf0 # reuse # Source Nodes: [l__mod___bert_encoder_layer_0_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg6_1, reinterpret_tensor(buf4, (8192, 768), (768, 1), 0), reinterpret_tensor(arg5_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf5) del arg5_1 del arg6_1 buf6 = empty_strided_cuda((8192, 768), (768, 1), torch.bfloat16) # Source Nodes: [l__mod___bert_encoder_layer_0_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg8_1, reinterpret_tensor(buf4, (8192, 768), (768, 1), 0), reinterpret_tensor(arg7_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf6) del arg7_1 del arg8_1 buf7 = empty_strided_cuda((8192, 768), (768, 1), torch.bfloat16) # Source Nodes: [l__mod___bert_encoder_layer_0_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg10_1, reinterpret_tensor(buf4, (8192, 768), (768, 1), 0), reinterpret_tensor(arg9_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf7) del arg10_1 del arg9_1 buf8 = empty_strided_cuda((16, 12, 512, 512), (3145728, 262144, 512, 1), torch.bfloat16) # Source Nodes: [attn_output], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf8, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output], Original ATen: [aten._scaled_dot_product_efficient_attention] buf9 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf5, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf6, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf7, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf8, False) del buf5 buf10 = buf9[0] del buf9 buf14 = buf7; del buf7 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf10, (8192, 768), (768, 1), 0), reinterpret_tensor(arg11_1, (768, 768), (1, 768), 0), out=buf14) del arg11_1 buf18 = reinterpret_tensor(buf10, (16, 512, 768), (393216, 768, 1), 0); del buf10 # reuse # Source Nodes: [add_1, hidden_states_2], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf14, arg12_1, buf4, arg13_1, arg14_1, buf18, 8192, 768, grid=grid(8192), stream=stream0) del arg12_1 del arg13_1 del arg14_1 buf19 = empty_strided_cuda((8192, 3072), (3072, 1), torch.bfloat16) # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf18, (8192, 768), (768, 1), 0), reinterpret_tensor(arg15_1, (768, 3072), (1, 768), 0), out=buf19) del arg15_1 buf20 = reinterpret_tensor(buf19, (16, 512, 3072), (1572864, 3072, 1), 0); del buf19 # reuse # Source Nodes: [hidden_states_4], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf20, arg16_1, 25165824, grid=grid(25165824), stream=stream0) del arg16_1 buf21 = reinterpret_tensor(buf4, (8192, 768), (768, 1), 0); del buf4 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf20, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg17_1, (3072, 768), (1, 3072), 0), out=buf21) del arg17_1 buf25 = reinterpret_tensor(buf14, (16, 512, 768), (393216, 768, 1), 0); del buf14 # reuse # Source Nodes: [add_2, hidden_states_7], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf21, arg18_1, buf18, arg19_1, arg20_1, buf25, 8192, 768, grid=grid(8192), stream=stream0) del arg18_1 del arg19_1 del arg20_1 buf26 = buf21; del buf21 # reuse # Source Nodes: [l__mod___bert_encoder_layer_1_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg22_1, reinterpret_tensor(buf25, (8192, 768), (768, 1), 0), reinterpret_tensor(arg21_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf26) del arg21_1 del arg22_1 buf27 = reinterpret_tensor(buf18, (8192, 768), (768, 1), 0); del buf18 # reuse # Source Nodes: [l__mod___bert_encoder_layer_1_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg24_1, reinterpret_tensor(buf25, (8192, 768), (768, 1), 0), reinterpret_tensor(arg23_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf27) del arg23_1 del arg24_1 buf28 = buf6; del buf6 # reuse # Source Nodes: [l__mod___bert_encoder_layer_1_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg26_1, reinterpret_tensor(buf25, (8192, 768), (768, 1), 0), reinterpret_tensor(arg25_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf28) del arg25_1 del arg26_1 buf29 = buf8; del buf8 # reuse # Source Nodes: [attn_output_3], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf29, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_3], Original ATen: [aten._scaled_dot_product_efficient_attention] buf30 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf26, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf27, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf28, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf29, False) del buf26 buf31 = buf30[0] del buf30 buf35 = buf28; del buf28 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf31, (8192, 768), (768, 1), 0), reinterpret_tensor(arg27_1, (768, 768), (1, 768), 0), out=buf35) del arg27_1 buf39 = reinterpret_tensor(buf31, (16, 512, 768), (393216, 768, 1), 0); del buf31 # reuse # Source Nodes: [add_3, hidden_states_10], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf35, arg28_1, buf25, arg29_1, arg30_1, buf39, 8192, 768, grid=grid(8192), stream=stream0) del arg28_1 del arg29_1 del arg30_1 buf40 = reinterpret_tensor(buf20, (8192, 3072), (3072, 1), 0); del buf20 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf39, (8192, 768), (768, 1), 0), reinterpret_tensor(arg31_1, (768, 3072), (1, 768), 0), out=buf40) del arg31_1 buf41 = reinterpret_tensor(buf40, (16, 512, 3072), (1572864, 3072, 1), 0); del buf40 # reuse # Source Nodes: [hidden_states_12], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf41, arg32_1, 25165824, grid=grid(25165824), stream=stream0) del arg32_1 buf42 = buf35; del buf35 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf41, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg33_1, (3072, 768), (1, 3072), 0), out=buf42) del arg33_1 buf46 = buf25; del buf25 # reuse # Source Nodes: [add_4, hidden_states_15], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf42, arg34_1, buf39, arg35_1, arg36_1, buf46, 8192, 768, grid=grid(8192), stream=stream0) del arg34_1 del arg35_1 del arg36_1 buf47 = buf42; del buf42 # reuse # Source Nodes: [l__mod___bert_encoder_layer_2_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg38_1, reinterpret_tensor(buf46, (8192, 768), (768, 1), 0), reinterpret_tensor(arg37_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf47) del arg37_1 del arg38_1 buf48 = reinterpret_tensor(buf39, (8192, 768), (768, 1), 0); del buf39 # reuse # Source Nodes: [l__mod___bert_encoder_layer_2_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg40_1, reinterpret_tensor(buf46, (8192, 768), (768, 1), 0), reinterpret_tensor(arg39_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf48) del arg39_1 del arg40_1 buf49 = buf27; del buf27 # reuse # Source Nodes: [l__mod___bert_encoder_layer_2_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg42_1, reinterpret_tensor(buf46, (8192, 768), (768, 1), 0), reinterpret_tensor(arg41_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf49) del arg41_1 del arg42_1 buf50 = buf29; del buf29 # reuse # Source Nodes: [attn_output_6], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf50, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_6], Original ATen: [aten._scaled_dot_product_efficient_attention] buf51 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf47, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf48, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf49, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf50, False) del buf47 buf52 = buf51[0] del buf51 buf56 = buf49; del buf49 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf52, (8192, 768), (768, 1), 0), reinterpret_tensor(arg43_1, (768, 768), (1, 768), 0), out=buf56) del arg43_1 buf60 = reinterpret_tensor(buf52, (16, 512, 768), (393216, 768, 1), 0); del buf52 # reuse # Source Nodes: [add_5, hidden_states_18], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf56, arg44_1, buf46, arg45_1, arg46_1, buf60, 8192, 768, grid=grid(8192), stream=stream0) del arg44_1 del arg45_1 del arg46_1 buf61 = reinterpret_tensor(buf41, (8192, 3072), (3072, 1), 0); del buf41 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf60, (8192, 768), (768, 1), 0), reinterpret_tensor(arg47_1, (768, 3072), (1, 768), 0), out=buf61) del arg47_1 buf62 = reinterpret_tensor(buf61, (16, 512, 3072), (1572864, 3072, 1), 0); del buf61 # reuse # Source Nodes: [hidden_states_20], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf62, arg48_1, 25165824, grid=grid(25165824), stream=stream0) del arg48_1 buf63 = buf56; del buf56 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf62, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg49_1, (3072, 768), (1, 3072), 0), out=buf63) del arg49_1 buf67 = buf46; del buf46 # reuse # Source Nodes: [add_6, hidden_states_23], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf63, arg50_1, buf60, arg51_1, arg52_1, buf67, 8192, 768, grid=grid(8192), stream=stream0) del arg50_1 del arg51_1 del arg52_1 buf68 = buf63; del buf63 # reuse # Source Nodes: [l__mod___bert_encoder_layer_3_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg54_1, reinterpret_tensor(buf67, (8192, 768), (768, 1), 0), reinterpret_tensor(arg53_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf68) del arg53_1 del arg54_1 buf69 = reinterpret_tensor(buf60, (8192, 768), (768, 1), 0); del buf60 # reuse # Source Nodes: [l__mod___bert_encoder_layer_3_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg56_1, reinterpret_tensor(buf67, (8192, 768), (768, 1), 0), reinterpret_tensor(arg55_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf69) del arg55_1 del arg56_1 buf70 = buf48; del buf48 # reuse # Source Nodes: [l__mod___bert_encoder_layer_3_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg58_1, reinterpret_tensor(buf67, (8192, 768), (768, 1), 0), reinterpret_tensor(arg57_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf70) del arg57_1 del arg58_1 buf71 = buf50; del buf50 # reuse # Source Nodes: [attn_output_9], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf71, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_9], Original ATen: [aten._scaled_dot_product_efficient_attention] buf72 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf68, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf69, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf70, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf71, False) del buf68 buf73 = buf72[0] del buf72 buf77 = buf70; del buf70 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf73, (8192, 768), (768, 1), 0), reinterpret_tensor(arg59_1, (768, 768), (1, 768), 0), out=buf77) del arg59_1 buf81 = reinterpret_tensor(buf73, (16, 512, 768), (393216, 768, 1), 0); del buf73 # reuse # Source Nodes: [add_7, hidden_states_26], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf77, arg60_1, buf67, arg61_1, arg62_1, buf81, 8192, 768, grid=grid(8192), stream=stream0) del arg60_1 del arg61_1 del arg62_1 buf82 = reinterpret_tensor(buf62, (8192, 3072), (3072, 1), 0); del buf62 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf81, (8192, 768), (768, 1), 0), reinterpret_tensor(arg63_1, (768, 3072), (1, 768), 0), out=buf82) del arg63_1 buf83 = reinterpret_tensor(buf82, (16, 512, 3072), (1572864, 3072, 1), 0); del buf82 # reuse # Source Nodes: [hidden_states_28], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf83, arg64_1, 25165824, grid=grid(25165824), stream=stream0) del arg64_1 buf84 = buf77; del buf77 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf83, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg65_1, (3072, 768), (1, 3072), 0), out=buf84) del arg65_1 buf88 = buf67; del buf67 # reuse # Source Nodes: [add_8, hidden_states_31], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf84, arg66_1, buf81, arg67_1, arg68_1, buf88, 8192, 768, grid=grid(8192), stream=stream0) del arg66_1 del arg67_1 del arg68_1 buf89 = buf84; del buf84 # reuse # Source Nodes: [l__mod___bert_encoder_layer_4_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg70_1, reinterpret_tensor(buf88, (8192, 768), (768, 1), 0), reinterpret_tensor(arg69_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf89) del arg69_1 del arg70_1 buf90 = reinterpret_tensor(buf81, (8192, 768), (768, 1), 0); del buf81 # reuse # Source Nodes: [l__mod___bert_encoder_layer_4_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg72_1, reinterpret_tensor(buf88, (8192, 768), (768, 1), 0), reinterpret_tensor(arg71_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf90) del arg71_1 del arg72_1 buf91 = buf69; del buf69 # reuse # Source Nodes: [l__mod___bert_encoder_layer_4_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg74_1, reinterpret_tensor(buf88, (8192, 768), (768, 1), 0), reinterpret_tensor(arg73_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf91) del arg73_1 del arg74_1 buf92 = buf71; del buf71 # reuse # Source Nodes: [attn_output_12], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf92, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_12], Original ATen: [aten._scaled_dot_product_efficient_attention] buf93 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf89, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf90, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf91, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf92, False) del buf89 buf94 = buf93[0] del buf93 buf98 = buf91; del buf91 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf94, (8192, 768), (768, 1), 0), reinterpret_tensor(arg75_1, (768, 768), (1, 768), 0), out=buf98) del arg75_1 buf102 = reinterpret_tensor(buf94, (16, 512, 768), (393216, 768, 1), 0); del buf94 # reuse # Source Nodes: [add_9, hidden_states_34], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf98, arg76_1, buf88, arg77_1, arg78_1, buf102, 8192, 768, grid=grid(8192), stream=stream0) del arg76_1 del arg77_1 del arg78_1 buf103 = reinterpret_tensor(buf83, (8192, 3072), (3072, 1), 0); del buf83 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf102, (8192, 768), (768, 1), 0), reinterpret_tensor(arg79_1, (768, 3072), (1, 768), 0), out=buf103) del arg79_1 buf104 = reinterpret_tensor(buf103, (16, 512, 3072), (1572864, 3072, 1), 0); del buf103 # reuse # Source Nodes: [hidden_states_36], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf104, arg80_1, 25165824, grid=grid(25165824), stream=stream0) del arg80_1 buf105 = buf98; del buf98 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf104, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg81_1, (3072, 768), (1, 3072), 0), out=buf105) del arg81_1 buf109 = buf88; del buf88 # reuse # Source Nodes: [add_10, hidden_states_39], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf105, arg82_1, buf102, arg83_1, arg84_1, buf109, 8192, 768, grid=grid(8192), stream=stream0) del arg82_1 del arg83_1 del arg84_1 buf110 = buf105; del buf105 # reuse # Source Nodes: [l__mod___bert_encoder_layer_5_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg86_1, reinterpret_tensor(buf109, (8192, 768), (768, 1), 0), reinterpret_tensor(arg85_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf110) del arg85_1 del arg86_1 buf111 = reinterpret_tensor(buf102, (8192, 768), (768, 1), 0); del buf102 # reuse # Source Nodes: [l__mod___bert_encoder_layer_5_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg88_1, reinterpret_tensor(buf109, (8192, 768), (768, 1), 0), reinterpret_tensor(arg87_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf111) del arg87_1 del arg88_1 buf112 = buf90; del buf90 # reuse # Source Nodes: [l__mod___bert_encoder_layer_5_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg90_1, reinterpret_tensor(buf109, (8192, 768), (768, 1), 0), reinterpret_tensor(arg89_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf112) del arg89_1 del arg90_1 buf113 = buf92; del buf92 # reuse # Source Nodes: [attn_output_15], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf113, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_15], Original ATen: [aten._scaled_dot_product_efficient_attention] buf114 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf110, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf111, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf112, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf113, False) del buf110 buf115 = buf114[0] del buf114 buf119 = buf112; del buf112 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf115, (8192, 768), (768, 1), 0), reinterpret_tensor(arg91_1, (768, 768), (1, 768), 0), out=buf119) del arg91_1 buf123 = reinterpret_tensor(buf115, (16, 512, 768), (393216, 768, 1), 0); del buf115 # reuse # Source Nodes: [add_11, hidden_states_42], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf119, arg92_1, buf109, arg93_1, arg94_1, buf123, 8192, 768, grid=grid(8192), stream=stream0) del arg92_1 del arg93_1 del arg94_1 buf124 = reinterpret_tensor(buf104, (8192, 3072), (3072, 1), 0); del buf104 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf123, (8192, 768), (768, 1), 0), reinterpret_tensor(arg95_1, (768, 3072), (1, 768), 0), out=buf124) del arg95_1 buf125 = reinterpret_tensor(buf124, (16, 512, 3072), (1572864, 3072, 1), 0); del buf124 # reuse # Source Nodes: [hidden_states_44], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf125, arg96_1, 25165824, grid=grid(25165824), stream=stream0) del arg96_1 buf126 = buf119; del buf119 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf125, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg97_1, (3072, 768), (1, 3072), 0), out=buf126) del arg97_1 buf130 = buf109; del buf109 # reuse # Source Nodes: [add_12, hidden_states_47], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf126, arg98_1, buf123, arg99_1, arg100_1, buf130, 8192, 768, grid=grid(8192), stream=stream0) del arg100_1 del arg98_1 del arg99_1 buf131 = buf126; del buf126 # reuse # Source Nodes: [l__mod___bert_encoder_layer_6_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg102_1, reinterpret_tensor(buf130, (8192, 768), (768, 1), 0), reinterpret_tensor(arg101_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf131) del arg101_1 del arg102_1 buf132 = reinterpret_tensor(buf123, (8192, 768), (768, 1), 0); del buf123 # reuse # Source Nodes: [l__mod___bert_encoder_layer_6_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg104_1, reinterpret_tensor(buf130, (8192, 768), (768, 1), 0), reinterpret_tensor(arg103_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf132) del arg103_1 del arg104_1 buf133 = buf111; del buf111 # reuse # Source Nodes: [l__mod___bert_encoder_layer_6_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg106_1, reinterpret_tensor(buf130, (8192, 768), (768, 1), 0), reinterpret_tensor(arg105_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf133) del arg105_1 del arg106_1 buf134 = buf113; del buf113 # reuse # Source Nodes: [attn_output_18], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf134, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_18], Original ATen: [aten._scaled_dot_product_efficient_attention] buf135 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf131, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf132, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf133, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf134, False) del buf131 buf136 = buf135[0] del buf135 buf140 = buf133; del buf133 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf136, (8192, 768), (768, 1), 0), reinterpret_tensor(arg107_1, (768, 768), (1, 768), 0), out=buf140) del arg107_1 buf144 = reinterpret_tensor(buf136, (16, 512, 768), (393216, 768, 1), 0); del buf136 # reuse # Source Nodes: [add_13, hidden_states_50], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf140, arg108_1, buf130, arg109_1, arg110_1, buf144, 8192, 768, grid=grid(8192), stream=stream0) del arg108_1 del arg109_1 del arg110_1 buf145 = reinterpret_tensor(buf125, (8192, 3072), (3072, 1), 0); del buf125 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf144, (8192, 768), (768, 1), 0), reinterpret_tensor(arg111_1, (768, 3072), (1, 768), 0), out=buf145) del arg111_1 buf146 = reinterpret_tensor(buf145, (16, 512, 3072), (1572864, 3072, 1), 0); del buf145 # reuse # Source Nodes: [hidden_states_52], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf146, arg112_1, 25165824, grid=grid(25165824), stream=stream0) del arg112_1 buf147 = buf140; del buf140 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf146, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg113_1, (3072, 768), (1, 3072), 0), out=buf147) del arg113_1 buf151 = buf130; del buf130 # reuse # Source Nodes: [add_14, hidden_states_55], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf147, arg114_1, buf144, arg115_1, arg116_1, buf151, 8192, 768, grid=grid(8192), stream=stream0) del arg114_1 del arg115_1 del arg116_1 buf152 = buf147; del buf147 # reuse # Source Nodes: [l__mod___bert_encoder_layer_7_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg118_1, reinterpret_tensor(buf151, (8192, 768), (768, 1), 0), reinterpret_tensor(arg117_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf152) del arg117_1 del arg118_1 buf153 = reinterpret_tensor(buf144, (8192, 768), (768, 1), 0); del buf144 # reuse # Source Nodes: [l__mod___bert_encoder_layer_7_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg120_1, reinterpret_tensor(buf151, (8192, 768), (768, 1), 0), reinterpret_tensor(arg119_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf153) del arg119_1 del arg120_1 buf154 = buf132; del buf132 # reuse # Source Nodes: [l__mod___bert_encoder_layer_7_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg122_1, reinterpret_tensor(buf151, (8192, 768), (768, 1), 0), reinterpret_tensor(arg121_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf154) del arg121_1 del arg122_1 buf155 = buf134; del buf134 # reuse # Source Nodes: [attn_output_21], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf155, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_21], Original ATen: [aten._scaled_dot_product_efficient_attention] buf156 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf152, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf153, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf154, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf155, False) del buf152 buf157 = buf156[0] del buf156 buf161 = buf154; del buf154 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf157, (8192, 768), (768, 1), 0), reinterpret_tensor(arg123_1, (768, 768), (1, 768), 0), out=buf161) del arg123_1 buf165 = reinterpret_tensor(buf157, (16, 512, 768), (393216, 768, 1), 0); del buf157 # reuse # Source Nodes: [add_15, hidden_states_58], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf161, arg124_1, buf151, arg125_1, arg126_1, buf165, 8192, 768, grid=grid(8192), stream=stream0) del arg124_1 del arg125_1 del arg126_1 buf166 = reinterpret_tensor(buf146, (8192, 3072), (3072, 1), 0); del buf146 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf165, (8192, 768), (768, 1), 0), reinterpret_tensor(arg127_1, (768, 3072), (1, 768), 0), out=buf166) del arg127_1 buf167 = reinterpret_tensor(buf166, (16, 512, 3072), (1572864, 3072, 1), 0); del buf166 # reuse # Source Nodes: [hidden_states_60], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf167, arg128_1, 25165824, grid=grid(25165824), stream=stream0) del arg128_1 buf168 = buf161; del buf161 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf167, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg129_1, (3072, 768), (1, 3072), 0), out=buf168) del arg129_1 buf172 = buf151; del buf151 # reuse # Source Nodes: [add_16, hidden_states_63], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf168, arg130_1, buf165, arg131_1, arg132_1, buf172, 8192, 768, grid=grid(8192), stream=stream0) del arg130_1 del arg131_1 del arg132_1 buf173 = buf168; del buf168 # reuse # Source Nodes: [l__mod___bert_encoder_layer_8_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg134_1, reinterpret_tensor(buf172, (8192, 768), (768, 1), 0), reinterpret_tensor(arg133_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf173) del arg133_1 del arg134_1 buf174 = reinterpret_tensor(buf165, (8192, 768), (768, 1), 0); del buf165 # reuse # Source Nodes: [l__mod___bert_encoder_layer_8_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg136_1, reinterpret_tensor(buf172, (8192, 768), (768, 1), 0), reinterpret_tensor(arg135_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf174) del arg135_1 del arg136_1 buf175 = buf153; del buf153 # reuse # Source Nodes: [l__mod___bert_encoder_layer_8_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg138_1, reinterpret_tensor(buf172, (8192, 768), (768, 1), 0), reinterpret_tensor(arg137_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf175) del arg137_1 del arg138_1 buf176 = buf155; del buf155 # reuse # Source Nodes: [attn_output_24], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf176, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_24], Original ATen: [aten._scaled_dot_product_efficient_attention] buf177 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf173, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf174, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf175, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf176, False) del buf173 buf178 = buf177[0] del buf177 buf182 = buf175; del buf175 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf178, (8192, 768), (768, 1), 0), reinterpret_tensor(arg139_1, (768, 768), (1, 768), 0), out=buf182) del arg139_1 buf186 = reinterpret_tensor(buf178, (16, 512, 768), (393216, 768, 1), 0); del buf178 # reuse # Source Nodes: [add_17, hidden_states_66], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf182, arg140_1, buf172, arg141_1, arg142_1, buf186, 8192, 768, grid=grid(8192), stream=stream0) del arg140_1 del arg141_1 del arg142_1 buf187 = reinterpret_tensor(buf167, (8192, 3072), (3072, 1), 0); del buf167 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf186, (8192, 768), (768, 1), 0), reinterpret_tensor(arg143_1, (768, 3072), (1, 768), 0), out=buf187) del arg143_1 buf188 = reinterpret_tensor(buf187, (16, 512, 3072), (1572864, 3072, 1), 0); del buf187 # reuse # Source Nodes: [hidden_states_68], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf188, arg144_1, 25165824, grid=grid(25165824), stream=stream0) del arg144_1 buf189 = buf182; del buf182 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf188, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg145_1, (3072, 768), (1, 3072), 0), out=buf189) del arg145_1 buf193 = buf172; del buf172 # reuse # Source Nodes: [add_18, hidden_states_71], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf189, arg146_1, buf186, arg147_1, arg148_1, buf193, 8192, 768, grid=grid(8192), stream=stream0) del arg146_1 del arg147_1 del arg148_1 buf194 = buf189; del buf189 # reuse # Source Nodes: [l__mod___bert_encoder_layer_9_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg150_1, reinterpret_tensor(buf193, (8192, 768), (768, 1), 0), reinterpret_tensor(arg149_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf194) del arg149_1 del arg150_1 buf195 = reinterpret_tensor(buf186, (8192, 768), (768, 1), 0); del buf186 # reuse # Source Nodes: [l__mod___bert_encoder_layer_9_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg152_1, reinterpret_tensor(buf193, (8192, 768), (768, 1), 0), reinterpret_tensor(arg151_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf195) del arg151_1 del arg152_1 buf196 = buf174; del buf174 # reuse # Source Nodes: [l__mod___bert_encoder_layer_9_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg154_1, reinterpret_tensor(buf193, (8192, 768), (768, 1), 0), reinterpret_tensor(arg153_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf196) del arg153_1 del arg154_1 buf197 = buf176; del buf176 # reuse # Source Nodes: [attn_output_27], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf197, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_27], Original ATen: [aten._scaled_dot_product_efficient_attention] buf198 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf194, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf195, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf196, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf197, False) del buf194 buf199 = buf198[0] del buf198 buf203 = buf196; del buf196 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf199, (8192, 768), (768, 1), 0), reinterpret_tensor(arg155_1, (768, 768), (1, 768), 0), out=buf203) del arg155_1 buf207 = reinterpret_tensor(buf199, (16, 512, 768), (393216, 768, 1), 0); del buf199 # reuse # Source Nodes: [add_19, hidden_states_74], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf203, arg156_1, buf193, arg157_1, arg158_1, buf207, 8192, 768, grid=grid(8192), stream=stream0) del arg156_1 del arg157_1 del arg158_1 buf208 = reinterpret_tensor(buf188, (8192, 3072), (3072, 1), 0); del buf188 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf207, (8192, 768), (768, 1), 0), reinterpret_tensor(arg159_1, (768, 3072), (1, 768), 0), out=buf208) del arg159_1 buf209 = reinterpret_tensor(buf208, (16, 512, 3072), (1572864, 3072, 1), 0); del buf208 # reuse # Source Nodes: [hidden_states_76], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf209, arg160_1, 25165824, grid=grid(25165824), stream=stream0) del arg160_1 buf210 = buf203; del buf203 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf209, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg161_1, (3072, 768), (1, 3072), 0), out=buf210) del arg161_1 buf214 = buf193; del buf193 # reuse # Source Nodes: [add_20, hidden_states_79], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf210, arg162_1, buf207, arg163_1, arg164_1, buf214, 8192, 768, grid=grid(8192), stream=stream0) del arg162_1 del arg163_1 del arg164_1 buf215 = buf210; del buf210 # reuse # Source Nodes: [l__mod___bert_encoder_layer_10_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg166_1, reinterpret_tensor(buf214, (8192, 768), (768, 1), 0), reinterpret_tensor(arg165_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf215) del arg165_1 del arg166_1 buf216 = reinterpret_tensor(buf207, (8192, 768), (768, 1), 0); del buf207 # reuse # Source Nodes: [l__mod___bert_encoder_layer_10_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg168_1, reinterpret_tensor(buf214, (8192, 768), (768, 1), 0), reinterpret_tensor(arg167_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf216) del arg167_1 del arg168_1 buf217 = buf195; del buf195 # reuse # Source Nodes: [l__mod___bert_encoder_layer_10_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg170_1, reinterpret_tensor(buf214, (8192, 768), (768, 1), 0), reinterpret_tensor(arg169_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf217) del arg169_1 del arg170_1 buf218 = buf197; del buf197 # reuse # Source Nodes: [attn_output_30], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf218, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_30], Original ATen: [aten._scaled_dot_product_efficient_attention] buf219 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf215, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf216, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf217, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf218, False) del buf215 buf220 = buf219[0] del buf219 buf224 = buf217; del buf217 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf220, (8192, 768), (768, 1), 0), reinterpret_tensor(arg171_1, (768, 768), (1, 768), 0), out=buf224) del arg171_1 buf228 = reinterpret_tensor(buf220, (16, 512, 768), (393216, 768, 1), 0); del buf220 # reuse # Source Nodes: [add_21, hidden_states_82], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf224, arg172_1, buf214, arg173_1, arg174_1, buf228, 8192, 768, grid=grid(8192), stream=stream0) del arg172_1 del arg173_1 del arg174_1 buf229 = reinterpret_tensor(buf209, (8192, 3072), (3072, 1), 0); del buf209 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf228, (8192, 768), (768, 1), 0), reinterpret_tensor(arg175_1, (768, 3072), (1, 768), 0), out=buf229) del arg175_1 buf230 = reinterpret_tensor(buf229, (16, 512, 3072), (1572864, 3072, 1), 0); del buf229 # reuse # Source Nodes: [hidden_states_84], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf230, arg176_1, 25165824, grid=grid(25165824), stream=stream0) del arg176_1 buf231 = buf224; del buf224 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf230, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg177_1, (3072, 768), (1, 3072), 0), out=buf231) del arg177_1 buf235 = buf214; del buf214 # reuse # Source Nodes: [add_22, hidden_states_87], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf231, arg178_1, buf228, arg179_1, arg180_1, buf235, 8192, 768, grid=grid(8192), stream=stream0) del arg178_1 del arg179_1 del arg180_1 buf236 = buf231; del buf231 # reuse # Source Nodes: [l__mod___bert_encoder_layer_11_attention_self_query], Original ATen: [aten.addmm] extern_kernels.addmm(arg182_1, reinterpret_tensor(buf235, (8192, 768), (768, 1), 0), reinterpret_tensor(arg181_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf236) del arg181_1 del arg182_1 buf237 = reinterpret_tensor(buf228, (8192, 768), (768, 1), 0); del buf228 # reuse # Source Nodes: [l__mod___bert_encoder_layer_11_attention_self_key], Original ATen: [aten.addmm] extern_kernels.addmm(arg184_1, reinterpret_tensor(buf235, (8192, 768), (768, 1), 0), reinterpret_tensor(arg183_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf237) del arg183_1 del arg184_1 buf238 = buf216; del buf216 # reuse # Source Nodes: [l__mod___bert_encoder_layer_11_attention_self_value], Original ATen: [aten.addmm] extern_kernels.addmm(arg186_1, reinterpret_tensor(buf235, (8192, 768), (768, 1), 0), reinterpret_tensor(arg185_1, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf238) del arg185_1 del arg186_1 buf239 = buf218; del buf218 # reuse # Source Nodes: [attn_output_33], Original ATen: [aten._scaled_dot_product_efficient_attention] triton_poi_fused__scaled_dot_product_efficient_attention_1.run(buf239, 50331648, grid=grid(50331648), stream=stream0) # Source Nodes: [attn_output_33], Original ATen: [aten._scaled_dot_product_efficient_attention] buf240 = torch.ops.aten._scaled_dot_product_efficient_attention.default(reinterpret_tensor(buf236, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf237, (16, 12, 512, 64), (393216, 64, 768, 1), 0), reinterpret_tensor(buf238, (16, 12, 512, 64), (393216, 64, 768, 1), 0), buf239, False) del buf236 del buf237 del buf239 buf241 = buf240[0] del buf240 buf245 = buf238; del buf238 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf241, (8192, 768), (768, 1), 0), reinterpret_tensor(arg187_1, (768, 768), (1, 768), 0), out=buf245) del arg187_1 buf249 = reinterpret_tensor(buf241, (16, 512, 768), (393216, 768, 1), 0); del buf241 # reuse # Source Nodes: [add_23, hidden_states_90], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf245, arg188_1, buf235, arg189_1, arg190_1, buf249, 8192, 768, grid=grid(8192), stream=stream0) del arg188_1 del arg189_1 del arg190_1 buf250 = reinterpret_tensor(buf230, (8192, 3072), (3072, 1), 0); del buf230 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf249, (8192, 768), (768, 1), 0), reinterpret_tensor(arg191_1, (768, 3072), (1, 768), 0), out=buf250) del arg191_1 buf251 = reinterpret_tensor(buf250, (16, 512, 3072), (1572864, 3072, 1), 0); del buf250 # reuse # Source Nodes: [hidden_states_92], Original ATen: [aten.gelu] triton_poi_fused_gelu_3.run(buf251, arg192_1, 25165824, grid=grid(25165824), stream=stream0) del arg192_1 buf252 = buf245; del buf245 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf251, (8192, 3072), (3072, 1), 0), reinterpret_tensor(arg193_1, (3072, 768), (1, 3072), 0), out=buf252) del arg193_1 del buf251 buf256 = buf235; del buf235 # reuse # Source Nodes: [add_24, hidden_states_95], Original ATen: [aten.add, aten.native_layer_norm] triton_per_fused_add_native_layer_norm_2.run(buf252, arg194_1, buf249, arg195_1, arg196_1, buf256, 8192, 768, grid=grid(8192), stream=stream0) del arg194_1 del arg195_1 del arg196_1 del buf249 buf257 = buf252; del buf252 # reuse # Source Nodes: [], Original ATen: [] extern_kernels.mm(reinterpret_tensor(buf256, (8192, 768), (768, 1), 0), reinterpret_tensor(arg197_1, (768, 768), (1, 768), 0), out=buf257) del arg197_1 buf262 = buf256; del buf256 # reuse # Source Nodes: [hidden_states_97, hidden_states_98], Original ATen: [aten.gelu, aten.native_layer_norm] triton_per_fused_gelu_native_layer_norm_4.run(buf257, arg198_1, arg199_1, arg200_1, buf262, 8192, 768, grid=grid(8192), stream=stream0) del arg198_1 del arg199_1 del arg200_1 del buf257 buf263 = empty_strided_cuda((768, 30528), (30528, 1), torch.bfloat16) # Source Nodes: [], Original ATen: [] triton_poi_fused_5.run(arg201_1, buf263, 23445504, grid=grid(23445504), stream=stream0) del arg201_1 buf264 = empty_strided_cuda((30528, ), (1, ), torch.bfloat16) # Source Nodes: [], Original ATen: [] triton_poi_fused_6.run(arg202_1, buf264, 30528, grid=grid(30528), stream=stream0) del arg202_1 buf265 = empty_strided_cuda((8192, 30528), (30528, 1), torch.bfloat16) # Source Nodes: [], Original ATen: [] extern_kernels.addmm(buf264, reinterpret_tensor(buf262, (8192, 768), (768, 1), 0), buf263, alpha=1, beta=1, out=buf265) del buf262 del buf263 del buf264 buf266 = empty_strided_cuda((8192, 1), (1, 8192), torch.float32) buf267 = empty_strided_cuda((8192, 1), (1, 8192), torch.float32) # Source Nodes: [masked_lm_loss], Original ATen: [aten._log_softmax] triton_red_fused__log_softmax_7.run(buf265, buf266, buf267, 8192, 30522, grid=grid(8192), stream=stream0) buf268 = empty_strided_cuda((), (), torch.bfloat16) buf270 = buf268; del buf268 # reuse # Source Nodes: [masked_lm_loss], Original ATen: [aten.nll_loss_forward] triton_red_fused_nll_loss_forward_8.run(buf270, arg206_1, buf265, buf266, buf267, 1, 8192, grid=grid(1), stream=stream0) del arg206_1 del buf266 del buf267 return (buf270, reinterpret_tensor(buf265, (16, 512, 30522), (15630336, 30528, 1), 0), ) def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg1_1 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg2_1 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg3_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg4_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg5_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg6_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg7_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg8_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg9_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg10_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg11_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg12_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg13_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg14_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg15_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg16_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg17_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg18_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg19_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg20_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg21_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg22_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg23_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg24_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg25_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg26_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg27_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg28_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg29_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg30_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg31_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg32_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg33_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg34_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg35_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg36_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg37_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg38_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg39_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg40_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg41_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg42_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg43_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg44_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg45_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg46_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg47_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg48_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg49_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg50_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg51_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg52_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg53_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg54_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg55_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg56_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg57_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg58_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg59_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg60_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg61_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg62_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg63_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg64_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg65_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg66_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg67_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg68_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg69_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg70_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg71_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg72_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg73_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg74_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg75_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg76_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg77_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg78_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg79_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg80_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg81_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg82_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg83_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg84_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg85_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg86_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg87_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg88_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg89_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg90_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg91_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg92_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg93_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg94_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg95_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg96_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg97_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg98_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg99_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg100_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg101_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg102_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg103_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg104_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg105_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg106_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg107_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg108_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg109_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg110_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg111_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg112_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg113_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg114_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg115_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg116_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg117_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg118_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg119_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg120_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg121_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg122_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg123_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg124_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg125_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg126_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg127_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg128_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg129_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg130_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg131_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg132_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg133_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg134_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg135_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg136_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg137_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg138_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg139_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg140_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg141_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg142_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg143_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg144_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg145_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg146_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg147_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg148_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg149_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg150_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg151_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg152_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg153_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg154_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg155_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg156_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg157_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg158_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg159_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg160_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg161_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg162_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg163_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg164_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg165_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg166_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg167_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg168_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg169_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg170_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg171_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg172_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg173_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg174_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg175_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg176_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg177_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg178_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg179_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg180_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg181_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg182_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg183_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg184_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg185_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg186_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg187_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg188_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg189_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg190_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg191_1 = rand_strided((3072, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg192_1 = rand_strided((3072, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg193_1 = rand_strided((768, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) arg194_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg195_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg196_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg197_1 = rand_strided((768, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg198_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg199_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg200_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg201_1 = rand_strided((30522, 768), (768, 1), device='cuda:0', dtype=torch.bfloat16) arg202_1 = rand_strided((30522, ), (1, ), device='cuda:0', dtype=torch.bfloat16) arg203_1 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) arg204_1 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64) arg205_1 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) arg206_1 = rand_strided((16, 512), (512, 1), device='cuda:0', dtype=torch.int64) fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1]) return print_performance(fn, times=times, repeat=repeat) if __name__ == "__main__": from torch._inductor.wrapper_benchmark import compiled_module_main compiled_module_main('BertForMaskedLM', benchmark_compiled_module)
비교하기