sm90 vs sm100 rowwise cutlass gemm

Created Diff never expires
27 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
189 lines
13 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
179 lines
// Cutlass rowwise kernel for sm90
// Cutlass rowwise kernel for SM100
template <
template <
typename TileShape,
typename TileShape,
typename ClusterShape,
typename ClusterShape,
typename Transposed,
typename Transposed,
typename FastAccum,
typename FastAccum,
typename DtypeA,
typename DtypeA,
typename DtypeB,
typename DtypeB,
typename DtypeBias>
typename DtypeBias>
void f8f8bf16_rowwise_impl(
void f8f8bf16_rowwise_impl_sm100(
at::Tensor XQ, // FP8
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
std::optional<at::Tensor> bias,
at::Tensor out,
at::Tensor out,
const int swizzle) {
const int swizzle) {
int M = XQ.size(0);
int M = XQ.size(0);
int N = WQ.size(1);
int N = WQ.size(1);
int K = XQ.size(1);
int K = XQ.size(1);


// Workaround for https://github.com/pytorch/pytorch/issues/133334.
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
if (M % 256 > 0) {
if (M % 256 > 0) {
int padded_M = ((M - 1) / 256 + 1) * 256;
int padded_M = ((M - 1) / 256 + 1) * 256;
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
.copy_(std::move(x_scale));
.copy_(std::move(x_scale));
x_scale = std::move(padded_x_scale);
x_scale = std::move(padded_x_scale);
}
}


using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);


using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);


using LayoutOutput = std::conditional_t<
using LayoutOutput = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
cutlass::layout::RowMajor>;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);


// Tag indicating the minimum SM that supports the intended feature
// Tag indicating the minimum SM that supports the intended feature
using ArchTag = cutlass::arch::Sm90;
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using OperatorClass = cutlass::arch::OpClassTensorOp;


// Implement rowwise scaling epilogue.
// Implement rowwise scaling epilogue.
constexpr int ColBroadcastStages = 0;
constexpr int ColBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;


using XScale = cutlass::epilogue::fusion::
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;


using WScale = cutlass::epilogue::fusion::
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;


using Bias = std::conditional_t<
using Bias = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;


using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
Multiply,
WScale,
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;


using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
Cast,
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90EVT<
Add,
Add,
Bias,
Bias,
AccumScale>>;
AccumScale>>;


constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
using CollectiveEpilogue =
cutlass::arch::Sm100, OperatorClass,
typename cutlass::epilogue::collective::CollectiveBuilder<
TileShape, ClusterShape,
ArchTag,
cutlass::epilogue::collective::EpilogueTileAuto,
OperatorClass,
DtypeAccum, DtypeEpilogue,
TileShape,
DtypeOutput, LayoutOutput, AlignmentOutput,
ClusterShape,
DtypeOutput, LayoutOutput, AlignmentOutput,
cutlass::epilogue::collective::EpilogueTileAuto,
EpilogueScheduleType,
DtypeAccum,
EpilogueEVT>::CollectiveOp;
DtypeEpilogue,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
typename Schedule<large_tile, FastAccum::value>::epilogue_type,
EpilogueEVT>::CollectiveOp;


using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using CollectiveMainloop =
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
ArchTag,
OperatorClass,
OperatorClass,
DtypeA,
DtypeA,
LayoutInputA,
LayoutInputA,
AlignmentInputA,
AlignmentInputA,
DtypeB,
DtypeB,
LayoutInputB,
LayoutInputB,
AlignmentInputB,
AlignmentInputB,
DtypeAccum,
DtypeAccum,
TileShape,
TileShape,
ClusterShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp;
typename Schedule<large_tile, FastAccum::value>::type>::
CollectiveOp;


using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveMainloop,
CollectiveEpilogue>;
CollectiveEpilogue>;


using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;


using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideOutput = typename Gemm::GemmKernel::StrideC;


StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));


typename Gemm::Arguments arguments{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{M, N, K},
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
stride_a,
stride_a,
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
stride_b},
stride_b},
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
: nullptr},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output,
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output}};
stride_output}};


Gemm gemm;
Gemm gemm;


// Using the arguments, query for extra workspace required for matrix
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = Gemm::get_workspace_size(arguments);


// Ensure persistent kernels leave enough free SMs for NCCL background ops.
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
arguments.hw_info.sm_count =
arguments.hw_info.sm_count =
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
}


// Set the swizzle size
// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;
arguments.scheduler.max_swizzle_size = swizzle;


// Allocate workspace memory
// Allocate workspace memory
auto workspace = XQ.new_empty(
auto workspace = XQ.new_empty(
{static_cast<int64_t>(workspace_size)},
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));
at::TensorOptions().dtype(at::kByte));


// Check the problem size is supported or not
// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
throw std::runtime_error("cutlass cannot implement");
}
}


// Initialize CUTLASS kernel with arguments and workspace pointer
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
status = gemm.initialize(arguments, workspace.data_ptr());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
throw std::runtime_error("cutlass cannot initialize");
}
}


status = gemm(at::cuda::getCurrentCUDAStream());
status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
throw std::runtime_error(
std::string("cutlass cannot run") +
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
cutlass::cutlassGetStatusString(status));
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}