Diff
checker
文本
文本
圖像
文檔
Excel
文件夾
Legal
Enterprise
桌面版
定價
登入
下載 Diffchecker 桌面版
比較文本
尋找兩個文字檔案之間的差異
工具
歷史
即時編輯器
摺疊未變更行
關閉換行
檢視
拆分
統一
比對精度
智能
單詞
字符
語法突出顯示
選擇語法
忽略
文字轉換
前往第一個差異
編輯輸入
Diffchecker Desktop
執行Diffchecker最安全的方式。取得Diffchecker桌面應用程式:您的差異永遠不會離開您的電腦!
取得桌面版
sm90 vs sm100 rowwise cutlass gemm
建立於
去年
差異永不過期
清除
匯出
分享
解釋
33 刪除
行
總計
刪除
字符
總計
刪除
要繼續使用此功能,請升級到
Diff
checker
Pro
查看價格
189 行
全部複製
10 新增
行
總計
新增
字符
總計
新增
要繼續使用此功能,請升級到
Diff
checker
Pro
查看價格
179 行
全部複製
複製
已複製
複製
已複製
// Cutlass rowwise kernel for
sm90
// Cutlass rowwise kernel for
SM100
template <
template <
typename TileShape,
typename TileShape,
typename ClusterShape,
typename ClusterShape,
typename Transposed,
typename Transposed,
typename FastAccum,
typename FastAccum,
typename DtypeA,
typename DtypeA,
typename DtypeB,
typename DtypeB,
typename DtypeBias>
typename DtypeBias>
複製
已複製
複製
已複製
void f8f8bf16_rowwise_impl
(
void f8f8bf16_rowwise_impl
_sm100
(
at::Tensor XQ, // FP8
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
std::optional<at::Tensor> bias,
at::Tensor out,
at::Tensor out,
const int swizzle) {
const int swizzle) {
int M = XQ.size(0);
int M = XQ.size(0);
int N = WQ.size(1);
int N = WQ.size(1);
int K = XQ.size(1);
int K = XQ.size(1);
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
if (M % 256 > 0) {
if (M % 256 > 0) {
int padded_M = ((M - 1) / 256 + 1) * 256;
int padded_M = ((M - 1) / 256 + 1) * 256;
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
.copy_(std::move(x_scale));
.copy_(std::move(x_scale));
x_scale = std::move(padded_x_scale);
x_scale = std::move(padded_x_scale);
}
}
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
using LayoutOutput = std::conditional_t<
using LayoutOutput = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
cutlass::layout::RowMajor>;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
// Tag indicating the minimum SM that supports the intended feature
// Tag indicating the minimum SM that supports the intended feature
複製
已複製
複製
已複製
using ArchTag = cutlass::arch::Sm
9
0;
using ArchTag = cutlass::arch::Sm
10
0;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// Implement rowwise scaling epilogue.
// Implement rowwise scaling epilogue.
constexpr int ColBroadcastStages = 0;
constexpr int ColBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
using XScale = cutlass::epilogue::fusion::
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
using WScale = cutlass::epilogue::fusion::
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
using Bias = std::conditional_t<
using Bias = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
Multiply,
WScale,
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
Cast,
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90EVT<
Add,
Add,
Bias,
Bias,
AccumScale>>;
AccumScale>>;
複製
已複製
複製
已複製
constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
using CollectiveEpilogue =
cutlass::arch::Sm100,
OperatorClass,
typename cutlass::epilogue::collective::CollectiveBuilder<
TileShape,
ClusterShape,
ArchTag,
cutlass::epilogue::collective::EpilogueTileAuto,
OperatorClass,
DtypeAccum,
DtypeEpilogue,
TileShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
ClusterShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueTileAuto,
EpilogueScheduleType,
DtypeAccum,
EpilogueEVT>::CollectiveOp;
DtypeEpilogue,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
typename Schedule<large_tile, FastAccum::value>::epilogue_type,
EpilogueEVT>::CollectiveOp;
複製
已複製
複製
已複製
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using CollectiveMainloop =
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
ArchTag,
OperatorClass,
OperatorClass,
DtypeA,
DtypeA,
LayoutInputA,
LayoutInputA,
AlignmentInputA,
AlignmentInputA,
DtypeB,
DtypeB,
LayoutInputB,
LayoutInputB,
AlignmentInputB,
AlignmentInputB,
DtypeAccum,
DtypeAccum,
TileShape,
TileShape,
ClusterShape,
ClusterShape,
複製
已複製
複製
已複製
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::
CollectiveOp;
typename Schedule<large_tile, FastAccum::value>::type
>::
CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveMainloop,
CollectiveEpilogue>;
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
typename Gemm::Arguments arguments{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{M, N, K},
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
stride_a,
stride_a,
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
stride_b},
stride_b},
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
: nullptr},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output,
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output}};
stride_output}};
Gemm gemm;
Gemm gemm;
// Using the arguments, query for extra workspace required for matrix
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
arguments.hw_info.sm_count =
arguments.hw_info.sm_count =
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
}
// Set the swizzle size
// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;
arguments.scheduler.max_swizzle_size = swizzle;
// Allocate workspace memory
// Allocate workspace memory
auto workspace = XQ.new_empty(
auto workspace = XQ.new_empty(
{static_cast<int64_t>(workspace_size)},
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));
at::TensorOptions().dtype(at::kByte));
// Check the problem size is supported or not
// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
throw std::runtime_error("cutlass cannot implement");
}
}
// Initialize CUTLASS kernel with arguments and workspace pointer
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
status = gemm.initialize(arguments, workspace.data_ptr());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
throw std::runtime_error("cutlass cannot initialize");
}
}
status = gemm(at::cuda::getCurrentCUDAStream());
status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
throw std::runtime_error(
std::string("cutlass cannot run") +
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
cutlass::cutlassGetStatusString(status));
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
已保存差異
原始文本
開啟檔案
// Cutlass rowwise kernel for sm90 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, typename Schedule<large_tile, FastAccum::value>::epilogue_type, EpilogueEVT>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( sizeof(typename CollectiveEpilogue::SharedStorage))>, typename Schedule<large_tile, FastAccum::value>::type>:: CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
更改後文本
開啟檔案
// Cutlass rowwise kernel for SM100 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl_sm100( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduleType>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
尋找差異