Diff
checker
Texto
Texto
Imágenes
Documentos
Excel
Carpetas
Legal
Enterprise
Aplicación de escritorio
Precios
Iniciar sesión
Descargar Diffchecker Desktop
Comparar texto
Encuentra la diferencia entre dos archivos de texto
Herramientas
Historial
Editor live
Ocultar sin cambios
Sin ajuste de línea
Vista
Dividido
Unificado
Nivel de detalle
Inteligente
Palabra
Letra
Resaltado de sintaxis
Elegir sintaxis
Ignorar
Transformar texto
Ir al primer cambio
Editar entrada
Diffchecker Desktop
La forma más segura de usar Diffchecker. ¡Obtén la app de Diffchecker Desktop: tus diffs nunca salen de tu computadora!
Obtener Desktop
sm90 vs sm100 rowwise cutlass gemm
Creado
el año pasado
El diff nunca expira
Borrar
Exportar
Compartir
Explicar
33 eliminaciones
Líneas
Total
Eliminado
Caracteres
Total
Eliminado
Para continuar usando esta función, actualice a
Diff
checker
Pro
Ver precios
189 líneas
Copiar todo
10 adiciones
Líneas
Total
Añadido
Caracteres
Total
Añadido
Para continuar usando esta función, actualice a
Diff
checker
Pro
Ver precios
179 líneas
Copiar todo
Copiar
Copiado
Copiar
Copiado
// Cutlass rowwise kernel for
sm90
// Cutlass rowwise kernel for
SM100
template <
template <
typename TileShape,
typename TileShape,
typename ClusterShape,
typename ClusterShape,
typename Transposed,
typename Transposed,
typename FastAccum,
typename FastAccum,
typename DtypeA,
typename DtypeA,
typename DtypeB,
typename DtypeB,
typename DtypeBias>
typename DtypeBias>
Copiar
Copiado
Copiar
Copiado
void f8f8bf16_rowwise_impl
(
void f8f8bf16_rowwise_impl
_sm100
(
at::Tensor XQ, // FP8
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
std::optional<at::Tensor> bias,
at::Tensor out,
at::Tensor out,
const int swizzle) {
const int swizzle) {
int M = XQ.size(0);
int M = XQ.size(0);
int N = WQ.size(1);
int N = WQ.size(1);
int K = XQ.size(1);
int K = XQ.size(1);
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
if (M % 256 > 0) {
if (M % 256 > 0) {
int padded_M = ((M - 1) / 256 + 1) * 256;
int padded_M = ((M - 1) / 256 + 1) * 256;
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
.copy_(std::move(x_scale));
.copy_(std::move(x_scale));
x_scale = std::move(padded_x_scale);
x_scale = std::move(padded_x_scale);
}
}
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
using LayoutOutput = std::conditional_t<
using LayoutOutput = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
cutlass::layout::RowMajor>;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
// Tag indicating the minimum SM that supports the intended feature
// Tag indicating the minimum SM that supports the intended feature
Copiar
Copiado
Copiar
Copiado
using ArchTag = cutlass::arch::Sm
9
0;
using ArchTag = cutlass::arch::Sm
10
0;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// Implement rowwise scaling epilogue.
// Implement rowwise scaling epilogue.
constexpr int ColBroadcastStages = 0;
constexpr int ColBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
using XScale = cutlass::epilogue::fusion::
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
using WScale = cutlass::epilogue::fusion::
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
using Bias = std::conditional_t<
using Bias = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
Multiply,
WScale,
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
Cast,
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90EVT<
Add,
Add,
Bias,
Bias,
AccumScale>>;
AccumScale>>;
Copiar
Copiado
Copiar
Copiado
constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
using CollectiveEpilogue =
cutlass::arch::Sm100,
OperatorClass,
typename cutlass::epilogue::collective::CollectiveBuilder<
TileShape,
ClusterShape,
ArchTag,
cutlass::epilogue::collective::EpilogueTileAuto,
OperatorClass,
DtypeAccum,
DtypeEpilogue,
TileShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
ClusterShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueTileAuto,
EpilogueScheduleType,
DtypeAccum,
EpilogueEVT>::CollectiveOp;
DtypeEpilogue,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
typename Schedule<large_tile, FastAccum::value>::epilogue_type,
EpilogueEVT>::CollectiveOp;
Copiar
Copiado
Copiar
Copiado
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using CollectiveMainloop =
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
ArchTag,
OperatorClass,
OperatorClass,
DtypeA,
DtypeA,
LayoutInputA,
LayoutInputA,
AlignmentInputA,
AlignmentInputA,
DtypeB,
DtypeB,
LayoutInputB,
LayoutInputB,
AlignmentInputB,
AlignmentInputB,
DtypeAccum,
DtypeAccum,
TileShape,
TileShape,
ClusterShape,
ClusterShape,
Copiar
Copiado
Copiar
Copiado
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::
CollectiveOp;
typename Schedule<large_tile, FastAccum::value>::type
>::
CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveMainloop,
CollectiveEpilogue>;
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
typename Gemm::Arguments arguments{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{M, N, K},
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
stride_a,
stride_a,
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
stride_b},
stride_b},
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
: nullptr},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output,
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output}};
stride_output}};
Gemm gemm;
Gemm gemm;
// Using the arguments, query for extra workspace required for matrix
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
arguments.hw_info.sm_count =
arguments.hw_info.sm_count =
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
}
// Set the swizzle size
// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;
arguments.scheduler.max_swizzle_size = swizzle;
// Allocate workspace memory
// Allocate workspace memory
auto workspace = XQ.new_empty(
auto workspace = XQ.new_empty(
{static_cast<int64_t>(workspace_size)},
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));
at::TensorOptions().dtype(at::kByte));
// Check the problem size is supported or not
// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
throw std::runtime_error("cutlass cannot implement");
}
}
// Initialize CUTLASS kernel with arguments and workspace pointer
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
status = gemm.initialize(arguments, workspace.data_ptr());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
throw std::runtime_error("cutlass cannot initialize");
}
}
status = gemm(at::cuda::getCurrentCUDAStream());
status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
throw std::runtime_error(
std::string("cutlass cannot run") +
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
cutlass::cutlassGetStatusString(status));
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
Diferencias guardadas
Texto original
Abrir archivo
// Cutlass rowwise kernel for sm90 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, typename Schedule<large_tile, FastAccum::value>::epilogue_type, EpilogueEVT>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( sizeof(typename CollectiveEpilogue::SharedStorage))>, typename Schedule<large_tile, FastAccum::value>::type>:: CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
Texto modificado
Abrir archivo
// Cutlass rowwise kernel for SM100 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl_sm100( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduleType>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
Encontrar la diferencia