Diff
checker
文本
文本
图像
文档
Excel
文件夹
Legal
Enterprise
桌面版
定价
登录
下载 Diffchecker 桌面版
比较文本
查找两个文本文件之间的差异
工具
历史
实时编辑器
折叠未更改行
关闭换行
视图
拆分
统一
比对精度
智能
单词
字符
语法高亮
选择语法
忽略
文本转换
转到第一个差异
编辑输入
Diffchecker Desktop
运行Diffchecker最安全的方式。获取Diffchecker桌面应用:您的差异永远不会离开您的电脑!
获取桌面版
sm90 vs sm100 rowwise cutlass gemm
创建于
去年
差异永不过期
清除
导出
分享
解释
33 删除
行
总计
删除
字符
总计
删除
要继续使用此功能,请升级到
Diff
checker
Pro
查看价格
189 行
全部复制
10 添加
行
总计
添加
字符
总计
添加
要继续使用此功能,请升级到
Diff
checker
Pro
查看价格
179 行
全部复制
复制
已复制
复制
已复制
// Cutlass rowwise kernel for
sm90
// Cutlass rowwise kernel for
SM100
template <
template <
typename TileShape,
typename TileShape,
typename ClusterShape,
typename ClusterShape,
typename Transposed,
typename Transposed,
typename FastAccum,
typename FastAccum,
typename DtypeA,
typename DtypeA,
typename DtypeB,
typename DtypeB,
typename DtypeBias>
typename DtypeBias>
复制
已复制
复制
已复制
void f8f8bf16_rowwise_impl
(
void f8f8bf16_rowwise_impl
_sm100
(
at::Tensor XQ, // FP8
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
std::optional<at::Tensor> bias,
at::Tensor out,
at::Tensor out,
const int swizzle) {
const int swizzle) {
int M = XQ.size(0);
int M = XQ.size(0);
int N = WQ.size(1);
int N = WQ.size(1);
int K = XQ.size(1);
int K = XQ.size(1);
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
if (M % 256 > 0) {
if (M % 256 > 0) {
int padded_M = ((M - 1) / 256 + 1) * 256;
int padded_M = ((M - 1) / 256 + 1) * 256;
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
.copy_(std::move(x_scale));
.copy_(std::move(x_scale));
x_scale = std::move(padded_x_scale);
x_scale = std::move(padded_x_scale);
}
}
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
using LayoutOutput = std::conditional_t<
using LayoutOutput = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
cutlass::layout::RowMajor>;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
// Tag indicating the minimum SM that supports the intended feature
// Tag indicating the minimum SM that supports the intended feature
复制
已复制
复制
已复制
using ArchTag = cutlass::arch::Sm
9
0;
using ArchTag = cutlass::arch::Sm
10
0;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// Implement rowwise scaling epilogue.
// Implement rowwise scaling epilogue.
constexpr int ColBroadcastStages = 0;
constexpr int ColBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
using XScale = cutlass::epilogue::fusion::
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
using WScale = cutlass::epilogue::fusion::
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
using Bias = std::conditional_t<
using Bias = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
Multiply,
WScale,
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
Cast,
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90EVT<
Add,
Add,
Bias,
Bias,
AccumScale>>;
AccumScale>>;
复制
已复制
复制
已复制
constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
using CollectiveEpilogue =
cutlass::arch::Sm100,
OperatorClass,
typename cutlass::epilogue::collective::CollectiveBuilder<
TileShape,
ClusterShape,
ArchTag,
cutlass::epilogue::collective::EpilogueTileAuto,
OperatorClass,
DtypeAccum,
DtypeEpilogue,
TileShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
ClusterShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueTileAuto,
EpilogueScheduleType,
DtypeAccum,
EpilogueEVT>::CollectiveOp;
DtypeEpilogue,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
typename Schedule<large_tile, FastAccum::value>::epilogue_type,
EpilogueEVT>::CollectiveOp;
复制
已复制
复制
已复制
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using CollectiveMainloop =
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
ArchTag,
OperatorClass,
OperatorClass,
DtypeA,
DtypeA,
LayoutInputA,
LayoutInputA,
AlignmentInputA,
AlignmentInputA,
DtypeB,
DtypeB,
LayoutInputB,
LayoutInputB,
AlignmentInputB,
AlignmentInputB,
DtypeAccum,
DtypeAccum,
TileShape,
TileShape,
ClusterShape,
ClusterShape,
复制
已复制
复制
已复制
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::
CollectiveOp;
typename Schedule<large_tile, FastAccum::value>::type
>::
CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveMainloop,
CollectiveEpilogue>;
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
typename Gemm::Arguments arguments{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{M, N, K},
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
stride_a,
stride_a,
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
stride_b},
stride_b},
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
: nullptr},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output,
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output}};
stride_output}};
Gemm gemm;
Gemm gemm;
// Using the arguments, query for extra workspace required for matrix
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
arguments.hw_info.sm_count =
arguments.hw_info.sm_count =
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
}
// Set the swizzle size
// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;
arguments.scheduler.max_swizzle_size = swizzle;
// Allocate workspace memory
// Allocate workspace memory
auto workspace = XQ.new_empty(
auto workspace = XQ.new_empty(
{static_cast<int64_t>(workspace_size)},
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));
at::TensorOptions().dtype(at::kByte));
// Check the problem size is supported or not
// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
throw std::runtime_error("cutlass cannot implement");
}
}
// Initialize CUTLASS kernel with arguments and workspace pointer
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
status = gemm.initialize(arguments, workspace.data_ptr());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
throw std::runtime_error("cutlass cannot initialize");
}
}
status = gemm(at::cuda::getCurrentCUDAStream());
status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
throw std::runtime_error(
std::string("cutlass cannot run") +
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
cutlass::cutlassGetStatusString(status));
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
已保存差异
原始文本
打开文件
// Cutlass rowwise kernel for sm90 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, typename Schedule<large_tile, FastAccum::value>::epilogue_type, EpilogueEVT>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( sizeof(typename CollectiveEpilogue::SharedStorage))>, typename Schedule<large_tile, FastAccum::value>::type>:: CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
更改后文本
打开文件
// Cutlass rowwise kernel for SM100 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl_sm100( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduleType>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
查找差异