include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp Source File

include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp Source File#

Composable Kernel: include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp Source File
gemm_aquant_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
12 
13 namespace ck_tile {
14 
16 {
19  index_t N_,
20  index_t K_,
21  index_t QK_,
22  index_t stride_A_,
23  index_t stride_B_,
24  index_t stride_C_,
25  index_t stride_AQ_)
26  : M(M_),
27  N(N_),
28  K(K_),
29  QK(QK_),
30  stride_A(stride_A_),
31  stride_B(stride_B_),
32  stride_C(stride_C_),
33  stride_AQ(stride_AQ_)
34  {
35  }
36 
45 };
46 
48 {
50  CK_TILE_HOST AQuantGemmHostArgs(const void* a_ptr_,
51  const void* b_ptr_,
52  void* c_ptr_,
53  const void* aq_ptr_,
54  index_t k_batch_,
55  index_t M_,
56  index_t N_,
57  index_t K_,
58  index_t QK_,
59  index_t stride_A_,
60  index_t stride_B_,
61  index_t stride_C_,
62  index_t stride_AQ_)
63  : AQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_AQ_),
64  a_ptr(a_ptr_),
65  b_ptr(b_ptr_),
66  aq_ptr(aq_ptr_),
67  c_ptr(c_ptr_),
68  k_batch(k_batch_)
69  {
70  }
71 
72  const void* a_ptr;
73  const void* b_ptr;
74  const void* aq_ptr;
75  void* c_ptr;
77 };
78 
80 {
81  const void* a_ptr;
82  const void* b_ptr;
83  const void* aq_ptr;
84  void* c_ptr;
94 };
95 
96 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
98 {
106  static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
107 
112 
113  static constexpr auto I0 = number<0>();
114  static constexpr auto I1 = number<1>();
115  static constexpr auto I2 = number<2>();
116  static constexpr auto I3 = number<3>();
117 
118  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
119  {
120  // clang-format off
121  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
122  // clang-format on
123  }
124 
125  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
126  {
127  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
128  }
129 
130  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
131 
132  CK_TILE_HOST static constexpr AQuantGemmKernelArgs
134  {
135  return AQuantGemmKernelArgs{hostArgs.a_ptr,
136  hostArgs.b_ptr,
137  hostArgs.aq_ptr,
138  hostArgs.c_ptr,
139  hostArgs.M,
140  hostArgs.N,
141  hostArgs.K,
142  hostArgs.QK,
143  hostArgs.stride_A,
144  hostArgs.stride_B,
145  hostArgs.stride_C,
146  hostArgs.stride_AQ,
147  hostArgs.k_batch};
148  }
149 
151  {
152  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
153  }
154 
156  {
157  __device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs,
158  const std::size_t k_id = blockIdx.z)
159  {
160  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
161  const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
162  const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
163 
164  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
165  {
166  a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
167  }
168  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
169  {
170  a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
171  }
172 
173  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
174  {
175  b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
176  }
177  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
178  {
179  b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
180  }
181 
182  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
183  {
184  splitted_k = __builtin_amdgcn_readfirstlane(KRead);
185  }
186  else
187  {
188  splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
189  }
190  }
191 
195  };
196 
198  {
199  if(kargs.k_batch != 1)
200  {
201  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
202  {
203  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
204  }
205  return false;
206  }
207 
208  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
209  if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0)
210  {
211  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
212  {
213  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
214  }
215  return false;
216  }
217 
218  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
219  {
220  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
221  GemmPipeline::kPadK == false)
222  {
223  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
224  {
225  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
226  "without padding!");
227  }
228  return false;
229  }
230  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
231  {
232  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
233  {
234  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
235  }
236  return false;
237  }
238  }
239  else
240  {
241  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
242  {
243  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
244  {
246  "Can't support M that is not a multiple of MPerBlock without padding!");
247  }
248  return false;
249  }
250  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
251  {
252  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
253  {
254  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
255  }
256  return false;
257  }
258  }
259 
260  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
261  {
262  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
263  {
264  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
265  {
267  "Can't support N that is not a multiple of NPerBlock without padding!");
268  }
269  return false;
270  }
271  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
272  {
273  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
274  {
275  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
276  }
277  return false;
278  }
279  }
280  else
281  {
282  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
283  GemmPipeline::kPadK == false)
284  {
285  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
286  {
287  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
288  "without padding!");
289  }
290  return false;
291  }
292  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
293  {
294  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
295  {
296  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
297  }
298  return false;
299  }
300  }
301 
302  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
303  {
304  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
305  {
306  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
307  {
309  "Can't support N that is not a multiple of NPerBlock without padding!");
310  }
311  return false;
312  }
313  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
314  {
315  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
316  {
317  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
318  }
319  return false;
320  }
321  }
322  else
323  {
324  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
325  {
326  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
327  {
329  "Can't support M that is not a multiple of MPerBlock without padding!");
330  }
331  return false;
332  }
333  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
334  {
335  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
336  {
337  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
338  }
339  return false;
340  }
341  }
342  return true;
343  }
344 
345  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
346  CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
347  const BDataType* b_ptr,
348  const AQDataType* aq_ptr,
349  CDataType* c_ptr,
350  const AQuantGemmKernelArgs& kargs,
351  const SplitKBatchOffset& splitk_batch_offset)
352  {
353  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
354  const auto& a_tensor_view = [&]() {
355  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
356  {
357  return make_naive_tensor_view<address_space_enum::global>(
358  a_ptr,
359  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
360  make_tuple(kargs.stride_A, 1),
361  number<GemmPipeline::GetVectorSizeA()>{},
362  number<1>{});
363  }
364  else
365  {
366  return make_naive_tensor_view<address_space_enum::global>(
367  a_ptr,
368  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
369  make_tuple(kargs.stride_A, 1),
370  number<GemmPipeline::GetVectorSizeA()>{},
371  number<1>{});
372  }
373  }();
374 
375  const auto& aq_tensor_view = [&]() {
376  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
377  return make_naive_tensor_view<address_space_enum::global>(
378  aq_ptr,
379  make_tuple(kargs.M, kargs.QK),
380  make_tuple(kargs.stride_AQ, 1),
381  number<GemmPipeline::GetVectorSizeAQ()>{},
382  number<1>{});
383  }();
384 
385  const auto& b_tensor_view = [&]() {
386  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
387  {
388  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
389  {
390  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
391  const index_t K0 = splitk_batch_offset.splitted_k / K1;
392  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
393  const auto b_k0_n_k1_desc =
395  make_tuple(kargs.N * K1, K1, I1),
397  number<1>{});
398  const auto b_n_k_desc = transform_tensor_descriptor(
399  b_k0_n_k1_desc,
404  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
405  }
406  else
407  {
408  return make_naive_tensor_view<address_space_enum::global>(
409  b_ptr,
410  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
411  make_tuple(kargs.stride_B, 1),
412  number<GemmPipeline::GetVectorSizeB()>{},
413  number<1>{});
414  }
415  }
416  else
417  {
418  if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
419  {
420  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
421  const index_t K0 = splitk_batch_offset.splitted_k / K1;
422  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
423  const auto b_k0_n_k1_desc =
425  make_tuple(kargs.N * K1, K1, I1),
427  number<1>{});
428  const auto b_n_k_desc = transform_tensor_descriptor(
429  b_k0_n_k1_desc,
434  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
435  }
436  else
437  {
438  return make_naive_tensor_view<address_space_enum::global>(
439  b_ptr,
440  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
441  make_tuple(kargs.stride_B, 1),
442  number<GemmPipeline::GetVectorSizeB()>{},
443  number<1>{});
444  }
445  }
446  }();
447 
448  // TODO: enable vector write for C in ColMajor
449  const auto& c_tensor_view = [&]() {
450  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
451  {
452  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
453  c_ptr,
454  make_tuple(kargs.M, kargs.N),
455  make_tuple(kargs.stride_C, 1),
456  number<EpiloguePipeline::GetVectorSizeC()>{},
457  number<1>{});
458  }
459  else
460  {
461  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
462  c_ptr,
463  make_tuple(kargs.M, kargs.N),
464  make_tuple(1, kargs.stride_C),
465  number<1>{},
466  number<1>{});
467  }
468  }();
469 
470  return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view);
471  }
472 
473  template <typename TensorView>
474  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
475  {
476  const auto& a_pad_view = [&]() {
477  const auto& a_tensor_view = views.at(I0);
478  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
479  {
480  return pad_tensor_view(a_tensor_view,
484  }
485  else
486  {
487  return pad_tensor_view(a_tensor_view,
491  }
492  }();
493 
494  const auto& aq_pad_view = [&]() {
495  const auto& aq_tensor_view = views.at(I1);
496  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
497  return pad_tensor_view(
498  aq_tensor_view,
500  number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
501  // TODO: Add support for padding.
503  }();
504 
505  const auto& b_pad_view = [&]() {
506  const auto& b_tensor_view = views.at(I2);
507  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
508  {
509  return pad_tensor_view(b_tensor_view,
513  }
514  else
515  {
516  return pad_tensor_view(b_tensor_view,
520  }
521  }();
522 
523  // TODO vector write in for C in ColMajor
524  const auto& c_pad_view = [&]() {
525  const auto& c_tensor_view = views.at(I3);
526  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
527  {
528  return pad_tensor_view(c_tensor_view,
532  }
533  else
534  {
535  return pad_tensor_view(c_tensor_view,
539  }
540  }();
541 
542  return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view);
543  }
544 
545  template <typename PadView>
546  CK_TILE_DEVICE static auto
547  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
548  {
549  const auto& a_pad_view = views.at(I0);
550  const auto& aq_pad_view = views.at(I1);
551  const auto& b_pad_view = views.at(I2);
552  const auto& c_pad_view = views.at(I3);
553 
554  const auto& a_block_window = [&]() {
555  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
556  {
557  return make_tile_window(a_pad_view,
560  {i_m, 0});
561  }
562  else
563  {
564  return make_tile_window(a_pad_view,
567  {0, i_m});
568  }
569  }();
570 
571  const auto& aq_block_window = [&]() {
572  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
573  return make_tile_window(
574  aq_pad_view,
576  number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
577  {i_m, 0});
578  }();
579 
580  const auto& b_block_window = [&]() {
581  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
582  {
583  return make_tile_window(b_pad_view,
586  {i_n, 0});
587  }
588  else
589  {
590  return make_tile_window(b_pad_view,
593  {0, i_n});
594  }
595  }();
596 
597  auto c_block_window = make_tile_window(
598  c_pad_view,
600  {i_m, i_n});
601 
602  return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window);
603  }
604 
620  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
621  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
622  const BDataType* b_ptr,
623  const AQDataType* aq_ptr,
624  CDataType* c_ptr,
625  void* smem_ptr_0,
626  const AQuantGemmKernelArgs& kargs,
627  const SplitKBatchOffset& splitk_batch_offset,
628  const index_t block_idx_m,
629  const index_t block_idx_n)
630  {
631  // Create Gemm tensor views, pad views and tile windows
632  const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
633  a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
634 
635  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
636  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
637 
638  const index_t num_loop = __builtin_amdgcn_readfirstlane(
639  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
640 
641  // Run GEMM cooperatively by whole workgroup.
642  const auto& a_block_window = gemm_tile_windows.at(I0);
643  const auto& aq_block_window = gemm_tile_windows.at(I1);
644  const auto& b_block_window = gemm_tile_windows.at(I2);
645 
646  const auto& c_block_tile = GemmPipeline{}.template operator()(
647  a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
648 
649  // Run Epilogue Pipeline
650  auto& c_block_window = gemm_tile_windows.at(I3);
651 
652  EpiloguePipeline{}.template
653  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
654  c_block_window, c_block_tile, c_block_window, smem_ptr_0);
655  }
656 
658  {
659  const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
660  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
661  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
662  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
663 
664  const SplitKBatchOffset splitk_batch_offset(kargs);
665  // options
666  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
667  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
668  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
669  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
670 
671  // allocate LDS
672  __shared__ char smem_ptr_0[GetSmemSize()];
673 
674  assert(kargs.k_batch == 1);
675  RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
676  }
677 };
678 
679 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: gemm_aquant_kernel.hpp:48
const void * a_ptr
Definition: gemm_aquant_kernel.hpp:72
const void * b_ptr
Definition: gemm_aquant_kernel.hpp:73
const void * aq_ptr
Definition: gemm_aquant_kernel.hpp:74
void * c_ptr
Definition: gemm_aquant_kernel.hpp:75
index_t k_batch
Definition: gemm_aquant_kernel.hpp:76
CK_TILE_HOST AQuantGemmHostArgs()=default
CK_TILE_HOST AQuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_)
Definition: gemm_aquant_kernel.hpp:50
Definition: gemm_aquant_kernel.hpp:156
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_aquant_kernel.hpp:157
index_t b_k_split_offset
Definition: gemm_aquant_kernel.hpp:193
index_t a_k_split_offset
Definition: gemm_aquant_kernel.hpp:192
index_t splitted_k
Definition: gemm_aquant_kernel.hpp:194
Definition: gemm_aquant_kernel.hpp:80
const void * a_ptr
Definition: gemm_aquant_kernel.hpp:81
index_t stride_A
Definition: gemm_aquant_kernel.hpp:89
index_t stride_B
Definition: gemm_aquant_kernel.hpp:90
const void * b_ptr
Definition: gemm_aquant_kernel.hpp:82
index_t K
Definition: gemm_aquant_kernel.hpp:87
index_t M
Definition: gemm_aquant_kernel.hpp:85
index_t k_batch
Definition: gemm_aquant_kernel.hpp:93
void * c_ptr
Definition: gemm_aquant_kernel.hpp:84
index_t stride_AQ
Definition: gemm_aquant_kernel.hpp:92
index_t N
Definition: gemm_aquant_kernel.hpp:86
index_t stride_C
Definition: gemm_aquant_kernel.hpp:91
index_t QK
Definition: gemm_aquant_kernel.hpp:88
const void * aq_ptr
Definition: gemm_aquant_kernel.hpp:83
Definition: gemm_aquant_kernel.hpp:98
remove_cvref_t< typename GemmPipeline::AQLayout > AQLayout
Definition: gemm_aquant_kernel.hpp:103
static constexpr auto I1
Definition: gemm_aquant_kernel.hpp:114
static constexpr auto I0
Definition: gemm_aquant_kernel.hpp:113
static constexpr auto I3
Definition: gemm_aquant_kernel.hpp:116
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, CDataType *c_ptr, const AQuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_aquant_kernel.hpp:346
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_aquant_kernel.hpp:101
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_aquant_kernel.hpp:474
static constexpr CK_TILE_HOST AQuantGemmKernelArgs MakeKernelArgs(const AQuantGemmHostArgs &hostArgs)
Definition: gemm_aquant_kernel.hpp:133
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_aquant_kernel.hpp:130
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_aquant_kernel.hpp:108
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_aquant_kernel.hpp:104
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_aquant_kernel.hpp:110
static constexpr index_t KernelBlockSize
Definition: gemm_aquant_kernel.hpp:106
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_aquant_kernel.hpp:105
CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const
Definition: gemm_aquant_kernel.hpp:657
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_aquant_kernel.hpp:102
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_aquant_kernel.hpp:547
remove_cvref_t< typename GemmPipeline::AQDataType > AQDataType
Definition: gemm_aquant_kernel.hpp:109
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, CDataType *c_ptr, void *smem_ptr_0, const AQuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_aquant_kernel.hpp:621
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_aquant_kernel.hpp:99
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_aquant_kernel.hpp:150
static CK_TILE_HOST bool IsSupportedArgument(const AQuantGemmKernelArgs &kargs)
Definition: gemm_aquant_kernel.hpp:197
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_aquant_kernel.hpp:111
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_aquant_kernel.hpp:100
static CK_TILE_HOST const std::string GetName()
Definition: gemm_aquant_kernel.hpp:118
static constexpr auto I2
Definition: gemm_aquant_kernel.hpp:115
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_aquant_kernel.hpp:125
Definition: gemm_aquant_kernel.hpp:16
index_t M
Definition: gemm_aquant_kernel.hpp:37
index_t stride_B
Definition: gemm_aquant_kernel.hpp:42
CK_TILE_HOST AQuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_)
Definition: gemm_aquant_kernel.hpp:18
index_t stride_A
Definition: gemm_aquant_kernel.hpp:41
CK_TILE_HOST AQuantGemmProblem()=default
index_t QK
Definition: gemm_aquant_kernel.hpp:40
index_t stride_C
Definition: gemm_aquant_kernel.hpp:43
index_t K
Definition: gemm_aquant_kernel.hpp:39
index_t N
Definition: gemm_aquant_kernel.hpp:38
index_t stride_AQ
Definition: gemm_aquant_kernel.hpp:44
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
#define CK_TILE_ENV(name)
Definition: env.hpp:145