02/09 + patch vs. 03/01 + patch
370 lines
# AOT ID: ['0_backward']
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long, c_int
from ctypes import c_void_p, c_long, c_int
import torch
import torch
import math
import math
import random
import random
import os
import os
import tempfile
import tempfile
from math import inf, nan
from math import inf, nan
from cmath import nanj
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton
import triton.language as tl
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
from torch._inductor.runtime.triton_heuristics import (
grid,
grid,
split_scan_grid,
split_scan_grid,
grid_combo_kernels,
grid_combo_kernels,
start_graph,
start_graph,
end_graph,
end_graph,
cooperative_reduction_grid,
cooperative_reduction_grid,
)
)
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch._inductor.kernel.flex_attention
import torch._inductor.kernel.flex_attention
aten = torch.ops.aten
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_root/6x/c6xa3pwry3z4jzo3u3cc6q3htotswajxyoqfcdvn5dgb3wcu3hbv.py
# kernel path: /tmp/torchinductor_root/6x/c6xa3pwry3z4jzo3u3cc6q3htotswajxyoqfcdvn5dgb3wcu3hbv.py
# Topologically Sorted Source Nodes: [loss, logits, mul_176, square_16, add_116, rsqrt, mul_177], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward, aten._to_copy, aten.mul, aten.pow, aten.add, aten.rsqrt, aten._log_softmax, aten._log_softmax_backward_data]
# Topologically Sorted Source Nodes: [loss, logits, mul_176, square_16, add_116, rsqrt, mul_177], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward, aten._to_copy, aten.mul, aten.pow, aten.add, aten.rsqrt, aten._log_softmax, aten._log_softmax_backward_data]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# add_116 => add_179
# add_116 => add_179
# logits => convert_element_type_323
# logits => convert_element_type_323
# loss => full_default_12, full_default_13, sub_4, sub_5
# loss => full_default_12, full_default_13, sub_4, sub_5
# mul_176 => mul_239
# mul_176 => mul_239
# mul_177 => mul_240
# mul_177 => mul_240
# rsqrt => rsqrt_63
# rsqrt => rsqrt_63
# square_16 => pow_80
# square_16 => pow_80
# Graph fragment:
# Graph fragment:
# %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%tangents_1, %convert_element_type_324), kwargs = {})
# %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%tangents_1, %convert_element_type_324), kwargs = {})
# %ne_3 : [num_users=2] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_164, -100), kwargs = {})
# %ne_3 : [num_users=2] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_164, -100), kwargs = {})
# %full_default_12 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
# %full_default_12 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_6 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_164, %full_default_12), kwargs = {})
# %where_6 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_164, %full_default_12), kwargs = {})
# %scatter_upon_const_tensor : [num_users=1] = call_function[target=torch._inductor.fx_passes.post_grad.scatter_upon_const_tensor](args = (), kwargs = {shape: [65536, 50304], background_val: 0, dtype: torch.float32, dim: 1, selector: %where_6, val: -1.0})
# %scatter_upon_const_tensor : [num_users=1] = call_function[target=torch._inductor.fx_passes.post_grad.scatter_upon_const_tensor](args = (), kwargs = {shape: [65536, 50304], background_val: 0, dtype: torch.float32, dim: 1, selector: %where_6, val: -1.0})
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_7 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %div_2, %full_default_13), kwargs = {})
# %where_7 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %div_2, %full_default_13), kwargs = {})
# %mul_241 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter_upon_const_tensor, %where_7), kwargs = {})
# %mul_241 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter_upon_const_tensor, %where_7), kwargs = {})
# %convert_element_type_323 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_62, torch.float32), kwargs = {})
# %convert_element_type_323 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_62, torch.float32), kwargs = {})
# %mul_239 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_323, 15), kwargs = {})
# %mul_239 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_323, 15), kwargs = {})
# %pow_80 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 2), kwargs = {})
# %pow_80 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 2), kwargs = {})
# %add_179 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%pow_80, 225), kwargs = {})
# %add_179 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%pow_80, 225), kwargs = {})
# %rsqrt_63 : [num_users=3] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_179,), kwargs = {})
# %rsqrt_63 : [num_users=3] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_179,), kwargs = {})
# %mul_240 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_239, %rsqrt_63), kwargs = {})
# %mul_240 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_239, %rsqrt_63), kwargs = {})
# %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_240, %amax), kwargs = {})
# %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_240, %amax), kwargs = {})
# %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_4, %log), kwargs = {})
# %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_4, %log), kwargs = {})
# %exp_1 : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_5,), kwargs = {})
# %exp_1 : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_5,), kwargs = {})
# %sum_10 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_241, [1], True), kwargs = {})
# %sum_10 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_241, [1], True), kwargs = {})
# %mul_242 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp_1, %sum_10), kwargs = {})
# %mul_242 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp_1, %sum_10), kwargs = {})
# %sub_6 : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_241, %mul_242), kwargs = {})
# %sub_6 : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_241, %mul_242), kwargs = {})
# %mul_243 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %mul_239), kwargs = {})
# %mul_243 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %mul_239), kwargs = {})
# %mul_244 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %rsqrt_63), kwargs = {})
# %mul_244 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %rsqrt_63), kwargs = {})
# %pow_81 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt_63, 3), kwargs = {})
# %pow_81 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt_63, 3), kwargs = {})
# %mul_245 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%mul_243, -0.5), kwargs = {})
# %mul_245 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%mul_243, -0.5), kwargs = {})
# %mul_246 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_245, %pow_81), kwargs = {})
# %mul_246 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_245, %pow_81), kwargs = {})
# %pow_82 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 1.0), kwargs = {})
# %pow_82 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 1.0), kwargs = {})
# %mul_247 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_82, 2.0), kwargs = {})
# %mul_247 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_82, 2.0), kwargs = {})
# %mul_248 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_246, %mul_247), kwargs = {})
# %mul_248 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_246, %mul_247), kwargs = {})
# %mul_249 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_244, 15), kwargs = {})
# %mul_249 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_244, 15), kwargs = {})
# %add_180 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_248, %mul_249), kwargs = {})
# %add_180 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_248, %mul_249), kwargs = {})
# %convert_element_type_325 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_180, torch.bfloat16), kwargs = {})
# %convert_element_type_325 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_180, torch.bfloat16), kwargs = {})
triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0 = async_compile.triton('triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', '''
triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0 = async_compile.triton('triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
@triton_heuristics.reduction(
size_hints={'x': 65536, 'r0_': 65536},
size_hints={'x': 65536, 'r0_': 65536},
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'xnumel': 'i64', 'r0_numel': 'i64'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'xnumel': 'i64', 'r0_numel': 'i64'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 8, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 8, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
)
@triton.jit
@triton.jit
def triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
def triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 65536
xnumel = 65536
r0_numel = 50304
r0_numel = 50304
rnumel = r0_numel
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
rbase = r0_base
rbase = r0_base
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr1 + (0))
tmp10 = tl.load(in_ptr1 + (0))
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
tmp12 = tl.load(in_ptr2 + (0))
tmp12 = tl.load(in_ptr2 + (0))
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
_tmp18 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
_tmp18 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
r0_mask = r0_index < r0_numel
roffset = r0_offset
roffset = r0_offset
rindex = r0_index
rindex = r0_index
r0_1 = r0_index
r0_1 = r0_index
tmp1 = tl.full([1, 1], -100, tl.int64)
tmp1 = tl.full([1, 1], -100, tl.int64)
tmp2 = tmp0 != tmp1
tmp2 = tmp0 != tmp1
tmp3 = tl.full([1, 1], 0, tl.int64)
tmp3 = tl.full([1, 1], 0, tl.int64)
tmp4 = tl.where(tmp2, tmp0, tmp3)
tmp4 = tl.where(tmp2, tmp0, tmp3)
tmp5 = r0_1
tmp5 = r0_1
tmp6 = tmp4 == tmp5
tmp6 = tmp4 == tmp5
tmp7 = -1.0
tmp7 = -1.0
tmp8 = 0.0
tmp8 = 0.0
tmp9 = tl.where(tmp6, tmp7, tmp8)
tmp9 = tl.where(tmp6, tmp7, tmp8)
tmp14 = tmp11 / tmp13
tmp14 = tmp11 / tmp13
tmp15 = tl.where(tmp2, tmp14, tmp8)
tmp15 = tl.where(tmp2, tmp14, tmp8)
tmp16 = tmp9 * tmp15
tmp16 = tmp9 * tmp15
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
tmp19 = _tmp18 + tmp17
tmp19 = _tmp18 + tmp17
_tmp18 = tl.where(r0_mask, tmp19, _tmp18)
_tmp18 = tl.where(r0_mask, tmp19, _tmp18)
tmp18 = tl.sum(_tmp18, 1)[:, None]
tmp18 = tl.sum(_tmp18, 1)[:, None]
tmp29 = tl.load(in_ptr1 + (0))
tmp29 = tl.load(in_ptr1 + (0))
tmp30 = tl.broadcast_to(tmp29, [XBLOCK, R0_BLOCK])
tmp30 = tl.broadcast_to(tmp29, [XBLOCK, R0_BLOCK])
tmp31 = tl.load(in_ptr2 + (0))
tmp31 = tl.load(in_ptr2 + (0))
tmp32 = tl.broadcast_to(tmp31, [XBLOCK, R0_BLOCK])
tmp32 = tl.broadcast_to(tmp31, [XBLOCK, R0_BLOCK])
tmp45 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp45 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp47 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp47 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
r0_mask = r0_index < r0_numel
roffset = r0_offset
roffset = r0_offset
rindex = r0_index
rindex = r0_index
r0_1 = r0_index
r0_1 = r0_index
tmp36 = tl.load(in_out_ptr0 + (r0_1 + 50304*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp36 = tl.load(in_out_ptr0 + (r0_1 + 50304*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tl.full([1, 1], -100, tl.int64)
tmp20 = tl.full([1, 1], -100, tl.int64)
tmp21 = tmp0 != tmp20
tmp21 = tmp0 != tmp20
tmp22 = tl.full([1, 1], 0, tl.int64)
tmp22 = tl.full([1, 1], 0, tl.int64)
tmp23 = tl.where(tmp21, tmp0, tmp22)
tmp23 = tl.where(tmp21, tmp0, tmp22)
tmp24 = r0_1
tmp24 = r0_1
tmp25 = tmp23 == tmp24
tmp25 = tmp23 == tmp24
tmp26 = -1.0
tmp26 = -1.0
tmp27 = 0.0
tmp27 = 0.0
tmp28 = tl.where(tmp25, tmp26, tmp27)
tmp28 = tl.where(tmp25, tmp26, tmp27)
tmp33 = tmp30 / tmp32
tmp33 = tmp30 / tmp32
tmp34 = tl.where(tmp21, tmp33, tmp27)
tmp34 = tl.where(tmp21, tmp33, tmp27)
tmp35 = tmp28 * tmp34
tmp35 = tmp28 * tmp34
tmp37 = tmp36.to(tl.float32)
tmp37 = tmp36.to(tl.float32)
tmp38 = 15.0
tmp38 = 15.0
tmp39 = tmp37 * tmp38
tmp39 = tmp37 * tmp38
tmp40 = tmp37 * tmp37
tmp40 = tmp37 * tmp37
tmp41 = 225.0
tmp41 = 225.0
tmp42 = tmp40 + tmp41
tmp42 = tmp40 + tmp41
tmp43 = libdevice.rsqrt(tmp42)
tmp43 = libdevice.rsqrt(tmp42)
tmp44 = tmp39 * tmp43
tmp44 = tmp39 * tmp43
tmp46 = tmp44 - tmp45
tmp46 = tmp44 - tmp45
tmp48 = tmp46 - tmp47
tmp48 = tmp46 - tmp47
tmp49 = tl_math.exp(tmp48)
tmp49 = tl_math.exp(tmp48)
tmp50 = tmp49 * tmp18
tmp50 = tmp49 * tmp18
tmp51 = tmp35 - tmp50
tmp51 = tmp35 - tmp50
tmp52 = tmp51 * tmp39
tmp52 = tmp51 * tmp39
tmp53 = -0.5
tmp53 = -0.5
tmp54 = tmp52 * tmp53
tmp54 = tmp52 * tmp53
tmp55 = tmp43 * tmp43
tmp55 = tmp43 * tmp43
tmp56 = tmp55 * tmp43
tmp56 = tmp55 * tmp43
tmp57 = tmp54 * tmp56
tmp57 = tmp54 * tmp56
tmp58 = 2.0
tmp58 = 2.0
tmp59 = tmp37 * tmp58
tmp59 = tmp37 * tmp58
tmp60 = tmp57 * tmp59
tmp60 = tmp57 * tmp59
tmp61 = tmp51 * tmp43
tmp61 = tmp51 * tmp43
tmp62 = tmp61 * tmp38
tmp62 = tmp61 * tmp38
tmp63 = tmp60 + tmp62
tmp63 = tmp60 + tmp62
tmp64 = tmp63.to(tl.float32)
tmp64 = tmp63.to(tl.float32)
tl.store(in_out_ptr0 + (r0_1 + 50304*x0), tmp64, r0_mask)
tl.store(in_out_ptr0 + (r0_1 + 50304*x0), tmp64, r0_mask)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/dr/cdr4v4mkzyrsy2jewbwgoh2vvxxwk7mm64ykaxpfkv6jlvpbdlk5.py
# kernel path: /tmp/torchinductor_root/dr/cdr4v4mkzyrsy2jewbwgoh2vvxxwk7mm64ykaxpfkv6jlvpbdlk5.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy]
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# Graph fragment:
# Graph fragment:
# %convert_element_type_330 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_63, torch.float32), kwargs = {})
# %convert_element_type_330 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_63, torch.float32), kwargs = {})
triton_poi_fused__to_copy_1 = async_compile.triton('triton_poi_fused__to_copy_1', '''
triton_poi_fused__to_copy_1 = async_compile.triton('triton_poi_fused__to_copy_1', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
size_hints={'x': 67108864},
filename=__file__,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_poi_fused__to_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_poi_fused__to_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 51511296
xnumel = 51511296
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/rl/crlraaykurjoxsxegkqbgdv7ae4g2a5erai7d57blek732nk3rjc.py
# kernel path: /tmp/torchinductor_root/rl/crlraaykurjoxsxegkqbgdv7ae4g2a5erai7d57blek732nk3rjc.py
# Topologically Sorted Source Nodes: [x_144], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Topologically Sorted Source Nodes: [x_144], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# x_144 => convert_element_type_318
# x_144 => convert_element_type_318
# Graph fragment:
# Graph fragment:
# %convert_element_type_331 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_184, torch.float32), kwargs = {})
# %convert_element_type_331 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_184, torch.float32), kwargs = {})
# %convert_element_type_318 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_177, torch.float32), kwargs = {})
# %convert_element_type_318 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_177, torch.float32), kwargs = {})
# %mul_250 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %convert_element_type_318), kwargs = {})
# %mul_250 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %convert_element_type_318), kwargs = {})
# %mul_251 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %rsqrt_62), kwargs = {})
# %mul_251 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %rsqrt_62), kwargs = {})
# %sum_11 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_250, [2], True), kwargs = {})
# %sum_11 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_250, [2], True), kwargs = {})
# %div_3 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_20, 1024), kwargs = {})
# %div_3 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_20, 1024), kwargs = {})
# %pow_84 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_318, 1.0), kwargs = {})
# %pow_84 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_318, 1.0), kwargs = {})
# %mul_254 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_84, 2.0), kwargs = {})
# %mul_254 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_84, 2.0), kwargs = {})
# %mul_255 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_3, %mul_254), kwargs = {})
# %mul_255 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_3, %mul_254), kwargs = {})
# %add_181 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_251, %mul_255), kwargs = {})
# %add_181 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_251, %mul_255), kwargs = {})
# %convert_element_type_332 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_181, torch.bfloat16), kwargs = {})
# %convert_element_type_332 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_181, torch.bfloat16), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', '''
triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
)
@triton.jit
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_out_ptr0, in_ptr0, in_ptr1, xnumel, r0_numel):
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_out_ptr0, in_ptr0, in_ptr1, xnumel, r0_numel):
xnumel = 65536
xnumel = 65536
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
r0_numel = 1024
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
roffset = r0_offset
rindex = r0_index
rindex = r0_index
r0_1 = r0_index
r0_1 = r0_index
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp9 = tmp1 * tmp8
tmp9 = tmp1 * tmp8
tmp10 = -0.5
tmp10 = -0.5
tmp11 = tmp7 * tmp10
tmp11 = tmp7 * tmp10
tmp12 = tmp8 * tmp8
tmp12 = tmp8 * tmp8
tmp13 = tmp12 * tmp8
tmp13 = tmp12 * tmp8
tmp14 = tmp11 * tmp13
tmp14 = tmp11 * tmp13
tmp15 = 0.0009765625
tmp15 = 0.0009765625
tmp16 = tmp14 * tmp15
tmp16 = tmp14 * tmp15
tmp17 = 2.0
tmp17 = 2.0
tmp18 = tmp3 * tmp17
tmp18 = tmp3 * tmp17
tmp19 = tmp16 * tmp18
tmp19 = tmp16 * tmp18
tmp20 = tmp9 + tmp19
tmp20 = tmp9 + tmp19
tmp21 = tmp20.to(tl.float32)
tmp21 = tmp20.to(tl.float32)
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp21, None)
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp21, None)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/u6/cu637rvtsfhxlrzleut2lz46jmik23yqzdbhtptkmjew57itpu2w.py
# kernel path: /tmp/torchinductor_root/u6/cu637rvtsfhxlrzleut2lz46jmik23yqzdbhtptkmjew57itpu2w.py
# Topologically Sorted Source Nodes: [relu_15], Original ATen: [aten.relu, aten.pow, aten.mul, aten.threshold_backward]
# Topologically Sorted Source Nodes: [relu_15], Original ATen: [aten.relu, aten.pow, aten.mul, aten.threshold_backward]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# relu_15 => relu_15
# relu_15 => relu_15
# Graph fragment:
# Graph fragment:
# %relu_15 : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%view_180,), kwargs = {})
# %relu_15 : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%view_180,), kwargs = {})
# %pow_85 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%relu_15, 1.0), kwargs = {})
# %pow_85 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%relu_15, 1.0), kwargs = {})
# %mul_256 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_85, 2.0), kwargs = {})
# %mul_256 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_85, 2.0), kwargs = {})
# %mul_257 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_186, %mul_256), kwargs = {})
# %mul_257 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_186, %mul_256), kwargs = {})
# %le_1 : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu_15, 0), kwargs = {})
# %le_1 : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu_15, 0), kwargs = {})
# %full_default_17 : [num_users=16] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %full_default_17 : [num_users=16] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_8 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le_1, %full_default_17, %mul_257), kwargs = {})
# %where_8 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le_1, %full_default_17, %mul_257), kwargs = {})
triton_poi_fused_mul_pow_relu_threshold_backward_3 = async_compile.triton('triton_poi_fused_mul_pow_relu_threshold_backward_3', '''
triton_poi_fused_mul_pow_relu_threshold_backward_3 = async_compile.triton('triton_poi_fused_mul_pow_relu_threshold_backward_3', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints={'x': 268435456},
size_hints={'x': 268435456},
filename=__file__,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_pow_relu_threshold_backward_3', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_pow_relu_threshold_backward_3', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_poi_fused_mul_pow_relu_threshold_backward_3(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_poi_fused_mul_pow_relu_threshold_backward_3(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 268435456
xnumel = 268435456
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.full([1], 0, tl.int32)
tmp1 = tl.full([1], 0, tl.int32)
tmp2 = triton_helpers.maximum(tmp1, tmp0)
tmp2 = triton_helpers.maximum(tmp1, tmp0)
tmp3 = 0.0
tmp3 = 0.0
tmp4 = tmp2 <= tmp3
tmp4 = tmp2 <= tmp3
tmp6 = 2.0
tmp6 = 2.0
tmp7 = tmp2 * tmp6
tmp7 = tmp2 * tmp6
tmp8 = tmp5 * tmp7
tmp8 = tmp5 * tmp7
tmp9 = tl.where(tmp4, tmp3, tmp8)
tmp9 = tl.where(tmp4, tmp3, tmp8)
tl.store(in_out_ptr0 + (x0), tmp9, None)
tl.store(in_out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/y6/cy6sn724x4qlyi5erhmh4ehnfh6gw456auqzpb7n6yuo6c2sfuyv.py
# kernel path: /tmp/torchinductor_root/y6/cy6sn724x4qlyi5erhmh4ehnfh6gw456auqzpb7n6yuo6c2sfuyv.py
# Topologically Sorted Source Nodes: [rms_norm_61], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Topologically Sorted Source Nodes: [rms_norm_61], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# Source node to ATen node mapping:
# rms_norm_61 => convert_element_type_312
# rms_norm_61 => convert_element_type_312
# Graph fragment:
# Graph fragment:
# %convert_element_type_341 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_188, torch.float32), kwargs = {})
# %convert_element_type_341 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_188, torch.float32), kwargs = {})
# %convert_element_type_312 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_175, torch.float32), kwargs = {})
# %convert_element_type_312 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_175, torch.float32), kwargs = {})
# %mul_258 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %convert_element_type_312), kwargs = {})
# %mul_258 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %convert_element_type_312), kwargs = {})
# %mul_259 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %rsqrt_61), kwargs = {})
# %mul_259 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %rsqrt_61), kwargs = {})
# %sum_12 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_258, [2], True), kwargs = {})
# %sum_12 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_258, [2], True), kwargs = {})
# %div_4 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_21, 1024), kwargs = {})
# %div_4 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_21, 1024), kwargs = {})
# %pow_87 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_312, 1.0), kwargs
# %pow_87 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_312, 1.0), kwargs