include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp Source File

include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp Source File#

Composable Kernel: include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp Source File
gemm_group_quant_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
11 CK_TILE_HOST_DEVICE static constexpr auto GetAQGlobalVectorLoadSize()
12 {
13  using I1 = number<1>;
14  constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
15 
16  constexpr index_t BlockSize = Problem::kBlockSize;
17 
18  // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps
19  constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps);
20  constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
21 
22  // Define vector load candidates in descending order of priority
23  constexpr std::array<index_t, 5> candidates{
24  PackedSize * 32 / sizeof(DataType),
25  PackedSize * 16 / sizeof(DataType),
26  PackedSize * 8 / sizeof(DataType),
27  PackedSize * 4 / sizeof(DataType),
28  PackedSize * 2 / sizeof(DataType),
29  };
30 
31  for(const auto vec_size : candidates)
32  {
33  if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0)
34  continue;
35  bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) &&
36  (elements_per_thread % vec_size == 0) && vec_size != candidates[4];
37  if(is_valid)
38  {
39  return vec_size;
40  }
41  }
42  return PackedSize; // Absolute fallback
43 }
44 
45 // AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across
46 // threads. Post mfma scales are shuffled across threads in the warp and applied to
47 // accum registers.
48 template <typename BlockGemmShape,
49  typename WarpGemm,
50  index_t BlockSize,
51  index_t YPerTile,
52  index_t XPerTile,
53  index_t VecSize>
55 {
56  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
57  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
58  static constexpr index_t warp_size = get_warp_size();
59  static constexpr index_t num_warps = BlockSize / get_warp_size();
60 
61  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
62  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
63  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
64 
65  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
66 
67  static_assert(num_warps == MWarps * NWarps * KWarps);
68 
69  // KWarps > 1 isn't supported
70  static_assert(KWarps == 1);
71 
72  // # of elements per thread
73  static constexpr index_t X = XPerTile;
74 
75  static constexpr index_t Y0 = 1;
76  static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
77  static constexpr index_t Y2 = MWarps;
78  static constexpr index_t Y3 = WarpGemm::kM;
79  static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
80  static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
81  "Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
82 
84  {
91  sequence<1, 0>>{});
92  }
93 };
94 
95 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: gemm_group_quant_utils.hpp:55
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:62
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:65
static constexpr index_t Y3
Definition: gemm_group_quant_utils.hpp:78
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:59
static constexpr CK_TILE_HOST_DEVICE auto Make2DStaticTileDistribution()
Definition: gemm_group_quant_utils.hpp:83
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:63
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:61
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:58
static constexpr index_t Y0
Definition: gemm_group_quant_utils.hpp:75
static constexpr index_t X
Definition: gemm_group_quant_utils.hpp:73
static constexpr index_t Y2
Definition: gemm_group_quant_utils.hpp:77
static constexpr index_t Y1
Definition: gemm_group_quant_utils.hpp:76
Definition: static_encoding_pattern.hpp:107
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: sequence.hpp:52
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192