/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp Source File
gemm_quant_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <string>
7 
8 #include "ck_tile/core.hpp"
14 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
19 namespace detail {
20 // Helper templates for safe type extraction
21 template <typename T, typename Default>
23 {
24  using type = Default;
25 };
26 
27 template <typename T, typename Default>
28  requires requires { typename T::AQLayout; }
29 struct get_aq_layout_or<T, Default>
30 {
31  using type = typename T::AQLayout;
32 };
33 
34 template <typename T, typename Default>
36 {
37  using type = Default;
38 };
39 
40 template <typename T, typename Default>
41  requires requires { typename T::BQLayout; }
42 struct get_bq_layout_or<T, Default>
43 {
44  using type = typename T::BQLayout;
45 };
46 
47 template <typename T, typename Default>
49 {
50  using type = Default;
51 };
52 
53 template <typename T, typename Default>
54  requires requires { typename T::AQDataType; }
55 struct get_aq_data_type_or<T, Default>
56 {
57  using type = typename T::AQDataType;
58 };
59 
60 template <typename T, typename Default>
62 {
63  using type = Default;
64 };
65 
66 template <typename T, typename Default>
67  requires requires { typename T::BQDataType; }
68 struct get_bq_data_type_or<T, Default>
69 {
70  using type = typename T::BQDataType;
71 };
72 
73 template <typename T, typename Default>
75 {
76  using type = Default;
77 };
78 
79 template <typename T, typename Default>
80  requires requires { typename T::PreshuffleQuant; }
81 struct get_preshuffle_or<T, Default>
82 {
83  using type = typename T::PreshuffleQuant;
84 };
85 } // namespace detail
86 
88 {
91  index_t N_,
92  index_t K_,
93  index_t QK_A_,
94  index_t QK_B_,
95  index_t stride_A_,
96  index_t stride_B_,
97  index_t stride_C_,
98  index_t stride_AQ_,
99  index_t stride_BQ_)
100  : M(M_),
101  N(N_),
102  K(K_),
103  QK_A(QK_A_),
104  QK_B(QK_B_),
105  stride_A(stride_A_),
106  stride_B(stride_B_),
107  stride_C(stride_C_),
108  stride_AQ(stride_AQ_),
109  stride_BQ(stride_BQ_)
110  {
111  }
112 
123 };
124 
126 {
128  CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
129  const void* b_ptr_,
130  void* c_ptr_,
131  const void* aq_ptr_,
132  const void* bq_ptr_,
133  index_t k_batch_,
134  index_t M_,
135  index_t N_,
136  index_t K_,
137  index_t QK_A_,
138  index_t QK_B_,
139  index_t stride_A_,
140  index_t stride_B_,
141  index_t stride_C_,
142  index_t stride_AQ_,
143  index_t stride_BQ_)
145  M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
146  a_ptr(a_ptr_),
147  b_ptr(b_ptr_),
148  aq_ptr(aq_ptr_),
149  bq_ptr(bq_ptr_),
150  c_ptr(c_ptr_),
151  k_batch(k_batch_)
152  {
153  }
154 
155  const void* a_ptr = nullptr;
156  const void* b_ptr = nullptr;
157  const void* aq_ptr = nullptr;
158  const void* bq_ptr = nullptr;
159  void* c_ptr = nullptr;
161 };
162 
164 {
165  const void* a_ptr;
166  const void* b_ptr;
167  const void* aq_ptr;
168  const void* bq_ptr;
169  void* c_ptr;
181 };
182 
183 template <typename TilePartitioner_,
184  typename GemmPipeline_,
185  typename EpiloguePipeline_,
186  QuantType QuantType_>
188 {
195 
200 
201  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
202  static constexpr bool PreshuffleQuant = remove_cvref_t<
204 
209 
210  using AQDataType =
212  using BQDataType =
214 
215  static constexpr auto I0 = number<0>(); // A Tensor
216  static constexpr auto I1 = number<1>(); // AQ Tensor
217  static constexpr auto I2 = number<2>(); // B Tensor
218  static constexpr auto I3 = number<3>(); // BQ Tensor
219  static constexpr auto I4 = number<4>(); // C Tensor
220 
221  static constexpr auto kQuantType = QuantType_;
222 
223  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
224  {
225  // clang-format off
226  return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
227  // clang-format on
228  }
229 
230  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
231  {
232  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
233  }
234 
235  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
236 
237  CK_TILE_HOST static constexpr QuantGemmKernelArgs
239  {
240  return QuantGemmKernelArgs{hostArgs.a_ptr,
241  hostArgs.b_ptr,
242  hostArgs.aq_ptr,
243  hostArgs.bq_ptr,
244  hostArgs.c_ptr,
245  hostArgs.M,
246  hostArgs.N,
247  hostArgs.K,
248  hostArgs.QK_A,
249  hostArgs.QK_B,
250  hostArgs.stride_A,
251  hostArgs.stride_B,
252  hostArgs.stride_C,
253  hostArgs.stride_AQ,
254  hostArgs.stride_BQ,
255  hostArgs.k_batch};
256  }
257 
259  {
260  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
261  }
262 
264  {
265  __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
266  const std::size_t k_id = blockIdx.z)
267  {
268  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
269  const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
270  const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
271 
272  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
273  {
274  a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
275  }
276  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
277  {
278  a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
279  }
280 
281  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
282  {
283  b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
284  }
285  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
286  {
287  b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
288  }
289 
290  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
291  {
292  splitted_k = __builtin_amdgcn_readfirstlane(KRead);
293  }
294  else
295  {
296  splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
297  }
298  }
299 
303  };
304 
306  {
307  if(kargs.k_batch != 1)
308  {
309  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
310  {
311  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
312  }
313  return false;
314  }
315 
316  if constexpr(kQuantType == QuantType::AQuantGrouped)
317  {
318  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
319  if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
320  {
321  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
322  {
323  CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!");
324  }
325  return false;
326  }
327  }
328 
329  // NOTE: no kernel currently uses BQuant like this:
330  if constexpr(kQuantType == QuantType::BQuantGrouped)
331  {
332  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
333  if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
334  {
335  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
336  {
337  CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!");
338  }
339  return false;
340  }
341  }
342 
343  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
344  {
345  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
346  GemmPipeline::kPadK == false)
347  {
348  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
349  {
350  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
351  "without padding!");
352  }
353  return false;
354  }
355  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
356  {
357  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
358  {
359  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
360  }
361  return false;
362  }
363  }
364  else
365  {
366  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
367  {
368  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
369  {
371  "Can't support M that is not a multiple of MPerBlock without padding!");
372  }
373  return false;
374  }
375  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
376  {
377  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
378  {
379  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
380  }
381  return false;
382  }
383  }
384 
385  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
386  {
387  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
388  {
389  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
390  {
392  "Can't support N that is not a multiple of NPerBlock without padding!");
393  }
394  return false;
395  }
396  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
397  {
398  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
399  {
400  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
401  }
402  return false;
403  }
404  }
405  else
406  {
407  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
408  GemmPipeline::kPadK == false)
409  {
410  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
411  {
412  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
413  "without padding!");
414  }
415  return false;
416  }
417  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
418  {
419  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
420  {
421  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
422  }
423  return false;
424  }
425  }
426 
427  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
428  {
429  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
430  {
431  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
432  {
434  "Can't support N that is not a multiple of NPerBlock without padding!");
435  }
436  return false;
437  }
438  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
439  {
440  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
441  {
442  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
443  }
444  return false;
445  }
446  }
447  else
448  {
449  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
450  {
451  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
452  {
454  "Can't support M that is not a multiple of MPerBlock without padding!");
455  }
456  return false;
457  }
458  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
459  {
460  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
461  {
462  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
463  }
464  return false;
465  }
466  }
467  return true;
468  }
469 
470  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
471  CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
472  const BDataType* b_ptr,
473  const AQDataType* aq_ptr,
474  const BQDataType* bq_ptr,
475  CDataType* c_ptr,
476  const QuantGemmKernelArgs& kargs,
477  const SplitKBatchOffset& splitk_batch_offset)
478  {
479  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
480  const auto& a_tensor_view = [&]() {
481  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
482  {
483  return make_naive_tensor_view<address_space_enum::global>(
484  a_ptr,
485  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
486  make_tuple(kargs.stride_A, 1),
487  number<GemmPipeline::GetVectorSizeA()>{},
488  number<1>{});
489  }
490  else
491  {
492  return make_naive_tensor_view<address_space_enum::global>(
493  a_ptr,
494  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
495  make_tuple(kargs.stride_A, 1),
496  number<GemmPipeline::GetVectorSizeA()>{},
497  number<1>{});
498  }
499  }();
500 
501  const auto get_padding_size = [](index_t length, index_t alignment) {
502  return ck_tile::integer_least_multiple(length, alignment) - length;
503  };
504 
505  const auto& aq_tensor_view = [&]() {
507  {
508  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
509  const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
510  const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
511 
512  const auto aq_desc =
514  make_tuple(aq_x, 1),
515  number<GemmPipeline::GetVectorSizeAQ()>{},
516  number<1>{});
517 
518  const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
519  const auto aq_pad0_desc = transform_tensor_descriptor(
520  aq_desc,
521  make_tuple(
523  make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
526 
527  const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
528  const auto wave_tile_size =
529  TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
530  const auto wave_tile_count_x =
531  ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
532  const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
533  aq_pad0_desc,
534  make_tuple(
536  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
539 
540  const auto aq_pad1_desc = transform_tensor_descriptor(
541  aq_unmerge_pad0_desc,
542  make_tuple(
544  make_pass_through_transform(wave_tile_count_x),
546  wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
549 
550  const auto pad_wave_size =
552  const auto aq_merge_pad1_desc = transform_tensor_descriptor(
553  aq_pad1_desc,
554  make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
555  make_pass_through_transform(pad_wave_size)),
558 
559  return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
560  }
561  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
562  {
563  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
564  return make_naive_tensor_view<address_space_enum::global>(
565  aq_ptr,
566  make_tuple(kargs.M, kargs.QK_A),
567  make_tuple(kargs.stride_AQ, 1),
568  number<GemmPipeline::GetVectorSizeAQ()>{},
569  number<1>{});
570  }
571  else if constexpr(kQuantType == QuantType::RowColQuant)
572  {
573  return make_naive_tensor_view<address_space_enum::global>(
574  aq_ptr,
575  make_tuple(kargs.M, kargs.N),
576  make_tuple(1, 0), // broadcasting over n
577  number<1>{},
578  number<1>{});
579  }
580  else
581  {
582  return nullptr; // TODO: use some other "empty" type for this
583  }
584  }();
585 
586  const auto& b_tensor_view = [&]() {
587  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
588  {
589  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
590  {
591  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
592  const index_t K0 = splitk_batch_offset.splitted_k / K1;
593  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
594  const auto b_k0_n_k1_desc =
596  make_tuple(kargs.N * K1, K1, I1),
598  number<1>{});
599  const auto b_n_k_desc = transform_tensor_descriptor(
600  b_k0_n_k1_desc,
605  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
606  }
607  else
608  {
609  return make_naive_tensor_view<address_space_enum::global>(
610  b_ptr,
611  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
612  make_tuple(kargs.stride_B, 1),
613  number<GemmPipeline::GetVectorSizeB()>{},
614  number<1>{});
615  }
616  }
617  else
618  {
619  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
620  {
621  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
622  const index_t K0 = splitk_batch_offset.splitted_k / K1;
623  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
624  const auto b_k0_n_k1_desc =
626  make_tuple(kargs.N * K1, K1, I1),
628  number<1>{});
629  const auto b_n_k_desc = transform_tensor_descriptor(
630  b_k0_n_k1_desc,
635  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
636  }
637  else
638  {
639  return make_naive_tensor_view<address_space_enum::global>(
640  b_ptr,
641  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
642  make_tuple(kargs.stride_B, 1),
643  number<GemmPipeline::GetVectorSizeB()>{},
644  number<1>{});
645  }
646  }
647  }();
648 
649  const auto& bq_tensor_view = [&]() {
650  if constexpr(kQuantType == QuantType::RowColQuant)
651  {
652  return make_naive_tensor_view<address_space_enum::global>(
653  bq_ptr,
654  make_tuple(kargs.M, kargs.N),
655  make_tuple(0, 1), // broadcasting over m
656  number<1>{},
657  number<1>{});
658  }
659  else if constexpr(kQuantType == QuantType::BQuantGrouped)
660  {
661  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
662  return make_naive_tensor_view<address_space_enum::global>(
663  bq_ptr,
664  make_tuple(kargs.N, kargs.QK_B),
665  make_tuple(kargs.stride_BQ, 1),
666  number<GemmPipeline::GetVectorSizeBQ()>{},
667  number<1>{});
668  }
669  else
670  {
671  return nullptr; // TODO: use some other "empty" type for this
672  }
673  }();
674 
675  // TODO: enable vector write for C in ColMajor
676  const auto& c_tensor_view = [&]() {
677  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
678  {
679  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
680  c_ptr,
681  make_tuple(kargs.M, kargs.N),
682  make_tuple(kargs.stride_C, 1),
683  number<EpiloguePipeline::GetVectorSizeC()>{},
684  number<1>{});
685  }
686  else
687  {
688  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
689  c_ptr,
690  make_tuple(kargs.M, kargs.N),
691  make_tuple(1, kargs.stride_C),
692  number<1>{},
693  number<1>{});
694  }
695  }();
696 
697  return make_tuple(
698  a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
699  }
700 
701  template <typename TensorView>
702  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
703  {
704  const auto& a_pad_view = [&]() {
705  const auto& a_tensor_view = views.at(I0);
706  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
707  {
708  return pad_tensor_view(a_tensor_view,
712  }
713  else
714  {
715  return pad_tensor_view(a_tensor_view,
719  }
720  }();
721 
722  // no padding
723  const auto& aq_pad_view = [&]() { return views.at(I1); }();
724 
725  const auto& b_pad_view = [&]() {
726  const auto& b_tensor_view = views.at(I2);
727  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
728  {
729  return pad_tensor_view(b_tensor_view,
733  }
734  else
735  {
736  return pad_tensor_view(b_tensor_view,
740  }
741  }();
742 
743  // no padding
744  const auto& bq_pad_view = [&]() { return views.at(I3); }();
745 
746  // TODO vector write in for C in ColMajor
747  const auto& c_pad_view = [&]() {
748  const auto& c_tensor_view = views.at(I4);
749  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
750  {
751  return pad_tensor_view(c_tensor_view,
755  }
756  else
757  {
758  return pad_tensor_view(c_tensor_view,
762  }
763  }();
764 
765  return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
766  }
767 
768  template <typename PadView>
769  CK_TILE_DEVICE static auto
770  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
771  {
772  const auto& a_pad_view = views.at(I0);
773  const auto& aq_pad_view = views.at(I1);
774  const auto& b_pad_view = views.at(I2);
775  const auto& bq_pad_view = views.at(I3);
776  const auto& c_pad_view = views.at(I4);
777 
778  const auto& a_block_window = [&]() {
779  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
780  {
781  return make_tile_window(a_pad_view,
784  {i_m, 0});
785  }
786  else
787  {
788  return make_tile_window(a_pad_view,
791  {0, i_m});
792  }
793  }();
794 
795  const auto& aq_block_window = [&]() {
797  {
798  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
799  constexpr auto block_m = TilePartitioner::MPerBlock;
800  constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
801  constexpr auto aqk_per_block =
802  TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
803  constexpr auto tile_window_width =
804  ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
805  constexpr auto tile_window_height = block_m / warp_m;
806  auto block_m_idx = i_m / block_m;
807  return make_tile_window(
808  aq_pad_view,
810  {block_m_idx * tile_window_height, 0});
811  }
812  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
813  {
814  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
815  constexpr auto block_m = TilePartitioner::MPerBlock;
816  constexpr auto block_k = TilePartitioner::KPerBlock;
817  return make_tile_window(
818  aq_pad_view,
819  make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
820  {i_m, 0});
821  }
822  else if constexpr(kQuantType == QuantType::RowColQuant)
823  {
824  return make_tile_window(aq_pad_view,
827  {i_m, i_n});
828  }
829  else
830  {
831  return nullptr; // TODO: use some other "empty" type?
832  }
833  }();
834 
835  const auto& b_block_window = [&]() {
836  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
837  {
838  return make_tile_window(b_pad_view,
841  {i_n, 0});
842  }
843  else
844  {
845  return make_tile_window(b_pad_view,
848  {0, i_n});
849  }
850  }();
851 
852  const auto& bq_block_window = [&]() {
853  if constexpr(kQuantType == QuantType::RowColQuant)
854  {
855  return make_tile_window(bq_pad_view,
858  {i_m, i_n});
859  }
860  else if constexpr(kQuantType == QuantType::BQuantGrouped)
861  {
862  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
863  return make_tile_window(
864  bq_pad_view,
866  number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
867  {i_n, 0});
868  }
869  else
870  {
871  return nullptr; // TODO: use some other "empty" type here
872  }
873  }();
874 
875  auto c_block_window = make_tile_window(
876  c_pad_view,
878  {i_m, i_n});
879 
880  return make_tuple(
881  a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
882  }
883 
899  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
900  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
901  const BDataType* b_ptr,
902  const AQDataType* aq_ptr,
903  const BQDataType* bq_ptr,
904  CDataType* c_ptr,
905  void* smem_ptr_0,
906  const QuantGemmKernelArgs& kargs,
907  const SplitKBatchOffset& splitk_batch_offset,
908  const index_t block_idx_m,
909  const index_t block_idx_n)
910  {
911  // Create Gemm tensor views, pad views and tile windows
912  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
913  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
914 
915  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
916  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
917 
918  const index_t num_loop = __builtin_amdgcn_readfirstlane(
919  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
920 
921  // Run GEMM cooperatively by whole workgroup.
922  const auto& a_block_window = gemm_tile_windows.at(I0);
923  const auto& b_block_window = gemm_tile_windows.at(I2);
924 
925  const auto& c_block_tile = [&]() {
926  if constexpr(kQuantType == QuantType::AQuantGrouped)
927  {
928  const auto& aq_block_window = gemm_tile_windows.at(I1);
929  return GemmPipeline{}.template operator()(
930  a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
931  }
932  else if constexpr(kQuantType == QuantType::BQuantGrouped)
933  {
934  const auto& bq_block_window = gemm_tile_windows.at(I3);
935  return GemmPipeline{}.template operator()(
936  a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
937  }
938  else if constexpr(kQuantType == QuantType::RowColQuant)
939  {
940  return GemmPipeline{}.template operator()(
941  a_block_window, b_block_window, num_loop, smem_ptr_0);
942  }
943  }();
944 
945  // Run Epilogue Pipeline
946  auto& c_block_window = gemm_tile_windows.at(I4);
947 
948  if constexpr(kQuantType == QuantType::AQuantGrouped ||
950  {
951  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
952  }
953  else if constexpr(kQuantType == QuantType::RowColQuant)
954  {
955  const auto& aq_block_window = gemm_tile_windows.at(I1);
956  const auto& bq_block_window = gemm_tile_windows.at(I3);
957  EpiloguePipeline{}(c_block_window,
958  c_block_tile,
959  c_block_window,
960  smem_ptr_0,
961  aq_block_window,
962  bq_block_window);
963  }
964  }
965 
967  {
968  const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
969  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
970  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
971  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
972 
973  const SplitKBatchOffset splitk_batch_offset(kargs);
974  // options
975  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
976  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
977  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
978  const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
979  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
980 
981  // allocate LDS
982  __shared__ char smem_ptr_0[GetSmemSize()];
983 
984  assert(kargs.k_batch == 1);
985  RunGemm(
986  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
987  }
988 };
989 
990 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
requires requires
Definition: gemm_quant_kernel.hpp:28
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
QuantType
Definition: tile_gemm_quant_traits.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_quant_kernel.hpp:126
void * c_ptr
Definition: gemm_quant_kernel.hpp:159
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:157
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:158
const void * b_ptr
Definition: gemm_quant_kernel.hpp:156
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition: gemm_quant_kernel.hpp:160
const void * a_ptr
Definition: gemm_quant_kernel.hpp:155
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:128
Definition: gemm_quant_kernel.hpp:264
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_quant_kernel.hpp:265
index_t a_k_split_offset
Definition: gemm_quant_kernel.hpp:300
index_t b_k_split_offset
Definition: gemm_quant_kernel.hpp:301
index_t splitted_k
Definition: gemm_quant_kernel.hpp:302
Definition: gemm_quant_kernel.hpp:164
index_t k_batch
Definition: gemm_quant_kernel.hpp:180
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:179
const void * b_ptr
Definition: gemm_quant_kernel.hpp:166
void * c_ptr
Definition: gemm_quant_kernel.hpp:169
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:167
index_t stride_A
Definition: gemm_quant_kernel.hpp:175
index_t M
Definition: gemm_quant_kernel.hpp:170
const void * a_ptr
Definition: gemm_quant_kernel.hpp:165
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:168
index_t QK_B
Definition: gemm_quant_kernel.hpp:174
index_t K
Definition: gemm_quant_kernel.hpp:172
index_t QK_A
Definition: gemm_quant_kernel.hpp:173
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:178
index_t N
Definition: gemm_quant_kernel.hpp:171
index_t stride_C
Definition: gemm_quant_kernel.hpp:177
index_t stride_B
Definition: gemm_quant_kernel.hpp:176
Definition: gemm_quant_kernel.hpp:188
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:219
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:218
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_quant_kernel.hpp:230
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_quant_kernel.hpp:190
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_quant_kernel.hpp:191
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_quant_kernel.hpp:702
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_quant_kernel.hpp:189
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: gemm_quant_kernel.hpp:208
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:215
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition: gemm_quant_kernel.hpp:966
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_quant_kernel.hpp:207
static constexpr index_t kBlockSize
Definition: gemm_quant_kernel.hpp:201
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_quant_kernel.hpp:193
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_quant_kernel.hpp:194
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_quant_kernel.hpp:471
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:216
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_quant_kernel.hpp:192
static constexpr bool PreshuffleQuant
Definition: gemm_quant_kernel.hpp:202
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:305
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: gemm_quant_kernel.hpp:211
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: gemm_quant_kernel.hpp:213
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_quant_kernel.hpp:206
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:217
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_quant_kernel.hpp:258
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:900
static constexpr CK_TILE_HOST QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition: gemm_quant_kernel.hpp:238
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_kernel.hpp:223
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:770
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition: gemm_quant_kernel.hpp:199
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_quant_kernel.hpp:205
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition: gemm_quant_kernel.hpp:197
static constexpr auto kQuantType
Definition: gemm_quant_kernel.hpp:221
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_quant_kernel.hpp:235
Definition: gemm_quant_kernel.hpp:88
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:121
index_t N
Definition: gemm_quant_kernel.hpp:114
index_t K
Definition: gemm_quant_kernel.hpp:115
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:122
index_t stride_C
Definition: gemm_quant_kernel.hpp:120
index_t stride_B
Definition: gemm_quant_kernel.hpp:119
index_t stride_A
Definition: gemm_quant_kernel.hpp:118
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:90
index_t QK_A
Definition: gemm_quant_kernel.hpp:116
index_t QK_B
Definition: gemm_quant_kernel.hpp:117
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition: gemm_quant_kernel.hpp:113
Definition: integral_constant.hpp:13
typename T::AQDataType type
Definition: gemm_quant_kernel.hpp:57
Definition: gemm_quant_kernel.hpp:49
Default type
Definition: gemm_quant_kernel.hpp:50
typename T::AQLayout type
Definition: gemm_quant_kernel.hpp:31
Definition: gemm_quant_kernel.hpp:23
Default type
Definition: gemm_quant_kernel.hpp:24
typename T::BQDataType type
Definition: gemm_quant_kernel.hpp:70
Definition: gemm_quant_kernel.hpp:62
Default type
Definition: gemm_quant_kernel.hpp:63
typename T::BQLayout type
Definition: gemm_quant_kernel.hpp:44
Definition: gemm_quant_kernel.hpp:36
Default type
Definition: gemm_quant_kernel.hpp:37
typename T::PreshuffleQuant type
Definition: gemm_quant_kernel.hpp:83
Definition: gemm_quant_kernel.hpp:75
Default type
Definition: gemm_quant_kernel.hpp:76
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145