/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File
gemm_quant_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <string>
7 
8 #include "ck_tile/core.hpp"
14 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
19 namespace detail {
20 // Helper templates for safe type extraction
21 template <typename, typename Default, typename = void>
23 {
24  using type = Default;
25 };
26 
27 template <typename T, typename Default>
28 struct get_aq_layout_or<T, Default, std::void_t<typename T::AQLayout>>
29 {
30  using type = typename T::AQLayout;
31 };
32 
33 template <typename, typename Default, typename = void>
35 {
36  using type = Default;
37 };
38 
39 template <typename T, typename Default>
40 struct get_bq_layout_or<T, Default, std::void_t<typename T::BQLayout>>
41 {
42  using type = typename T::BQLayout;
43 };
44 
45 template <typename, typename Default, typename = void>
47 {
48  using type = Default;
49 };
50 
51 template <typename T, typename Default>
52 struct get_aq_data_type_or<T, Default, std::void_t<typename T::AQDataType>>
53 {
54  using type = typename T::AQDataType;
55 };
56 
57 template <typename, typename Default, typename = void>
59 {
60  using type = Default;
61 };
62 
63 template <typename T, typename Default>
64 struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
65 {
66  using type = typename T::BQDataType;
67 };
68 
69 template <typename, typename = void>
71 {
72  static constexpr bool value = false;
73 };
74 
75 template <typename T>
76 struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
77 {
78  static constexpr bool value = T::PreshuffleQuant;
79 };
80 
81 template <typename, typename = void>
83 {
84  static constexpr bool value = false;
85 };
86 
87 template <typename T>
88 struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
89 {
90  static constexpr bool value = T::PreshuffleB;
91 };
92 } // namespace detail
93 
95 {
98  index_t N_,
99  index_t K_,
100  index_t QK_A_,
101  index_t QK_B_,
102  index_t stride_A_,
103  index_t stride_B_,
104  index_t stride_C_,
105  index_t stride_AQ_,
106  index_t stride_BQ_)
107  : M(M_),
108  N(N_),
109  K(K_),
110  QK_A(QK_A_),
111  QK_B(QK_B_),
112  stride_A(stride_A_),
113  stride_B(stride_B_),
114  stride_C(stride_C_),
115  stride_AQ(stride_AQ_),
116  stride_BQ(stride_BQ_)
117  {
118  }
119 
130 };
131 
133 {
135  CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
136  const void* b_ptr_,
137  void* c_ptr_,
138  const void* aq_ptr_,
139  const void* bq_ptr_,
140  index_t k_batch_,
141  index_t M_,
142  index_t N_,
143  index_t K_,
144  index_t QK_A_,
145  index_t QK_B_,
146  index_t stride_A_,
147  index_t stride_B_,
148  index_t stride_C_,
149  index_t stride_AQ_,
150  index_t stride_BQ_)
152  M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
153  a_ptr(a_ptr_),
154  b_ptr(b_ptr_),
155  aq_ptr(aq_ptr_),
156  bq_ptr(bq_ptr_),
157  c_ptr(c_ptr_),
158  k_batch(k_batch_)
159  {
160  }
161 
162  const void* a_ptr = nullptr;
163  const void* b_ptr = nullptr;
164  const void* aq_ptr = nullptr;
165  const void* bq_ptr = nullptr;
166  void* c_ptr = nullptr;
168 };
169 
171 {
172  const void* a_ptr;
173  const void* b_ptr;
174  const void* aq_ptr;
175  const void* bq_ptr;
176  void* c_ptr;
188 };
189 
190 template <typename TilePartitioner_,
191  typename GemmPipeline_,
192  typename EpiloguePipeline_,
193  QuantType QuantType_>
195 {
202 
207 
208  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
209  static constexpr bool PreshuffleQuant =
212 
217 
218  using AQDataType =
220  using BQDataType =
222 
223  static constexpr auto I0 = number<0>(); // A Tensor
224  static constexpr auto I1 = number<1>(); // AQ Tensor
225  static constexpr auto I2 = number<2>(); // B Tensor
226  static constexpr auto I3 = number<3>(); // BQ Tensor
227  static constexpr auto I4 = number<4>(); // C Tensor
228 
229  static constexpr auto kQuantType = QuantType_;
230 
231  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
232  {
233  // clang-format off
234  return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
235  // clang-format on
236  }
237 
238  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
239  {
240  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
241  }
242 
243  CK_TILE_HOST static auto BlockSize()
244  {
245  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
246  }
247 
248  CK_TILE_HOST static constexpr QuantGemmKernelArgs
250  {
251  return QuantGemmKernelArgs{hostArgs.a_ptr,
252  hostArgs.b_ptr,
253  hostArgs.aq_ptr,
254  hostArgs.bq_ptr,
255  hostArgs.c_ptr,
256  hostArgs.M,
257  hostArgs.N,
258  hostArgs.K,
259  hostArgs.QK_A,
260  hostArgs.QK_B,
261  hostArgs.stride_A,
262  hostArgs.stride_B,
263  hostArgs.stride_C,
264  hostArgs.stride_AQ,
265  hostArgs.stride_BQ,
266  hostArgs.k_batch};
267  }
268 
270  {
271  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
272  }
273 
275  {
276  __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
277  const std::size_t k_id = blockIdx.z)
278  {
279  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
280  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
281  const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
282 
283  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
284  {
286  }
287  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
288  {
289  a_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_A);
290  }
291 
292  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
293  {
294  b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
295  }
296  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
297  {
299  }
300 
301  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
302  {
304  }
305  else
306  {
307  splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
308  }
309  }
310 
314  };
315 
317  {
318  if(kargs.k_batch != 1)
319  {
320  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
321  {
322  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
323  }
324  return false;
325  }
326 
327  if constexpr(kQuantType == QuantType::AQuantGrouped)
328  {
329  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
330  if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
331  {
332  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
333  {
334  CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!");
335  }
336  return false;
337  }
338  }
339 
340  if constexpr(kQuantType == QuantType::BQuantGrouped)
341  {
342  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
343  if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
344  {
345  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
346  {
347  CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!");
348  }
349  return false;
350  }
351  }
352 
353  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
354  {
355  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
356  GemmPipeline::kPadK == false)
357  {
358  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
359  {
360  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
361  "without padding!");
362  }
363  return false;
364  }
365  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
366  {
367  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
368  {
369  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
370  }
371  return false;
372  }
373  }
374  else
375  {
376  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
377  {
378  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
379  {
381  "Can't support M that is not a multiple of MPerBlock without padding!");
382  }
383  return false;
384  }
385  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
386  {
387  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
388  {
389  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
390  }
391  return false;
392  }
393  }
394 
395  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
396  {
397  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
398  {
399  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
400  {
402  "Can't support N that is not a multiple of NPerBlock without padding!");
403  }
404  return false;
405  }
406  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
407  {
408  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
409  {
410  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
411  }
412  return false;
413  }
414  }
415  else
416  {
417  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
418  GemmPipeline::kPadK == false)
419  {
420  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
421  {
422  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
423  "without padding!");
424  }
425  return false;
426  }
427  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
428  {
429  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
430  {
431  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
432  }
433  return false;
434  }
435  }
436 
437  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
438  {
439  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
440  {
441  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
442  {
444  "Can't support N that is not a multiple of NPerBlock without padding!");
445  }
446  return false;
447  }
448  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
449  {
450  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
451  {
452  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
453  }
454  return false;
455  }
456  }
457  else
458  {
459  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
460  {
461  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
462  {
464  "Can't support M that is not a multiple of MPerBlock without padding!");
465  }
466  return false;
467  }
468  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
469  {
470  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
471  {
472  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
473  }
474  return false;
475  }
476  }
477  return true;
478  }
479 
480  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
481  CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
482  const BDataType* b_ptr,
483  const AQDataType* aq_ptr,
484  const BQDataType* bq_ptr,
485  CDataType* c_ptr,
486  const QuantGemmKernelArgs& kargs,
487  const SplitKBatchOffset& splitk_batch_offset)
488  {
489 
490  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
491  const auto& a_tensor_view = [&]() {
492  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
493  {
494  return make_naive_tensor_view<address_space_enum::global>(
495  a_ptr,
496  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
497  make_tuple(kargs.stride_A, 1),
498  number<GemmPipeline::GetVectorSizeA()>{},
499  number<1>{});
500  }
501  else
502  {
503  return make_naive_tensor_view<address_space_enum::global>(
504  a_ptr,
505  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
506  make_tuple(kargs.stride_A, 1),
507  number<GemmPipeline::GetVectorSizeA()>{},
508  number<1>{});
509  }
510  }();
511 
512  const auto get_padding_size = [](index_t length, index_t alignment) {
513  return ck_tile::integer_least_multiple(length, alignment) - length;
514  };
515 
516  const auto& aq_tensor_view = [&]() {
518  {
519  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
520  const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
521  const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
522 
523  const auto aq_desc =
525  make_tuple(aq_x, 1),
526  number<GemmPipeline::GetVectorSizeAQ()>{},
527  number<1>{});
528 
529  const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
530  const auto aq_pad0_desc = transform_tensor_descriptor(
531  aq_desc,
532  make_tuple(
534  make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
537 
538  const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
539  const auto wave_tile_size =
540  TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
541  const auto wave_tile_count_x =
542  ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
543  const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
544  aq_pad0_desc,
545  make_tuple(
547  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
550 
551  const auto aq_pad1_desc = transform_tensor_descriptor(
552  aq_unmerge_pad0_desc,
553  make_tuple(
555  make_pass_through_transform(wave_tile_count_x),
557  wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
560 
561  const auto pad_wave_size =
563  const auto aq_merge_pad1_desc = transform_tensor_descriptor(
564  aq_pad1_desc,
565  make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
566  make_pass_through_transform(pad_wave_size)),
569 
570  return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
571  }
572  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
573  {
574  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
575  return make_naive_tensor_view<address_space_enum::global>(
576  aq_ptr,
577  make_tuple(kargs.M, kargs.QK_A),
578  make_tuple(kargs.stride_AQ, 1),
579  number<GemmPipeline::GetVectorSizeAQ()>{},
580  number<1>{});
581  }
582  else if constexpr(kQuantType == QuantType::RowColQuant)
583  {
584  return make_naive_tensor_view<address_space_enum::global>(
585  aq_ptr,
586  make_tuple(kargs.M, kargs.N),
587  make_tuple(1, 0), // broadcasting over n
588  number<1>{},
589  number<1>{});
590  }
591  else
592  {
593  return nullptr; // TODO: use some other "empty" type for this
594  }
595  }();
596 
597  const auto& b_tensor_view = [&]() {
598  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
599  {
600  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
601  {
602  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
603  const index_t K0 = splitk_batch_offset.splitted_k / K1;
604  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
605  const auto b_k0_n_k1_desc =
607  make_tuple(kargs.N * K1, K1, I1),
609  number<1>{});
610  const auto b_n_k_desc = transform_tensor_descriptor(
611  b_k0_n_k1_desc,
616  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
617  }
618  else
619  {
620  return make_naive_tensor_view<address_space_enum::global>(
621  b_ptr,
622  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
623  make_tuple(kargs.stride_B, 1),
624  number<GemmPipeline::GetVectorSizeB()>{},
625  number<1>{});
626  }
627  }
628  else
629  {
630  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
631  {
632  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
633  const index_t K0 = splitk_batch_offset.splitted_k / K1;
634  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
635  const auto b_k0_n_k1_desc =
637  make_tuple(kargs.N * K1, K1, I1),
639  number<1>{});
640  const auto b_n_k_desc = transform_tensor_descriptor(
641  b_k0_n_k1_desc,
646  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
647  }
648  else
649  {
650  if constexpr(PreshuffleB)
651  {
652  index_t kFlatK =
653  GemmPipeline::flatKPerWarp *
654  (splitk_batch_offset.splitted_k /
655  TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
656  index_t kFlatN = kargs.N * kargs.K / kFlatK;
657 
658  return make_naive_tensor_view<address_space_enum::global>(
659  b_ptr,
660  make_tuple(kFlatN, kFlatK),
661  make_tuple(kFlatK, 1),
662  number<GemmPipeline::GetVectorSizeB()>{},
663  number<1>{});
664  }
665  else
666  {
667  return make_naive_tensor_view<address_space_enum::global>(
668  b_ptr,
669  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
670  make_tuple(kargs.stride_B, 1),
671  number<GemmPipeline::GetVectorSizeB()>{},
672  number<1>{});
673  }
674  }
675  }
676  }();
677 
678  const auto& bq_tensor_view = [&]() {
679  if constexpr(kQuantType == QuantType::RowColQuant)
680  {
681  return make_naive_tensor_view<address_space_enum::global>(
682  bq_ptr,
683  make_tuple(kargs.M, kargs.N),
684  make_tuple(0, 1), // broadcasting over m
685  number<1>{},
686  number<1>{});
687  }
688  else if constexpr(kQuantType == QuantType::BQuantGrouped)
689  {
690  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
692  return make_naive_tensor_view<address_space_enum::global>(
693  bq_ptr,
694  make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
695  make_tuple(1, kargs.stride_BQ),
696  number<GemmPipeline::GetVectorSizeBQ()>{},
697  number<1>{});
698  }
699  else
700  {
701  return nullptr; // TODO: use some other "empty" type for this
702  }
703  }();
704 
705  // TODO: enable vector write for C in ColMajor
706  const auto& c_tensor_view = [&]() {
707  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
708  {
709  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
710  c_ptr,
711  make_tuple(kargs.M, kargs.N),
712  make_tuple(kargs.stride_C, 1),
713  number<EpiloguePipeline::GetVectorSizeC()>{},
714  number<1>{});
715  }
716  else
717  {
718  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
719  c_ptr,
720  make_tuple(kargs.M, kargs.N),
721  make_tuple(1, kargs.stride_C),
722  number<1>{},
723  number<1>{});
724  }
725  }();
726 
727  return make_tuple(
728  a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
729  }
730 
731  template <typename TensorView>
732  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
733  {
734  const auto& a_pad_view = [&]() {
735  const auto& a_tensor_view = views.at(I0);
736  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
737  {
738  return pad_tensor_view(a_tensor_view,
742  }
743  else
744  {
745  return pad_tensor_view(a_tensor_view,
749  }
750  }();
751 
752  // no padding
753  const auto& aq_pad_view = [&]() { return views.at(I1); }();
754 
755  const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view
756 
757  const auto& b_pad_view = [&]() {
758  const auto& b_tensor_view = views.at(I2);
759  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
760  {
761  return pad_tensor_view(b_tensor_view,
765  }
766  else
767  {
768  return pad_tensor_view(b_tensor_view,
772  }
773  }();
774 
775  // no padding
776  const auto& bq_pad_view = [&]() { return views.at(I3); }();
777 
778  // TODO vector write in for C in ColMajor
779  const auto& c_pad_view = [&]() {
780  const auto& c_tensor_view = views.at(I4);
781  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
782  {
783  return pad_tensor_view(c_tensor_view,
787  }
788  else
789  {
790  return pad_tensor_view(c_tensor_view,
794  }
795  }();
796  if constexpr(PreshuffleB)
797  {
798 
799  return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
800  }
801  else
802  {
803  return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
804  }
805  }
806 
807  template <typename PadView>
808  CK_TILE_DEVICE static auto
809  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
810  {
811 
812  const auto& a_pad_view = views.at(I0);
813  const auto& aq_pad_view = views.at(I1);
814  const auto& b_pad_view = views.at(I2);
815  const auto& bq_pad_view = views.at(I3);
816  const auto& c_pad_view = views.at(I4);
817  const auto& a_block_window = [&]() {
818  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
819  {
820  return make_tile_window(a_pad_view,
823  {i_m, 0});
824  }
825  else
826  {
827  return make_tile_window(a_pad_view,
830  {0, i_m});
831  }
832  }();
833 
834  const auto& aq_block_window = [&]() {
836  {
837  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
839  constexpr auto block_m = TilePartitioner::MPerBlock;
840  constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
841  constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
842  constexpr auto tile_window_width =
843  ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
844  constexpr auto tile_window_height = block_m / warp_m;
845  auto block_m_idx = i_m / block_m;
846  return make_tile_window(
847  aq_pad_view,
849  {block_m_idx * tile_window_height, 0});
850  }
851  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
852  {
853  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
855  constexpr auto block_m = TilePartitioner::MPerBlock;
856  constexpr auto block_k = TilePartitioner::KPerBlock;
857  return make_tile_window(
858  aq_pad_view,
859  make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
860  {i_m, 0});
861  }
862  else if constexpr(kQuantType == QuantType::RowColQuant)
863  {
864  return make_tile_window(aq_pad_view,
867  {i_m, i_n});
868  }
869  else
870  {
871  return nullptr; // TODO: use some other "empty" type?
872  }
873  }();
874 
875  const auto& b_block_window = [&]() {
876  if constexpr(PreshuffleB)
877  {
878 
879  return make_tile_window(
880  b_pad_view,
883  {static_cast<int>(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0});
884  }
885  else
886  {
887  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
888  {
889  return make_tile_window(b_pad_view,
892  {i_n, 0});
893  }
894  else
895  {
896  return make_tile_window(b_pad_view,
899  {0, i_n});
900  }
901  }
902  }();
903 
904  const auto& bq_block_window = [&]() {
905  if constexpr(kQuantType == QuantType::RowColQuant)
906  {
907  return make_tile_window(bq_pad_view,
910  {i_m, i_n});
911  }
912  else if constexpr(kQuantType == QuantType::BQuantGrouped)
913  {
914  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
916  return make_tile_window(
917  bq_pad_view,
919  number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
920  {0, i_n / QuantGroupSize::kN});
921  }
922  else
923  {
924  return nullptr; // TODO: use some other "empty" type here
925  }
926  }();
927 
928  auto c_block_window = make_tile_window(
929  c_pad_view,
931  {i_m, i_n});
932 
933  return make_tuple(
934  a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
935  }
936 
953  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
954  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
955  const BDataType* b_ptr,
956  const AQDataType* aq_ptr,
957  const BQDataType* bq_ptr,
958  CDataType* c_ptr,
959  void* smem_ptr_0,
960  const QuantGemmKernelArgs& kargs,
961  const SplitKBatchOffset& splitk_batch_offset,
962  const index_t block_idx_m,
963  const index_t block_idx_n)
964  {
965  // Create Gemm tensor views, pad views and tile windows
966  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
967  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
968 
969  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
970  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
971 
972  const index_t num_loop =
973  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
974 
975  // Run GEMM cooperatively by whole workgroup.
976  const auto& a_block_window = gemm_tile_windows.at(I0);
977  const auto& b_block_window = gemm_tile_windows.at(I2);
978 
979  const auto& c_block_tile = [&]() {
980  if constexpr(kQuantType == QuantType::AQuantGrouped)
981  {
982  const auto& aq_block_window = gemm_tile_windows.at(I1);
983  return GemmPipeline{}.template operator()(
984  a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
985  }
986  else if constexpr(kQuantType == QuantType::BQuantGrouped)
987  {
988  const auto& bq_block_window = gemm_tile_windows.at(I3);
989  return GemmPipeline{}.template operator()(
990  a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
991  }
992  else if constexpr(kQuantType == QuantType::RowColQuant ||
994  {
995  return GemmPipeline{}.template operator()(
996  a_block_window, b_block_window, num_loop, smem_ptr_0);
997  }
998  }();
999 
1000  // Run Epilogue Pipeline
1001  auto& c_block_window = gemm_tile_windows.at(I4);
1002 
1003  if constexpr(kQuantType == QuantType::AQuantGrouped ||
1005  {
1006  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
1007  }
1008  else if constexpr(kQuantType == QuantType::RowColQuant)
1009  {
1010  const auto& aq_block_window = gemm_tile_windows.at(I1);
1011  const auto& bq_block_window = gemm_tile_windows.at(I3);
1012  EpiloguePipeline{}(c_block_window,
1013  c_block_tile,
1014  c_block_window,
1015  smem_ptr_0,
1016  aq_block_window,
1017  bq_block_window);
1018  }
1019  else if constexpr(kQuantType == QuantType::TensorQuant)
1020  {
1021  // TODO: why doesn't readfirstlane work here?
1022  // const AccDataType aq_scale =
1023  // __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
1024  // const AccDataType bq_scale =
1025  // __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
1026  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1027  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1028  EpiloguePipeline{}(
1029  c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
1030  }
1031  }
1047  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1048  CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
1049  const BDataType* b_ptr,
1050  const AQDataType* aq_ptr,
1051  const BQDataType* bq_ptr,
1052  CDataType* c_ptr,
1053  void* smem_ptr_0,
1054  void* smem_ptr_1,
1055  const QuantGemmKernelArgs& kargs,
1056  const SplitKBatchOffset& splitk_batch_offset,
1057  const index_t block_idx_m,
1058  const index_t block_idx_n)
1059  {
1060  // Create Gemm tensor views, pad views and tile windows
1061  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
1062  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
1063 
1064  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1065  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1066 
1067  const index_t num_loop = __builtin_amdgcn_readfirstlane(
1068  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1069 
1070  // Run GEMM cooperatively by whole workgroup.
1071  const auto& a_block_window = gemm_tile_windows.at(I0);
1072  const auto& b_block_window = gemm_tile_windows.at(I2);
1073 
1074  const auto& c_block_tile = [&]() {
1075  if constexpr(kQuantType == QuantType::BQuantGrouped)
1076  {
1077  const auto& bq_block_window = gemm_tile_windows.at(I3);
1078  return GemmPipeline{}.template operator()(a_block_window,
1079  b_block_window,
1080  bq_block_window,
1081  num_loop,
1082  smem_ptr_0,
1083  smem_ptr_1);
1084  }
1085  else
1086  {
1087  return nullptr;
1088  }
1089  }();
1090 
1091  // Run Epilogue Pipeline
1092  auto& c_block_window = gemm_tile_windows.at(I4);
1093 
1094  if constexpr(kQuantType == QuantType::BQuantGrouped)
1095  {
1096  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
1097  }
1098  else
1099  {
1100  return;
1101  // throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or
1102  // RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped,
1103  // "DoubleSmemBuffer Not implemented");
1104  }
1105  }
1106 
1108  {
1109  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1110  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1111  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1112  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1113 
1114  const SplitKBatchOffset splitk_batch_offset(kargs);
1115  // options
1116  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
1117  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
1118  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
1119  const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
1120  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
1121 
1122  // allocate LDS
1123  __shared__ char smem_ptr_0[GetSmemSize()];
1124  assert(kargs.k_batch == 1);
1125  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1126  {
1127  __shared__ char smem_ptr_1[GetSmemSize()];
1128 
1129  RunGemm2LDS(a_ptr,
1130  b_ptr,
1131  aq_ptr,
1132  bq_ptr,
1133  c_ptr,
1134  smem_ptr_0,
1135  smem_ptr_1,
1136  kargs,
1137  splitk_batch_offset,
1138  i_m,
1139  i_n);
1140  }
1141  else
1142  {
1143  RunGemm(a_ptr,
1144  b_ptr,
1145  aq_ptr,
1146  bq_ptr,
1147  c_ptr,
1148  smem_ptr_0,
1149  kargs,
1150  splitk_batch_offset,
1151  i_m,
1152  i_n);
1153  }
1154  }
1155 };
1156 
1157 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
QuantType
Definition: tile_gemm_quant_traits.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_quant_kernel.hpp:133
void * c_ptr
Definition: gemm_quant_kernel.hpp:166
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:164
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:165
const void * b_ptr
Definition: gemm_quant_kernel.hpp:163
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition: gemm_quant_kernel.hpp:167
const void * a_ptr
Definition: gemm_quant_kernel.hpp:162
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:135
Definition: gemm_quant_kernel.hpp:275
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_quant_kernel.hpp:276
index_t a_k_split_offset
Definition: gemm_quant_kernel.hpp:311
index_t b_k_split_offset
Definition: gemm_quant_kernel.hpp:312
index_t splitted_k
Definition: gemm_quant_kernel.hpp:313
Definition: gemm_quant_kernel.hpp:171
index_t k_batch
Definition: gemm_quant_kernel.hpp:187
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:186
const void * b_ptr
Definition: gemm_quant_kernel.hpp:173
void * c_ptr
Definition: gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:174
index_t stride_A
Definition: gemm_quant_kernel.hpp:182
index_t M
Definition: gemm_quant_kernel.hpp:177
const void * a_ptr
Definition: gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:175
index_t QK_B
Definition: gemm_quant_kernel.hpp:181
index_t K
Definition: gemm_quant_kernel.hpp:179
index_t QK_A
Definition: gemm_quant_kernel.hpp:180
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:185
index_t N
Definition: gemm_quant_kernel.hpp:178
index_t stride_C
Definition: gemm_quant_kernel.hpp:184
index_t stride_B
Definition: gemm_quant_kernel.hpp:183
Definition: gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:226
static constexpr bool PreshuffleB
Definition: gemm_quant_kernel.hpp:211
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_quant_kernel.hpp:238
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_quant_kernel.hpp:197
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_quant_kernel.hpp:198
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_quant_kernel.hpp:732
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_quant_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: gemm_quant_kernel.hpp:216
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:223
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition: gemm_quant_kernel.hpp:1107
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_quant_kernel.hpp:215
static constexpr index_t kBlockSize
Definition: gemm_quant_kernel.hpp:208
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_quant_kernel.hpp:200
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_quant_kernel.hpp:201
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_quant_kernel.hpp:481
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:224
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_quant_kernel.hpp:199
static constexpr bool PreshuffleQuant
Definition: gemm_quant_kernel.hpp:209
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:316
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: gemm_quant_kernel.hpp:219
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: gemm_quant_kernel.hpp:221
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_quant_kernel.hpp:214
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:225
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_quant_kernel.hpp:269
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:954
static constexpr CK_TILE_HOST QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition: gemm_quant_kernel.hpp:249
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_kernel.hpp:231
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:809
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition: gemm_quant_kernel.hpp:206
static CK_TILE_DEVICE void RunGemm2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, void *smem_ptr_1, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:1048
static CK_TILE_HOST auto BlockSize()
Definition: gemm_quant_kernel.hpp:243
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_quant_kernel.hpp:213
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition: gemm_quant_kernel.hpp:204
static constexpr auto kQuantType
Definition: gemm_quant_kernel.hpp:229
Definition: gemm_quant_kernel.hpp:95
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:128
index_t N
Definition: gemm_quant_kernel.hpp:121
index_t K
Definition: gemm_quant_kernel.hpp:122
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:129
index_t stride_C
Definition: gemm_quant_kernel.hpp:127
index_t stride_B
Definition: gemm_quant_kernel.hpp:126
index_t stride_A
Definition: gemm_quant_kernel.hpp:125
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:97
index_t QK_A
Definition: gemm_quant_kernel.hpp:123
index_t QK_B
Definition: gemm_quant_kernel.hpp:124
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition: gemm_quant_kernel.hpp:120
Definition: integral_constant.hpp:13
Definition: gemm_quant_kernel.hpp:47
Default type
Definition: gemm_quant_kernel.hpp:48
typename T::AQLayout type
Definition: gemm_quant_kernel.hpp:30
Definition: gemm_quant_kernel.hpp:23
Default type
Definition: gemm_quant_kernel.hpp:24
Definition: gemm_quant_kernel.hpp:59
Default type
Definition: gemm_quant_kernel.hpp:60
typename T::BQLayout type
Definition: gemm_quant_kernel.hpp:42
Definition: gemm_quant_kernel.hpp:35
Default type
Definition: gemm_quant_kernel.hpp:36
Definition: gemm_quant_kernel.hpp:83
static constexpr bool value
Definition: gemm_quant_kernel.hpp:84
Definition: gemm_quant_kernel.hpp:71
static constexpr bool value
Definition: gemm_quant_kernel.hpp:72
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145