/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File
gemm_quant_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <string>
7 
8 #include "ck_tile/core.hpp"
14 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
19 namespace detail {
20 // Helper templates for safe type extraction
21 template <typename, typename Default, typename = void>
23 {
24  using type = Default;
25 };
26 
27 template <typename T, typename Default>
28 struct get_aq_layout_or<T, Default, std::void_t<typename T::AQLayout>>
29 {
30  using type = typename T::AQLayout;
31 };
32 
33 template <typename, typename Default, typename = void>
35 {
36  using type = Default;
37 };
38 
39 template <typename T, typename Default>
40 struct get_bq_layout_or<T, Default, std::void_t<typename T::BQLayout>>
41 {
42  using type = typename T::BQLayout;
43 };
44 
45 template <typename, typename Default, typename = void>
47 {
48  using type = Default;
49 };
50 
51 template <typename T, typename Default>
52 struct get_aq_data_type_or<T, Default, std::void_t<typename T::AQDataType>>
53 {
54  using type = typename T::AQDataType;
55 };
56 
57 template <typename, typename Default, typename = void>
59 {
60  using type = Default;
61 };
62 
63 template <typename T, typename Default>
64 struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
65 {
66  using type = typename T::BQDataType;
67 };
68 
69 template <typename, typename = void>
71 {
72  static constexpr bool value = false;
73 };
74 
75 template <typename T>
76 struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
77 {
78  static constexpr bool value = T::PreshuffleQuant;
79 };
80 
81 template <typename, typename = void>
83 {
84  static constexpr bool value = false;
85 };
86 
87 template <typename T>
88 struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
89 {
90  static constexpr bool value = T::PreshuffleB;
91 };
92 } // namespace detail
93 
95 {
98  index_t N_,
99  index_t K_,
100  index_t QK_A_,
101  index_t QK_B_,
102  index_t stride_A_,
103  index_t stride_B_,
104  index_t stride_C_,
105  index_t stride_AQ_,
106  index_t stride_BQ_)
107  : M(M_),
108  N(N_),
109  K(K_),
110  QK_A(QK_A_),
111  QK_B(QK_B_),
112  stride_A(stride_A_),
113  stride_B(stride_B_),
114  stride_C(stride_C_),
115  stride_AQ(stride_AQ_),
116  stride_BQ(stride_BQ_)
117  {
118  }
119 
130 };
131 
133 {
135  CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
136  const void* b_ptr_,
137  void* c_ptr_,
138  const void* aq_ptr_,
139  const void* bq_ptr_,
140  index_t k_batch_,
141  index_t M_,
142  index_t N_,
143  index_t K_,
144  index_t QK_A_,
145  index_t QK_B_,
146  index_t stride_A_,
147  index_t stride_B_,
148  index_t stride_C_,
149  index_t stride_AQ_,
150  index_t stride_BQ_)
152  M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
153  a_ptr(a_ptr_),
154  b_ptr(b_ptr_),
155  aq_ptr(aq_ptr_),
156  bq_ptr(bq_ptr_),
157  c_ptr(c_ptr_),
158  k_batch(k_batch_)
159  {
160  }
161 
162  const void* a_ptr = nullptr;
163  const void* b_ptr = nullptr;
164  const void* aq_ptr = nullptr;
165  const void* bq_ptr = nullptr;
166  void* c_ptr = nullptr;
168 };
169 
171 {
172  const void* a_ptr;
173  const void* b_ptr;
174  const void* aq_ptr;
175  const void* bq_ptr;
176  void* c_ptr;
188 };
189 
190 template <typename TilePartitioner_,
191  typename GemmPipeline_,
192  typename EpiloguePipeline_,
193  QuantType QuantType_>
195 {
202 
207 
208  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
209  static constexpr bool PreshuffleQuant =
212 
217 
218  using AQDataType =
220  using BQDataType =
222 
223  static constexpr auto I0 = number<0>(); // A Tensor
224  static constexpr auto I1 = number<1>(); // AQ Tensor
225  static constexpr auto I2 = number<2>(); // B Tensor
226  static constexpr auto I3 = number<3>(); // BQ Tensor
227  static constexpr auto I4 = number<4>(); // C Tensor
228 
229  static constexpr auto kQuantType = QuantType_;
230 
231  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
232  {
233  // clang-format off
234  return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
235  // clang-format on
236  }
237 
238  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
239  {
240  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
241  }
242 
243  CK_TILE_HOST static auto BlockSize()
244  {
245  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
246  }
247 
248  CK_TILE_HOST static constexpr QuantGemmKernelArgs
250  {
251  return QuantGemmKernelArgs{hostArgs.a_ptr,
252  hostArgs.b_ptr,
253  hostArgs.aq_ptr,
254  hostArgs.bq_ptr,
255  hostArgs.c_ptr,
256  hostArgs.M,
257  hostArgs.N,
258  hostArgs.K,
259  hostArgs.QK_A,
260  hostArgs.QK_B,
261  hostArgs.stride_A,
262  hostArgs.stride_B,
263  hostArgs.stride_C,
264  hostArgs.stride_AQ,
265  hostArgs.stride_BQ,
266  hostArgs.k_batch};
267  }
268 
270  {
271  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
272  }
273 
274  private:
275  CK_TILE_DEVICE static constexpr index_t get_padding_size(index_t length, index_t alignment)
276  {
277  return ck_tile::integer_least_multiple(length, alignment) - length;
278  };
279  // ===================================================================
280  // Helper: Create Pre-shuffled Quantization Tensor Descriptor
281  // ===================================================================
282  template <index_t KPerBlockBQ,
283  index_t NPerBlock,
284  index_t WarpTileN,
285  index_t GetVectorSizeBQ,
286  typename BQDataType_>
287  CK_TILE_DEVICE static auto
288  MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B)
289  {
290  // Step 1: Calculate base BQ tensor dimensions
291  // ----------------------------------------------------------
292  // bq_x: Number of quantization groups in N dimension
293  // = N * KPerBlockBQ, where KPerBlockBQ is the number of
294  // K-dimension groups per block
295  // bq_y: Number of quantization groups in K dimension
296  // = Total K groups (QK_B) / groups per block
297  const auto bq_x = N * KPerBlockBQ;
298  const auto bq_y = QK_B / KPerBlockBQ;
299 
300  const auto bq_desc = make_naive_tensor_descriptor(
301  make_tuple(bq_y, bq_x), make_tuple(bq_x, 1), number<GetVectorSizeBQ>{}, number<1>{});
302 
303  // Step 2: First padding transformation (block-level alignment)
304  // ----------------------------------------------------------
305  // Pad the X dimension to be a multiple of block_tile_size to ensure
306  // each thread block can process complete tiles without edge cases
307  const auto block_tile_size = NPerBlock * KPerBlockBQ;
308  const auto bq_pad0_desc = transform_tensor_descriptor(
309  bq_desc,
311  make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))),
312  make_tuple(sequence<0>{}, sequence<1>{}),
313  make_tuple(sequence<0>{}, sequence<1>{}));
314 
315  // Step 3: Unmerge transformation (wave-level decomposition)
316  // ----------------------------------------------------------
317  // Split the X dimension into [wave_tile_count_x, wave_tile_size]
318  // This separates the work into tiles that can be processed by
319  // individual warps/waves
320  const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
321  const auto wave_tile_size = WarpTileN * KPerBlockBQ;
322  const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
323 
324  const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
325  bq_pad0_desc,
327  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
328  make_tuple(sequence<0>{}, sequence<1>{}),
329  make_tuple(sequence<0>{}, sequence<1, 2>{}));
330 
331  // Step 4: Second padding transformation (warp-level alignment)
332  // ----------------------------------------------------------
333  // Pad wave_tile_size to be a multiple of warp_size (typically 32 or 64)
334  // This ensures coalesced memory accesses within each warp
335  const auto bq_pad1_desc = transform_tensor_descriptor(
336  bq_unmerge_pad0_desc,
338  make_pass_through_transform(wave_tile_count_x),
339  make_right_pad_transform(wave_tile_size,
340  get_padding_size(wave_tile_size, get_warp_size()))),
341  make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
342  make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
343 
344  // Step 5: Final merge transformation (prepare for indexing)
345  // ----------------------------------------------------------
346  // Merge [bq_y, wave_tile_count_x] into a single outer dimension
347  // This creates a 2D layout: [merged_outer_dim, pad_wave_size]
348  // where merged_outer_dim = bq_y * wave_tile_count_x
349  // This layout facilitates efficient block-to-data mapping
350  const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
351  const auto bq_merge_pad1_desc = transform_tensor_descriptor(
352  bq_pad1_desc,
353  make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)),
354  make_pass_through_transform(pad_wave_size)),
355  make_tuple(sequence<0, 1>{}, sequence<2>{}),
356  make_tuple(sequence<0>{}, sequence<1>{}));
357 
358  return make_tensor_view<address_space_enum::global>(bq_ptr, bq_merge_pad1_desc);
359  }
360 
361  public:
363  {
364  __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
365  const std::size_t k_id = blockIdx.z)
366  {
367  constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
368  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
369  const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
370 
371  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
372  {
374  }
375  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
376  {
377  a_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_A);
378  }
379 
380  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
381  {
382  b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
383  }
384  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
385  {
387  }
388 
389  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
390  {
392  }
393  else
394  {
395  splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
396  }
397  }
398 
402  };
403 
405  {
406  if(kargs.k_batch != 1)
407  {
408  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
409  {
410  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
411  }
412  return false;
413  }
414 
415  if constexpr(kQuantType == QuantType::AQuantGrouped)
416  {
417  if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
418  {
419  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
420  {
421  CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!");
422  }
423  return false;
424  }
425  }
426 
427  if constexpr(kQuantType == QuantType::BQuantGrouped)
428  {
429  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
430  if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
431  {
432  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
433  {
434  CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!");
435  }
436  return false;
437  }
438  }
439 
440  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
441  {
442  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
443  GemmPipeline::kPadK == false)
444  {
445  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
446  {
447  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
448  "without padding!");
449  }
450  return false;
451  }
452  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
453  {
454  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
455  {
456  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
457  }
458  return false;
459  }
460  }
461  else
462  {
463  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
464  {
465  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
466  {
468  "Can't support M that is not a multiple of MPerBlock without padding!");
469  }
470  return false;
471  }
472  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
473  {
474  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
475  {
476  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
477  }
478  return false;
479  }
480  }
481 
482  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
483  {
484  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
485  {
486  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
487  {
489  "Can't support N that is not a multiple of NPerBlock without padding!");
490  }
491  return false;
492  }
493  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
494  {
495  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
496  {
497  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
498  }
499  return false;
500  }
501  }
502  else
503  {
504  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
505  GemmPipeline::kPadK == false)
506  {
507  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
508  {
509  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
510  "without padding!");
511  }
512  return false;
513  }
514  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
515  {
516  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
517  {
518  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
519  }
520  return false;
521  }
522  }
523 
524  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
525  {
526  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
527  {
528  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
529  {
531  "Can't support N that is not a multiple of NPerBlock without padding!");
532  }
533  return false;
534  }
535  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
536  {
537  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
538  {
539  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
540  }
541  return false;
542  }
543  }
544  else
545  {
546  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
547  {
548  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
549  {
551  "Can't support M that is not a multiple of MPerBlock without padding!");
552  }
553  return false;
554  }
555  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
556  {
557  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
558  {
559  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
560  }
561  return false;
562  }
563  }
564  return true;
565  }
566 
567  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
568  CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
569  const BDataType* b_ptr,
570  const AQDataType* aq_ptr,
571  const BQDataType* bq_ptr,
572  CDataType* c_ptr,
573  const QuantGemmKernelArgs& kargs,
574  const SplitKBatchOffset& splitk_batch_offset)
575  {
576 
577  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
578  const auto& a_tensor_view = [&]() {
579  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
580  {
581  return make_naive_tensor_view<address_space_enum::global>(
582  a_ptr,
583  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
584  make_tuple(kargs.stride_A, 1),
585  number<GemmPipeline::GetVectorSizeA()>{},
586  number<1>{});
587  }
588  else
589  {
590  return make_naive_tensor_view<address_space_enum::global>(
591  a_ptr,
592  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
593  make_tuple(kargs.stride_A, 1),
594  number<GemmPipeline::GetVectorSizeA()>{},
595  number<1>{});
596  }
597  }();
598 
599  const auto& aq_tensor_view = [&]() {
601  {
602  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
603  const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
604  const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
605  const auto aq_desc =
607  make_tuple(aq_x, 1),
608  number<GemmPipeline::GetVectorSizeAQ()>{},
609  number<1>{});
610 
611  const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
612  const auto aq_pad0_desc = transform_tensor_descriptor(
613  aq_desc,
614  make_tuple(
616  make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
619 
620  const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
621  const auto wave_tile_size =
622  GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
623  const auto wave_tile_count_x =
624  ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
625 
626  const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
627  aq_pad0_desc,
628  make_tuple(
630  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
633 
634  const auto aq_pad1_desc = transform_tensor_descriptor(
635  aq_unmerge_pad0_desc,
636  make_tuple(
638  make_pass_through_transform(wave_tile_count_x),
640  wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
643 
644  const auto pad_wave_size =
646  const auto aq_merge_pad1_desc = transform_tensor_descriptor(
647  aq_pad1_desc,
648  make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
649  make_pass_through_transform(pad_wave_size)),
652 
653  return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
654  }
655  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
656  {
657  if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
658  {
659  return make_naive_tensor_view<address_space_enum::global>(
660  aq_ptr,
661  make_tuple(kargs.M, kargs.QK_A),
662  make_tuple(kargs.stride_AQ, 1),
663  number<GemmPipeline::GetVectorSizeAQ()>{},
664  number<1>{});
665  }
666  else // Column major AQ
667  {
668  return make_naive_tensor_view<address_space_enum::global>(
669  aq_ptr,
670  make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions
671  make_tuple(kargs.stride_AQ, 1), // Same stride pattern
672  number<GemmPipeline::GetVectorSizeAQ()>{},
673  number<1>{});
674  }
675  }
676  else if constexpr(kQuantType == QuantType::RowColQuant)
677  {
678  return make_naive_tensor_view<address_space_enum::global>(
679  aq_ptr,
680  make_tuple(kargs.M, kargs.N),
681  make_tuple(1, 0), // broadcasting over n
682  number<1>{},
683  number<1>{});
684  }
685  else
686  {
687  return nullptr; // TODO: use some other "empty" type for this
688  }
689  }();
690 
691  const auto& b_tensor_view = [&]() {
692  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
693  {
694  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
695  {
696  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
697  const index_t K0 = splitk_batch_offset.splitted_k / K1;
698  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
699  const auto b_k0_n_k1_desc =
701  make_tuple(kargs.N * K1, K1, I1),
703  number<1>{});
704  const auto b_n_k_desc = transform_tensor_descriptor(
705  b_k0_n_k1_desc,
710  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
711  }
712  else
713  {
714  return make_naive_tensor_view<address_space_enum::global>(
715  b_ptr,
716  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
717  make_tuple(kargs.stride_B, 1),
718  number<GemmPipeline::GetVectorSizeB()>{},
719  number<1>{});
720  }
721  }
722  else
723  {
724  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
725  {
726  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
727  const index_t K0 = splitk_batch_offset.splitted_k / K1;
728  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
729  const auto b_k0_n_k1_desc =
731  make_tuple(kargs.N * K1, K1, I1),
733  number<1>{});
734  const auto b_n_k_desc = transform_tensor_descriptor(
735  b_k0_n_k1_desc,
740  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
741  }
742  else
743  {
744  if constexpr(PreshuffleB)
745  {
746  index_t kFlatK = GemmPipeline::flatKPerWarp *
747  (splitk_batch_offset.splitted_k /
748  GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
749  index_t kFlatN = kargs.N * kargs.K / kFlatK;
750 
751  return make_naive_tensor_view<address_space_enum::global>(
752  b_ptr,
753  make_tuple(kFlatN, kFlatK),
754  make_tuple(kFlatK, 1),
755  number<GemmPipeline::GetVectorSizeB()>{},
756  number<1>{});
757  }
758  else
759  {
760  return make_naive_tensor_view<address_space_enum::global>(
761  b_ptr,
762  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
763  make_tuple(kargs.stride_B, 1),
764  number<GemmPipeline::GetVectorSizeB()>{},
765  number<1>{});
766  }
767  }
768  }
769  }();
770 
771  const auto& bq_tensor_view = [&]() {
772  if constexpr(kQuantType == QuantType::RowColQuant)
773  {
774  return make_naive_tensor_view<address_space_enum::global>(
775  bq_ptr,
776  make_tuple(kargs.M, kargs.N),
777  make_tuple(0, 1), // broadcasting over m
778  number<1>{},
779  number<1>{});
780  }
781  else if constexpr(kQuantType == QuantType::BQuantGrouped)
782  {
783  if constexpr(PreshuffleQuant)
784  {
785  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
786 
787  return MakePreshuffledQuantTensorView<
788  GemmPipeline::KPerBlockBQ,
789  GemmPipeline::NPerBlock,
790  TilePartitioner::BlockGemmShape::WarpTile::at(I1),
791  GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
792  }
793  else
794  {
795  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
797  return make_naive_tensor_view<address_space_enum::global>(
798  bq_ptr,
799  make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
800  make_tuple(kargs.stride_BQ, 1),
801  number<GemmPipeline::GetVectorSizeBQ()>{},
802  number<1>{});
803  }
804  }
805  else
806  {
807  return nullptr; // TODO: use some other "empty" type for this
808  }
809  }();
810 
811  // TODO: enable vector write for C in ColMajor
812  const auto& c_tensor_view = [&]() {
813  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
814  {
815  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
816  c_ptr,
817  make_tuple(kargs.M, kargs.N),
818  make_tuple(kargs.stride_C, 1),
819  number<EpiloguePipeline::GetVectorSizeC()>{},
820  number<1>{});
821  }
822  else
823  {
824  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
825  c_ptr,
826  make_tuple(kargs.M, kargs.N),
827  make_tuple(1, kargs.stride_C),
828  number<1>{},
829  number<1>{});
830  }
831  }();
832 
833  return make_tuple(
834  a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
835  }
836 
837  template <typename TensorView>
838  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
839  {
840  const auto& a_pad_view = [&]() {
841  const auto& a_tensor_view = views.at(I0);
842  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
843  {
844  return pad_tensor_view(a_tensor_view,
848  }
849  else
850  {
851  return pad_tensor_view(a_tensor_view,
855  }
856  }();
857 
858  // no padding
859  const auto& aq_pad_view = [&]() { return views.at(I1); }();
860 
861  const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view
862 
863  const auto& b_pad_view = [&]() {
864  const auto& b_tensor_view = views.at(I2);
865  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
866  {
867  return pad_tensor_view(b_tensor_view,
871  }
872  else
873  {
874  return pad_tensor_view(b_tensor_view,
878  }
879  }();
880 
881  // no padding
882  const auto& bq_pad_view = [&]() { return views.at(I3); }();
883 
884  // TODO vector write in for C in ColMajor
885  const auto& c_pad_view = [&]() {
886  const auto& c_tensor_view = views.at(I4);
887  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
888  {
889  return pad_tensor_view(c_tensor_view,
893  }
894  else
895  {
896  return pad_tensor_view(c_tensor_view,
900  }
901  }();
902  if constexpr(PreshuffleB)
903  {
904 
905  return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
906  }
907  else
908  {
909  return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
910  }
911  }
912 
913  template <typename PadView>
914  CK_TILE_DEVICE static auto
915  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
916  {
917 
918  const auto& a_pad_view = views.at(I0);
919  const auto& aq_pad_view = views.at(I1);
920  const auto& b_pad_view = views.at(I2);
921  const auto& bq_pad_view = views.at(I3);
922  const auto& c_pad_view = views.at(I4);
923  const auto& a_block_window = [&]() {
924  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
925  {
926  return make_tile_window(a_pad_view,
929  {i_m, 0});
930  }
931  else
932  {
933  return make_tile_window(a_pad_view,
936  {0, i_m});
937  }
938  }();
939 
940  const auto& aq_block_window = [&]() {
942  {
943  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
945  constexpr auto block_m = TilePartitioner::MPerBlock;
946  constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
947  constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
948  constexpr auto tile_window_width =
949  ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
950  constexpr auto tile_window_height = block_m / warp_m;
951  auto block_m_idx = i_m / block_m;
952  return make_tile_window(
953  aq_pad_view,
955  {block_m_idx * tile_window_height, 0});
956  }
957  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
958  {
960  constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
961  constexpr auto block_m = TilePartitioner::MPerBlock;
962  if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
963  {
964  return make_tile_window(aq_pad_view,
966  {i_m, 0});
967  }
968  else // Column major AQ
969  {
970  return make_tile_window(aq_pad_view,
972  {0, i_m});
973  }
974  }
975  else if constexpr(kQuantType == QuantType::RowColQuant)
976  {
977  return make_tile_window(aq_pad_view,
980  {i_m, i_n});
981  }
982  else
983  {
984  return nullptr; // TODO: use some other "empty" type?
985  }
986  }();
987 
988  const auto& b_block_window = [&]() {
989  if constexpr(PreshuffleB)
990  {
991 
992  return make_tile_window(
993  b_pad_view,
996  {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
997  }
998  else
999  {
1000  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
1001  {
1002  return make_tile_window(b_pad_view,
1005  {i_n, 0});
1006  }
1007  else
1008  {
1009  return make_tile_window(b_pad_view,
1012  {0, i_n});
1013  }
1014  }
1015  }();
1016 
1017  const auto& bq_block_window = [&]() {
1018  if constexpr(kQuantType == QuantType::RowColQuant)
1019  {
1020  return make_tile_window(bq_pad_view,
1023  {i_m, i_n});
1024  }
1025  else if constexpr(kQuantType == QuantType::BQuantGrouped)
1026  {
1027  if constexpr(PreshuffleQuant)
1028  {
1029  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
1031  constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
1032  constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
1033  constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
1034  constexpr auto tile_window_width =
1035  ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
1036  constexpr auto tile_window_height = block_n / warp_n;
1037  auto block_n_idx = i_n / block_n;
1038 
1039  return make_tile_window(
1040  bq_pad_view,
1042  {block_n_idx * tile_window_height, 0});
1043  }
1044  else
1045  {
1046  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
1048  return make_tile_window(
1049  bq_pad_view,
1051  number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
1052  {i_n / QuantGroupSize::kN, 0});
1053  }
1054  }
1055  else
1056  {
1057  return nullptr; // TODO: use some other "empty" type here
1058  }
1059  }();
1060 
1061  auto c_block_window = make_tile_window(
1062  c_pad_view,
1064  {i_m, i_n});
1065 
1066  return make_tuple(
1067  a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
1068  }
1069 
1086  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1087  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
1088  const BDataType* b_ptr,
1089  const AQDataType* aq_ptr,
1090  const BQDataType* bq_ptr,
1091  CDataType* c_ptr,
1092  void* smem_ptr_0,
1093  const QuantGemmKernelArgs& kargs,
1094  const SplitKBatchOffset& splitk_batch_offset,
1095  const index_t block_idx_m,
1096  const index_t block_idx_n)
1097  {
1098  // Create Gemm tensor views, pad views and tile windows
1099  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
1100  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
1101 
1102  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1103  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1104 
1105  const index_t num_loop =
1106  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1107 
1108  // Run GEMM cooperatively by whole workgroup.
1109  const auto& a_block_window = gemm_tile_windows.at(I0);
1110  const auto& b_block_window = gemm_tile_windows.at(I2);
1111 
1112  const auto& c_block_tile = [&]() {
1113  if constexpr(kQuantType == QuantType::AQuantGrouped)
1114  {
1115  const auto& aq_block_window = gemm_tile_windows.at(I1);
1116  index_t m = 0;
1117  if constexpr(PreshuffleQuant)
1118  {
1119  m = kargs.M;
1120  }
1121  return GemmPipeline{}.template operator()(
1122  a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m);
1123  }
1124  else if constexpr(kQuantType == QuantType::BQuantGrouped)
1125  {
1126  const auto& bq_block_window = gemm_tile_windows.at(I3);
1127  index_t n = 0;
1128  if constexpr(PreshuffleQuant)
1129  {
1130  n = kargs.N;
1131  }
1132  return GemmPipeline{}.template operator()(
1133  a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
1134  }
1135  else if constexpr(kQuantType == QuantType::RowColQuant ||
1137  {
1138  return GemmPipeline{}.template operator()(
1139  a_block_window, b_block_window, num_loop, smem_ptr_0);
1140  }
1141  }();
1142 
1143  // Run Epilogue Pipeline
1144  auto& c_block_window = gemm_tile_windows.at(I4);
1145 
1146  if constexpr(kQuantType == QuantType::AQuantGrouped ||
1148  {
1149  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
1150  }
1151  else if constexpr(kQuantType == QuantType::RowColQuant)
1152  {
1153  const auto& aq_block_window = gemm_tile_windows.at(I1);
1154  const auto& bq_block_window = gemm_tile_windows.at(I3);
1155  EpiloguePipeline{}(c_block_window,
1156  c_block_tile,
1157  c_block_window,
1158  smem_ptr_0,
1159  aq_block_window,
1160  bq_block_window);
1161  }
1162  else if constexpr(kQuantType == QuantType::TensorQuant)
1163  {
1164  // TODO: why doesn't readfirstlane work here?
1165  // const AccDataType aq_scale =
1166  // __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
1167  // const AccDataType bq_scale =
1168  // __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
1169  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1170  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1171  EpiloguePipeline{}(
1172  c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
1173  }
1174  }
1190  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1191  CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
1192  const BDataType* b_ptr,
1193  const AQDataType* aq_ptr,
1194  const BQDataType* bq_ptr,
1195  CDataType* c_ptr,
1196  void* smem_ptr_0,
1197  void* smem_ptr_1,
1198  const QuantGemmKernelArgs& kargs,
1199  const SplitKBatchOffset& splitk_batch_offset,
1200  const index_t block_idx_m,
1201  const index_t block_idx_n)
1202  {
1203  // Create Gemm tensor views, pad views and tile windows
1204  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
1205  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
1206 
1207  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1208  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1209 
1210  const index_t num_loop = __builtin_amdgcn_readfirstlane(
1211  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1212 
1213  // Run GEMM cooperatively by whole workgroup.
1214  const auto& a_block_window = gemm_tile_windows.at(I0);
1215  const auto& b_block_window = gemm_tile_windows.at(I2);
1216 
1217  const auto& c_block_tile = [&]() {
1218  if constexpr(kQuantType == QuantType::BQuantGrouped)
1219  {
1220  const auto& bq_block_window = gemm_tile_windows.at(I3);
1221  index_t n = 0;
1222  if constexpr(PreshuffleQuant)
1223  {
1224  n = kargs.N;
1225  }
1226  return GemmPipeline{}.template operator()(a_block_window,
1227  b_block_window,
1228  bq_block_window,
1229  num_loop,
1230  smem_ptr_0,
1231  smem_ptr_1,
1232  n);
1233  }
1234  else
1235  {
1236  return nullptr;
1237  }
1238  }();
1239 
1240  // Run Epilogue Pipeline
1241  auto& c_block_window = gemm_tile_windows.at(I4);
1242 
1243  if constexpr(kQuantType == QuantType::BQuantGrouped)
1244  {
1245  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
1246  }
1247  else
1248  {
1249  return;
1250  // throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or
1251  // RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped,
1252  // "DoubleSmemBuffer Not implemented");
1253  }
1254  }
1255 
1257  {
1258  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1259  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1260  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1261  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1262  const SplitKBatchOffset splitk_batch_offset(kargs);
1263  // options
1264  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
1265  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
1266  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
1267  const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
1268  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
1269 
1270  // allocate LDS
1271  __shared__ char smem_ptr_0[GetSmemSize()];
1272  assert(kargs.k_batch == 1);
1273  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1274  {
1275  __shared__ char smem_ptr_1[GetSmemSize()];
1276 
1277  RunGemm2LDS(a_ptr,
1278  b_ptr,
1279  aq_ptr,
1280  bq_ptr,
1281  c_ptr,
1282  smem_ptr_0,
1283  smem_ptr_1,
1284  kargs,
1285  splitk_batch_offset,
1286  i_m,
1287  i_n);
1288  }
1289  else
1290  {
1291  RunGemm(a_ptr,
1292  b_ptr,
1293  aq_ptr,
1294  bq_ptr,
1295  c_ptr,
1296  smem_ptr_0,
1297  kargs,
1298  splitk_batch_offset,
1299  i_m,
1300  i_n);
1301  }
1302  }
1303 };
1304 
1305 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
QuantType
Definition: tile_gemm_quant_traits.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_quant_kernel.hpp:133
void * c_ptr
Definition: gemm_quant_kernel.hpp:166
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:164
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:165
const void * b_ptr
Definition: gemm_quant_kernel.hpp:163
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition: gemm_quant_kernel.hpp:167
const void * a_ptr
Definition: gemm_quant_kernel.hpp:162
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:135
Definition: gemm_quant_kernel.hpp:363
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_quant_kernel.hpp:364
index_t a_k_split_offset
Definition: gemm_quant_kernel.hpp:399
index_t b_k_split_offset
Definition: gemm_quant_kernel.hpp:400
index_t splitted_k
Definition: gemm_quant_kernel.hpp:401
Definition: gemm_quant_kernel.hpp:171
index_t k_batch
Definition: gemm_quant_kernel.hpp:187
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:186
const void * b_ptr
Definition: gemm_quant_kernel.hpp:173
void * c_ptr
Definition: gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:174
index_t stride_A
Definition: gemm_quant_kernel.hpp:182
index_t M
Definition: gemm_quant_kernel.hpp:177
const void * a_ptr
Definition: gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:175
index_t QK_B
Definition: gemm_quant_kernel.hpp:181
index_t K
Definition: gemm_quant_kernel.hpp:179
index_t QK_A
Definition: gemm_quant_kernel.hpp:180
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:185
index_t N
Definition: gemm_quant_kernel.hpp:178
index_t stride_C
Definition: gemm_quant_kernel.hpp:184
index_t stride_B
Definition: gemm_quant_kernel.hpp:183
Definition: gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:226
static constexpr bool PreshuffleB
Definition: gemm_quant_kernel.hpp:211
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_quant_kernel.hpp:238
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_quant_kernel.hpp:197
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_quant_kernel.hpp:198
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_quant_kernel.hpp:838
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_quant_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: gemm_quant_kernel.hpp:216
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:223
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition: gemm_quant_kernel.hpp:1256
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_quant_kernel.hpp:215
static constexpr index_t kBlockSize
Definition: gemm_quant_kernel.hpp:208
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_quant_kernel.hpp:200
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_quant_kernel.hpp:201
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_quant_kernel.hpp:568
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:224
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_quant_kernel.hpp:199
static constexpr bool PreshuffleQuant
Definition: gemm_quant_kernel.hpp:209
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:404
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: gemm_quant_kernel.hpp:219
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: gemm_quant_kernel.hpp:221
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_quant_kernel.hpp:214
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:225
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_quant_kernel.hpp:269
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:1087
static constexpr CK_TILE_HOST QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition: gemm_quant_kernel.hpp:249
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_kernel.hpp:231
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:915
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition: gemm_quant_kernel.hpp:206
static CK_TILE_DEVICE void RunGemm2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, void *smem_ptr_1, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:1191
static CK_TILE_HOST auto BlockSize()
Definition: gemm_quant_kernel.hpp:243
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_quant_kernel.hpp:213
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition: gemm_quant_kernel.hpp:204
static constexpr auto kQuantType
Definition: gemm_quant_kernel.hpp:229
Definition: gemm_quant_kernel.hpp:95
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:128
index_t N
Definition: gemm_quant_kernel.hpp:121
index_t K
Definition: gemm_quant_kernel.hpp:122
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:129
index_t stride_C
Definition: gemm_quant_kernel.hpp:127
index_t stride_B
Definition: gemm_quant_kernel.hpp:126
index_t stride_A
Definition: gemm_quant_kernel.hpp:125
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:97
index_t QK_A
Definition: gemm_quant_kernel.hpp:123
index_t QK_B
Definition: gemm_quant_kernel.hpp:124
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition: gemm_quant_kernel.hpp:120
Definition: integral_constant.hpp:13
Definition: gemm_quant_kernel.hpp:47
Default type
Definition: gemm_quant_kernel.hpp:48
typename T::AQLayout type
Definition: gemm_quant_kernel.hpp:30
Definition: gemm_quant_kernel.hpp:23
Default type
Definition: gemm_quant_kernel.hpp:24
Definition: gemm_quant_kernel.hpp:59
Default type
Definition: gemm_quant_kernel.hpp:60
typename T::BQLayout type
Definition: gemm_quant_kernel.hpp:42
Definition: gemm_quant_kernel.hpp:35
Default type
Definition: gemm_quant_kernel.hpp:36
Definition: gemm_quant_kernel.hpp:83
static constexpr bool value
Definition: gemm_quant_kernel.hpp:84
Definition: gemm_quant_kernel.hpp:71
static constexpr bool value
Definition: gemm_quant_kernel.hpp:72
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145