Diff
checker
Texte
Texte
Images
Documents
Excel
Dossiers
Legal
Enterprise
Application de bureau
Prix
Se connecter
Télécharger Diffchecker Desktop
Comparer le texte
Trouver la différence entre deux fichiers texte
Outils
Historique
Éditeur live
Cacher identiques
Sans retour à la ligne
Vue
Divisé
Unifié
Niveau de précision
Intelligent
Mot
Caractère
Coloration syntaxique
Choisir la syntaxe
Ignorer
Transformer le texte
Aller au premier écart
Modifier l'entrée
Diffchecker Desktop
La façon la plus sécurisée d'utiliser Diffchecker. Obtenez l'application Diffchecker Desktop : vos diffs ne quittent jamais votre ordinateur !
Obtenir Desktop
sm90 vs sm100 rowwise cutlass gemm
Créé
l’année dernière
Le diff n'expire jamais
Effacer
Exporter
Partager
Expliquer
33 suppressions
Lignes
Total
Supprimé
Caractères
Total
Supprimé
Pour continuer à utiliser cette fonctionnalité, passez à
Diff
checker
Pro
Voir les prix
189 lignes
Copier tout
10 ajouts
Lignes
Total
Ajouté
Caractères
Total
Ajouté
Pour continuer à utiliser cette fonctionnalité, passez à
Diff
checker
Pro
Voir les prix
179 lignes
Copier tout
Copier
Copié
Copier
Copié
// Cutlass rowwise kernel for
sm90
// Cutlass rowwise kernel for
SM100
template <
template <
typename TileShape,
typename TileShape,
typename ClusterShape,
typename ClusterShape,
typename Transposed,
typename Transposed,
typename FastAccum,
typename FastAccum,
typename DtypeA,
typename DtypeA,
typename DtypeB,
typename DtypeB,
typename DtypeBias>
typename DtypeBias>
Copier
Copié
Copier
Copié
void f8f8bf16_rowwise_impl
(
void f8f8bf16_rowwise_impl
_sm100
(
at::Tensor XQ, // FP8
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
std::optional<at::Tensor> bias,
at::Tensor out,
at::Tensor out,
const int swizzle) {
const int swizzle) {
int M = XQ.size(0);
int M = XQ.size(0);
int N = WQ.size(1);
int N = WQ.size(1);
int K = XQ.size(1);
int K = XQ.size(1);
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
if (M % 256 > 0) {
if (M % 256 > 0) {
int padded_M = ((M - 1) / 256 + 1) * 256;
int padded_M = ((M - 1) / 256 + 1) * 256;
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
.copy_(std::move(x_scale));
.copy_(std::move(x_scale));
x_scale = std::move(padded_x_scale);
x_scale = std::move(padded_x_scale);
}
}
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
using LayoutOutput = std::conditional_t<
using LayoutOutput = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
cutlass::layout::RowMajor>;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
// Tag indicating the minimum SM that supports the intended feature
// Tag indicating the minimum SM that supports the intended feature
Copier
Copié
Copier
Copié
using ArchTag = cutlass::arch::Sm
9
0;
using ArchTag = cutlass::arch::Sm
10
0;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// Implement rowwise scaling epilogue.
// Implement rowwise scaling epilogue.
constexpr int ColBroadcastStages = 0;
constexpr int ColBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
using XScale = cutlass::epilogue::fusion::
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
using WScale = cutlass::epilogue::fusion::
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
using Bias = std::conditional_t<
using Bias = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
Multiply,
WScale,
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
Cast,
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90EVT<
Add,
Add,
Bias,
Bias,
AccumScale>>;
AccumScale>>;
Copier
Copié
Copier
Copié
constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
using CollectiveEpilogue =
cutlass::arch::Sm100,
OperatorClass,
typename cutlass::epilogue::collective::CollectiveBuilder<
TileShape,
ClusterShape,
ArchTag,
cutlass::epilogue::collective::EpilogueTileAuto,
OperatorClass,
DtypeAccum,
DtypeEpilogue,
TileShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
ClusterShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueTileAuto,
EpilogueScheduleType,
DtypeAccum,
EpilogueEVT>::CollectiveOp;
DtypeEpilogue,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
typename Schedule<large_tile, FastAccum::value>::epilogue_type,
EpilogueEVT>::CollectiveOp;
Copier
Copié
Copier
Copié
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using CollectiveMainloop =
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
ArchTag,
OperatorClass,
OperatorClass,
DtypeA,
DtypeA,
LayoutInputA,
LayoutInputA,
AlignmentInputA,
AlignmentInputA,
DtypeB,
DtypeB,
LayoutInputB,
LayoutInputB,
AlignmentInputB,
AlignmentInputB,
DtypeAccum,
DtypeAccum,
TileShape,
TileShape,
ClusterShape,
ClusterShape,
Copier
Copié
Copier
Copié
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::
CollectiveOp;
typename Schedule<large_tile, FastAccum::value>::type
>::
CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveMainloop,
CollectiveEpilogue>;
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
typename Gemm::Arguments arguments{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{M, N, K},
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
stride_a,
stride_a,
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
stride_b},
stride_b},
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
: nullptr},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output,
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output}};
stride_output}};
Gemm gemm;
Gemm gemm;
// Using the arguments, query for extra workspace required for matrix
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
arguments.hw_info.sm_count =
arguments.hw_info.sm_count =
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
}
// Set the swizzle size
// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;
arguments.scheduler.max_swizzle_size = swizzle;
// Allocate workspace memory
// Allocate workspace memory
auto workspace = XQ.new_empty(
auto workspace = XQ.new_empty(
{static_cast<int64_t>(workspace_size)},
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));
at::TensorOptions().dtype(at::kByte));
// Check the problem size is supported or not
// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
throw std::runtime_error("cutlass cannot implement");
}
}
// Initialize CUTLASS kernel with arguments and workspace pointer
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
status = gemm.initialize(arguments, workspace.data_ptr());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
throw std::runtime_error("cutlass cannot initialize");
}
}
status = gemm(at::cuda::getCurrentCUDAStream());
status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
throw std::runtime_error(
std::string("cutlass cannot run") +
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
cutlass::cutlassGetStatusString(status));
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
Différences enregistrées
Texte d'origine
Ouvrir un fichier
// Cutlass rowwise kernel for sm90 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, typename Schedule<large_tile, FastAccum::value>::epilogue_type, EpilogueEVT>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( sizeof(typename CollectiveEpilogue::SharedStorage))>, typename Schedule<large_tile, FastAccum::value>::type>:: CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
Texte modifié
Ouvrir un fichier
// Cutlass rowwise kernel for SM100 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl_sm100( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduleType>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
Trouver la différence