/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp Source File
fused_moegemm_pipeline_flatmm_uk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 /*
13 This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
14 we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
15 
16  <----- gemm-N ------>
17  +----+----+----+----+
18  | w0 | w1 | w2 | w3 | gemm-m
19  +----+----+----+----+
20 */
21 template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
23 {
26 
27  using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
28 
29  using ADataType = typename Problem::ADataType;
30  using GDataType = typename Problem::GDataType;
31  using DDataType = typename Problem::DDataType;
32  using AccDataType = typename Problem::AccDataType;
33  using ODataType = typename Problem::ODataType;
34  using AScaleDataType = typename Problem::AScaleDataType;
35  using GScaleDataType = typename Problem::GScaleDataType;
36  using DScaleDataType = typename Problem::DScaleDataType;
37  using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
38  using TopkWeightDataType = typename Problem::TopkWeightDataType;
39  using IndexDataType = typename Problem::IndexDataType;
40  using YDataType = typename Problem::YDataType;
41 
42  using Traits = typename Problem::Traits;
43 
44  static constexpr bool IsGateOnly = Traits::IsGateOnly;
45  static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
46  static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
47  static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
48 
49  static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
50  static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
51  static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
52  static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
53 
58 
59  static constexpr index_t kBlockPerCu = []() {
60  if constexpr(Problem::kBlockPerCu != -1)
61  return Problem::kBlockPerCu;
62  else
63  {
64  // minimize occupancy
65  return 2;
66  }
67  }();
68 
69  static constexpr const char* name = "flatmm_uk";
70 
72  {
73 #if 1
74  constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
75  constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
76  constexpr index_t smem_bridge =
77  BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
78  return max(smem_0 + smem_1, smem_bridge);
79 #else
80  // keep it here purposely in case we have regression
81  return 65536;
82 #endif
83  }
84 
85  // this is the thread-offset along row/col
87  {
88  constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
89  const auto a_coord = a_dist.calculate_index();
90  return a_coord;
91  }
92 
93  // this is the thread-offset along row/col
95  {
96  constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
97  const auto o_coord = o_dist.calculate_index();
98  return o_coord;
99  }
100 
102  {
103  constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
104  constexpr index_t MLans = BlockShape::BlockSize / KLans;
105  constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
106 
107  return MRepeat;
108  }
109 
110  // TODO: properlly support scatter/gather
112  {
113  constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
114  constexpr index_t MLans = BlockShape::BlockSize / KLans;
115  constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
116 
117  auto base_coord = threadIdx.x / KLans + base_offset;
118 
120  static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
121 
122  return coords;
123  }
124 
125  template <typename ROW_COORDS>
126  CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
127  {
128  constexpr index_t n_size = coords.size();
129 
130  array<index_t, n_size> row_ids;
131  static_for<0, n_size, 1>{}([&](auto i) {
132  row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
133 #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
134  row_ids.at(i) &= 0xffffff;
135 #endif
136  });
137 
138  return row_ids;
139  }
140 
141  template <typename ROW_COORDS>
142  CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
143  const TopkWeightDataType* sorted_weight_ptr)
144  {
145  constexpr index_t n_size = coords.size();
146 
148  static_for<0, n_size, 1>{}([&](auto i) {
149  w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
150  });
151 
152  return w;
153  }
154 
155  // TODO: this row id is before shuffle atomic, need use acc distribution
157  {
158  constexpr index_t MLanes = BlockShape::Warp_M1;
159  constexpr index_t Repeat_M = BlockShape::Repeat_M1;
160 
161  auto base_coord = threadIdx.x % MLanes + base_offset;
162 
164  static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; });
165 
166  return coords;
167  }
168 
169  template <typename Karg>
170  CK_TILE_DEVICE auto operator()(const Karg& kargs,
171  CK_TILE_LDS_ADDR void* smem,
172  index_t sorted_tile_id,
173  index_t intermediate_tile_id)
174  {
175  constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
176  ck_tile::index_t shared_intermediate_size_0 =
177  kargs.intermediate_size * hidden_radio_0; // total gate+up
178  ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size;
179 
180  // after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
181 
182  index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
183  index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
184  index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
185  index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
186 
187  const IndexDataType expert_id = amd_wave_read_first_lane(
188  reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
189  index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
190  index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
191 
192  // nr*kr*w
193  index_t interm_idx_nr0 = amd_wave_read_first_lane(
194  intermediate_tile_id *
195  BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
196 
197  index_t interm_idx_kr1 = amd_wave_read_first_lane(
198  intermediate_tile_id *
199  BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
200 
201  auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
202  auto row_ids_a = GetRowID(
203  row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
204  auto a_coords = generate_tuple(
205  [&](auto i) {
206  return row_ids_a[i] * kargs.stride_token +
207  threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
208  },
209  number<row_ids_a.size()>{});
210 
211  auto a_res =
212  make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
213  kargs.num_tokens * kargs.stride_token * sizeof(ADataType),
214  std::true_type{});
215 
216  auto make_gu_win = [&](const auto* ptr_) {
217  auto view_ = make_naive_tensor_view<address_space_enum::global>(
218  ptr_,
220  make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
222  number<1>{});
223 
224  auto win_ = make_tile_window_linear_raw(
225  view_,
229  {0, 0, 0},
230  Policy::template MakeGlobalTileDistribution_G<Problem>(),
232  return win_;
233  };
234 
235  const GDataType* gu_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
236  static_cast<long_index_t>(expert_id) * expert_stride_0 +
237  interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
238 
239  auto g_win = make_gu_win(gu_ptr);
240  // Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
241  auto u_win = make_gu_win(gu_ptr + kargs.intermediate_size * kargs.hidden_size);
242 
243  auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
244  auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
245  auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
246  number<decltype(g_win)::NumAccess_NonLinear>{});
247 
248  const auto d_win = [&]() {
249  const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
250  static_cast<long_index_t>(expert_id) * expert_stride_1 +
251  interm_idx_kr1 * BlockShape::Block_W1;
252  // note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
253 
254  const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
255  d_ptr,
256  make_tuple(nr_1, kr_1, BlockShape::Block_W1),
257  make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
259  number<1>{});
260 
261  const auto d_window_ = make_tile_window_linear_raw(
262  d_view_,
266  {0, 0, 0},
267  Policy::template MakeGlobalTileDistribution_D<Problem>(),
269  return d_window_;
270  }();
271  auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
272 
273  // TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
274  // block-k=512, block-n=128
275  // wg |<----- W_ ----->|
276  // Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
277  // y p y y p p y
278  // 1 2 0(imm)
279  auto d_coords = [&]() {
280  constexpr index_t Nr_ = 2;
281  constexpr index_t Nw_ = 4;
282  constexpr index_t Kr0_ = 4;
283  constexpr index_t Kr1_ = 4;
284  constexpr index_t Kl_ = 4;
285  constexpr index_t Nl_ = 16;
286  constexpr index_t Kv_ = 8;
287  constexpr index_t W_ = Kl_ * Nl_ * Kv_;
288  constexpr index_t num_offsets_ = Nr_ * Kr0_;
289  index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
290  shared_intermediate_size_1 *
291  Nl_; // Kr0_ * Kr1_ * W_;
292  return generate_tuple(
293  [&](auto i) {
294  constexpr auto i_nr_ = number<i % Nr_>{};
295  constexpr auto i_kr0_ = number<i / Nr_>{};
296 
297  return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
298  base_os_;
299  },
301  }();
302 
303  auto o_coords = generate_tuple(
304  [&](auto i) {
305  return row_ids_a[i] * kargs.stride_token +
306  threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
307  },
308  number<row_ids_a.size()>{});
309 
310  auto o_flags =
311  generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
312  number<row_ids_a.size()>{});
313 
314  auto bridge_sst_win = [&]() {
315  constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
316  constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
317  return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
318  reinterpret_cast<YDataType*>(smem), desc_),
319  desc_.get_lengths(),
320  {0, 0},
321  dist_);
322  }();
323 
324  auto o_res =
325  make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
326  kargs.num_tokens * kargs.stride_token * sizeof(ODataType),
327  std::true_type{});
328  auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
329  auto w_scale = GetWeightScale(
330  row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
331 
332  auto uk_0 = Policy::template GetUK_0<Problem>();
333 
334  auto y_pre = [&]() {
335  if constexpr(IsGateOnly)
336  {
337  auto acc_0 = uk_0(a_res,
338  a_coords,
339  g_res,
340  g_coords,
341  smem,
342  kargs.hidden_size,
343  BlockShape::Block_K0, // tile offset for B matrix each unroll
344  BlockShape::Block_Kr0 *
345  BlockShape::Block_W0); // tile offset for B matrix each unroll
346 
347  sweep_tile(
348  acc_0,
349  [&](auto idx0, auto idx1) {
350  fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
351  typename Problem::GateActivation{}(v_, v_);
352  acc_0(idx0) = v_.x;
353  acc_0(idx1) = v_.y;
354  },
355  sequence<1, 2>{});
356 
357  return cast_tile<YDataType>(acc_0);
358  }
359  else
360  {
361  uint32x8_t gu_res;
362  gu_res[0] = g_res[0];
363  gu_res[1] = g_res[1];
364  gu_res[2] = g_res[2];
365  gu_res[3] = g_res[3];
366  gu_res[4] = u_res[0];
367  gu_res[5] = u_res[1];
368  gu_res[6] = u_res[2];
369  gu_res[7] = u_res[3];
370 
371  auto acc_0 = uk_0(a_res,
372  a_coords,
373  gu_res,
374  g_coords,
375  smem,
376  kargs.hidden_size,
377  BlockShape::Block_K0, // tile offset for B matrix each unroll
378  BlockShape::Block_Kr0 * BlockShape::Block_W0,
379  bool_constant<true>{}); // tile offset for B matrix each unroll
380 
381  sweep_tile(
382  acc_0.at(number<0>{}),
383  [&](auto idx0, auto idx1) {
384  fp32x2_t v_{acc_0.at(number<0>{})(idx0), acc_0.at(number<0>{})(idx1)};
385  typename Problem::GateActivation{}(v_, v_);
386  acc_0.at(number<0>{})(idx0) = v_.x;
387  acc_0.at(number<0>{})(idx1) = v_.y;
388  },
389  sequence<1, 2>{});
390 
391  auto reduced_acc_0 =
392  tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
393  acc_0.at(number<0>{}),
394  acc_0.at(number<1>{}));
395 
396  return cast_tile<YDataType>(reduced_acc_0);
397  }
398  }();
399 
400  block_sync_lds();
401 
402  store_tile(bridge_sst_win, y_pre);
403  block_sync_lds();
404 
405  auto uk_1 = Policy::template GetUK_1<Problem>();
406  uk_1(d_res,
407  d_coords,
408  o_res,
409  o_coords,
410  o_flags,
411  smem,
412  kargs.hidden_size, // total n number
413  w_scale,
414  BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
415  BlockShape::Block_N1); // along N
416  }
417 };
418 
419 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:245
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_LDS_ADDR
Definition: config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto cmp_lt_to_exec(const X &x, const Y &y)
Definition: utility.hpp:133
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:33
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
uint32_t uint32x8_t
Definition: vector_type.hpp:154
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1029
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, uint32_t size=0xffffffff, ForceSGPR={})
Definition: amd_buffer_addressing.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:23
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:39
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:29
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:52
constexpr CK_TILE_DEVICE auto GetNumRowCoords_A()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:101
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:31
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:47
CK_TILE_DEVICE auto operator()(const Karg &kargs, CK_TILE_LDS_ADDR void *smem, index_t sorted_tile_id, index_t intermediate_tile_id)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:170
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:69
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:111
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:94
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:36
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:71
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:51
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:49
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:56
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:59
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:35
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:38
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:44
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:33
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:27
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:50
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:46
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:57
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:54
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords, const TopkWeightDataType *sorted_weight_ptr)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:142
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:24
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:37
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType *sorted_token_ids_ptr)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:126
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:30
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:40
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:156
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:25
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:45
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:42
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:32
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:34
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_uk.hpp:86
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43