Diff
checker
텍스트
텍스트
이미지
문서
Excel
폴더
Legal
Enterprise
데스크톱
요금제
로그인
데스크톱 앱 다운로드
텍스트 비교
두 텍스트 파일의 차이점을 찾아보세요
도구
기록
실시간 편집
변경 없는 행 숨기기
줄바꿈 비활성화
레이아웃
나란히 보기
합쳐 보기
비교 단위
스마트
단어
글자
구문 강조
언어 선택
제외
텍스트 변환
첫 변경으로
수정
Diffchecker Desktop
가장 안전하게 Diffchecker를 사용하는 방법. 데스크톱 앱을 사용하면 비교 데이터가 외부로 전송되지 않습니다!
데스크톱 앱 받기
sm90 vs sm100 rowwise cutlass gemm
생성일
작년
비교 결과 만료 없음
초기화
내보내기
공유
설명
33 삭제
행
총
삭제
글자
총
삭제
이 기능을 계속 사용하려면 업그레이드해 주세요
Diff
checker
Pro
요금제 보기
189 행
복사
10 추가
행
총
추가
글자
총
추가
이 기능을 계속 사용하려면 업그레이드해 주세요
Diff
checker
Pro
요금제 보기
179 행
복사
복사
복사됨
복사
복사됨
// Cutlass rowwise kernel for
sm90
// Cutlass rowwise kernel for
SM100
template <
template <
typename TileShape,
typename TileShape,
typename ClusterShape,
typename ClusterShape,
typename Transposed,
typename Transposed,
typename FastAccum,
typename FastAccum,
typename DtypeA,
typename DtypeA,
typename DtypeB,
typename DtypeB,
typename DtypeBias>
typename DtypeBias>
복사
복사됨
복사
복사됨
void f8f8bf16_rowwise_impl
(
void f8f8bf16_rowwise_impl
_sm100
(
at::Tensor XQ, // FP8
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
std::optional<at::Tensor> bias,
at::Tensor out,
at::Tensor out,
const int swizzle) {
const int swizzle) {
int M = XQ.size(0);
int M = XQ.size(0);
int N = WQ.size(1);
int N = WQ.size(1);
int K = XQ.size(1);
int K = XQ.size(1);
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
// Workaround for https://github.com/pytorch/pytorch/issues/133334.
if (M % 256 > 0) {
if (M % 256 > 0) {
int padded_M = ((M - 1) / 256 + 1) * 256;
int padded_M = ((M - 1) / 256 + 1) * 256;
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
.copy_(std::move(x_scale));
.copy_(std::move(x_scale));
x_scale = std::move(padded_x_scale);
x_scale = std::move(padded_x_scale);
}
}
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
using LayoutOutput = std::conditional_t<
using LayoutOutput = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
cutlass::layout::RowMajor>;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
// Tag indicating the minimum SM that supports the intended feature
// Tag indicating the minimum SM that supports the intended feature
복사
복사됨
복사
복사됨
using ArchTag = cutlass::arch::Sm
9
0;
using ArchTag = cutlass::arch::Sm
10
0;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// Implement rowwise scaling epilogue.
// Implement rowwise scaling epilogue.
constexpr int ColBroadcastStages = 0;
constexpr int ColBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
constexpr int RowBroadcastStages = 0;
using XScale = cutlass::epilogue::fusion::
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
using WScale = cutlass::epilogue::fusion::
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
using Bias = std::conditional_t<
using Bias = std::conditional_t<
Transposed::value,
Transposed::value,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
cutlass::epilogue::fusion::
cutlass::epilogue::fusion::
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
Multiply,
WScale,
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
Cast,
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90EVT<
Add,
Add,
Bias,
Bias,
AccumScale>>;
AccumScale>>;
복사
복사됨
복사
복사됨
constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
using CollectiveEpilogue =
cutlass::arch::Sm100,
OperatorClass,
typename cutlass::epilogue::collective::CollectiveBuilder<
TileShape,
ClusterShape,
ArchTag,
cutlass::epilogue::collective::EpilogueTileAuto,
OperatorClass,
DtypeAccum,
DtypeEpilogue,
TileShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
ClusterShape,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::collective::EpilogueTileAuto,
EpilogueScheduleType,
DtypeAccum,
EpilogueEVT>::CollectiveOp;
DtypeEpilogue,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
DtypeOutput,
LayoutOutput,
AlignmentOutput,
typename Schedule<large_tile, FastAccum::value>::epilogue_type,
EpilogueEVT>::CollectiveOp;
복사
복사됨
복사
복사됨
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using CollectiveMainloop =
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
ArchTag,
OperatorClass,
OperatorClass,
DtypeA,
DtypeA,
LayoutInputA,
LayoutInputA,
AlignmentInputA,
AlignmentInputA,
DtypeB,
DtypeB,
LayoutInputB,
LayoutInputB,
AlignmentInputB,
AlignmentInputB,
DtypeAccum,
DtypeAccum,
TileShape,
TileShape,
ClusterShape,
ClusterShape,
복사
복사됨
복사
복사됨
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::
CollectiveOp;
typename Schedule<large_tile, FastAccum::value>::type
>::
CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveMainloop,
CollectiveEpilogue>;
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
typename Gemm::Arguments arguments{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{M, N, K},
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
stride_a,
stride_a,
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
stride_b},
stride_b},
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
: nullptr},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}},
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output,
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output}};
stride_output}};
Gemm gemm;
Gemm gemm;
// Using the arguments, query for extra workspace required for matrix
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
arguments.hw_info.sm_count =
arguments.hw_info.sm_count =
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
}
// Set the swizzle size
// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;
arguments.scheduler.max_swizzle_size = swizzle;
// Allocate workspace memory
// Allocate workspace memory
auto workspace = XQ.new_empty(
auto workspace = XQ.new_empty(
{static_cast<int64_t>(workspace_size)},
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));
at::TensorOptions().dtype(at::kByte));
// Check the problem size is supported or not
// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
throw std::runtime_error("cutlass cannot implement");
}
}
// Initialize CUTLASS kernel with arguments and workspace pointer
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
status = gemm.initialize(arguments, workspace.data_ptr());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
throw std::runtime_error("cutlass cannot initialize");
}
}
status = gemm(at::cuda::getCurrentCUDAStream());
status = gemm(at::cuda::getCurrentCUDAStream());
if (status != cutlass::Status::kSuccess) {
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
throw std::runtime_error(
std::string("cutlass cannot run") +
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
cutlass::cutlassGetStatusString(status));
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
저장된 비교 결과
원본
파일 열기
// Cutlass rowwise kernel for sm90 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, typename Schedule<large_tile, FastAccum::value>::epilogue_type, EpilogueEVT>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( sizeof(typename CollectiveEpilogue::SharedStorage))>, typename Schedule<large_tile, FastAccum::value>::type>:: CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
수정본
파일 열기
// Cutlass rowwise kernel for SM100 template < typename TileShape, typename ClusterShape, typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, typename DtypeBias> void f8f8bf16_rowwise_impl_sm100( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, at::Tensor w_scale, std::optional<at::Tensor> bias, at::Tensor out, const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); // Workaround for https://github.com/pytorch/pytorch/issues/133334. if (M % 256 > 0) { int padded_M = ((M - 1) / 256 + 1) * 256; at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1}); padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M) .copy_(std::move(x_scale)); x_scale = std::move(padded_x_scale); } using LayoutInputA = cutlass::layout::RowMajor; constexpr int AlignmentInputA = 16 / sizeof(DtypeA); using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); using LayoutOutput = std::conditional_t< Transposed::value, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. constexpr int ColBroadcastStages = 0; constexpr int RowBroadcastStages = 0; using XScale = cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>; using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>; using Bias = std::conditional_t< Transposed::value, cutlass::epilogue::fusion:: Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>, cutlass::epilogue::fusion:: Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AccumScale = cutlass::epilogue::fusion::Sm90EVT< Multiply, WScale, cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< Add, Bias, AccumScale>>; using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutInputA, AlignmentInputA, DtypeB, LayoutInputB, AlignmentInputB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduleType>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int, int, int>, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using StrideInputA = typename Gemm::GemmKernel::StrideA; using StrideInputB = typename Gemm::GemmKernel::StrideB; using StrideOutput = typename Gemm::GemmKernel::StrideC; StrideInputA stride_a = cutlass::make_cute_packed_stride( StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1)); StrideInputB stride_b = cutlass::make_cute_packed_stride( StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1)); StrideOutput stride_output = cutlass::make_cute_packed_stride( StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1)); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, {reinterpret_cast<DtypeA*>(XQ.data_ptr()), stride_a, reinterpret_cast<DtypeB*>(WQ.data_ptr()), stride_b}, {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) : nullptr}, {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}}}}}, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output, reinterpret_cast<DtypeOutput*>(out.data_ptr()), stride_output}}; Gemm gemm; // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Ensure persistent kernels leave enough free SMs for NCCL background ops. if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { arguments.hw_info.sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } // Set the swizzle size arguments.scheduler.max_swizzle_size = swizzle; // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast<int64_t>(workspace_size)}, at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot implement"); } // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } status = gemm(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error( std::string("cutlass cannot run") + cutlass::cutlassGetStatusString(status)); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }
비교하기