/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp Source File
device_moe_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename CLayout,
28  typename ADataType,
29  typename BDataType,
30  typename DsDataType,
31  typename CDataType,
32  typename GemmAccDataType,
33  typename CShuffleDataType,
34  typename AElementwiseOperation,
35  typename BElementwiseOperation,
36  typename CElementwiseOperation,
37  GemmSpecialization GemmSpec,
38  index_t BlockSize,
39  index_t MPerBlock,
40  index_t NPerBlock,
41  index_t KPerBlock,
42  index_t AK1,
43  index_t BK1,
44  index_t MPerXDL,
45  index_t NPerXDL,
46  index_t MXdlPerWave,
47  index_t NXdlPerWave,
48  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
49  typename ABlockTransferThreadClusterArrangeOrder,
50  typename ABlockTransferSrcAccessOrder,
51  index_t ABlockTransferSrcVectorDim,
52  index_t ABlockTransferSrcScalarPerVector,
53  index_t ABlockTransferDstScalarPerVector_AK1,
54  bool ABlockLdsExtraM,
55  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
56  typename BBlockTransferThreadClusterArrangeOrder,
57  typename BBlockTransferSrcAccessOrder,
58  index_t BBlockTransferSrcVectorDim,
59  index_t BBlockTransferSrcScalarPerVector,
60  index_t BBlockTransferDstScalarPerVector_BK1,
61  bool BBlockLdsExtraN,
62  index_t CShuffleMXdlPerWavePerShuffle,
63  index_t CShuffleNXdlPerWavePerShuffle,
64  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
65  typename CDEShuffleBlockTransferScalarPerVectors,
68  index_t ActivationOP = 0,
69  bool NSwizzle = false,
70  bool IsInputGemm = true,
71  bool MulRoutedWeight = true,
72  bool PerTokenQuant = true,
73  typename IndexType = index_t,
74  typename ComputeTypeA = CDataType,
75  typename ComputeTypeB = ComputeTypeA,
76  typename LDSTypeA = ComputeTypeA,
77  typename LDSTypeB = ComputeTypeB>
79  BLayout,
80  DsLayout,
81  CLayout,
82  ADataType,
83  BDataType,
84  DsDataType,
85  CDataType,
86  AElementwiseOperation,
87  BElementwiseOperation,
88  CElementwiseOperation>
89 {
90  static constexpr index_t NumDTensor = DsDataType::Size();
91  using GridwiseGemm =
92  GridwiseMoeGemm<ALayout,
93  BLayout,
94  DsLayout,
95  CLayout,
96  ADataType,
97  BDataType,
98  GemmAccDataType,
99  CShuffleDataType,
100  DsDataType,
101  CDataType,
102  AElementwiseOperation,
103  BElementwiseOperation,
104  CElementwiseOperation,
105  GemmSpec,
106  BlockSize,
107  MPerBlock,
108  NPerBlock,
109  KPerBlock,
110  AK1,
111  BK1,
112  MPerXDL,
113  NPerXDL,
114  MXdlPerWave,
115  NXdlPerWave,
116  ABlockTransferThreadClusterLengths_AK0_M_AK1,
117  ABlockTransferThreadClusterArrangeOrder,
118  ABlockTransferSrcAccessOrder,
119  ABlockTransferSrcVectorDim,
120  ABlockTransferSrcScalarPerVector,
121  ABlockTransferDstScalarPerVector_AK1,
122  false,
123  ABlockLdsExtraM,
124  BBlockTransferThreadClusterLengths_BK0_N_BK1,
125  BBlockTransferThreadClusterArrangeOrder,
126  BBlockTransferSrcAccessOrder,
127  BBlockTransferSrcVectorDim,
128  BBlockTransferSrcScalarPerVector,
129  BBlockTransferDstScalarPerVector_BK1,
130  false,
131  BBlockLdsExtraN,
132  CShuffleMXdlPerWavePerShuffle,
133  CShuffleNXdlPerWavePerShuffle,
134  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
135  CDEShuffleBlockTransferScalarPerVectors,
136  BlkGemmPipeSched,
137  BlkGemmPipelineVer,
138  ActivationOP,
139  NSwizzle,
140  IsInputGemm,
141  MulRoutedWeight,
142  PerTokenQuant,
143  IndexType,
144  ComputeTypeA,
145  ComputeTypeB,
146  LDSTypeA,
147  LDSTypeB>;
148 
150 
151  static constexpr index_t APackedSize = []() {
153  return 2;
154  else
155  return 1;
156  }();
157 
158  static constexpr index_t BPackedSize = []() {
160  return 2;
161  else
162  return 1;
163  }();
164 
165  int GetPreShuffleParameters() override { return NPerXDL; }
166 
167  // Invoker
168  struct Invoker : public BaseInvoker
169  {
170  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
171  {
172  if(stream_config.log_level_ > 0)
173  {
174  arg.Print();
175  }
176 
178  {
179  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
180  }
181 
182  index_t gdx, gdy, gdz;
183  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
184 
185  float ave_time = 0;
186 
187  index_t k_grain = arg.KBatch * KPerBlock;
188  index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
189 
190  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
191 
192  const auto RunKernel = [&](const auto& kernel) {
193  if(stream_config.flush_cache)
194  {
195 
196  std::array<std::size_t, NumDTensor> DsSize;
197 
198  Argument arg_ = arg;
199 
200  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
201  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
202  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
203  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
204 
205  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
206  sizeof(ADataType) / APackedSize;
207  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
208  sizeof(BDataType) / BPackedSize;
209 
210  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
211  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
212 
213  static_for<0, NumDTensor, 1>{}([&](auto i) {
214  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
215  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
216  });
218  arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
219  rotating_mem.Print();
220 
221  auto run_flush_cache = [&]() {
222  // flush icache
224  // rotating mem
225  rotating_mem.Next();
226  // clear c mem
227  if(arg_.KBatch > 1)
228  hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
229  0,
230  arg_.M * arg_.N * sizeof(CDataType),
231  stream_config.stream_id_));
232  };
233 
234  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
235  stream_config,
236  run_flush_cache,
237  kernel,
238  dim3(gdx, gdy, gdz),
239  dim3(BlockSize),
240  0,
241  arg_);
242  }
243  else
244  {
245  if(arg.KBatch > 1)
246  hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
247  0,
248  arg.M * arg.N * sizeof(CDataType),
249  stream_config.stream_id_));
250 
251  ave_time = launch_and_time_kernel(
252  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
253  }
254  };
255 
256  constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
257  4 * (1 + GridwiseGemm::NWave);
258  constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
259  4 * (2) * (IsInputGemm ? 2 : 1);
260  constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
261  BlockSize / 4 * (IsInputGemm ? 2 : 1);
262  constexpr auto estimated_reg_total =
263  estimated_reg_a + estimated_reg_b + estimated_reg_c;
264 
265  constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
266 
267  constexpr auto MemoryDataOp =
269  if(has_main_k_block_loop)
270  {
271  // Tail number always full
272  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
273  {
274  {
276  {
277  const auto kernel = kernel_moe_gemm<GridwiseGemm,
278  true,
279  MemoryDataOp,
280  minimum_occupancy,
282  RunKernel(kernel);
283  }
284  else
285  {
286  const auto kernel = kernel_moe_gemm<GridwiseGemm,
287  true,
288  MemoryDataOp,
289  minimum_occupancy,
291  RunKernel(kernel);
292  }
293  }
294  }
295  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
296  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
297  {
299  {
300  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
301  true,
302  MemoryDataOp,
303  minimum_occupancy,
305  RunKernel(kernel);
306  }
307  else
308  {
309  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
310  true,
311  MemoryDataOp,
312  minimum_occupancy,
314  RunKernel(kernel);
315  }
316  }
317  else
318  {
319  throw std::runtime_error("todo: only v1 & v2 support now");
320  }
321  }
322 #if 1
323  else
324  {
325  // Tail number always 1
326  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
327  {
328  const auto kernel = kernel_moe_gemm<GridwiseGemm,
329  true,
331  minimum_occupancy,
333  RunKernel(kernel);
334  }
335  }
336 #endif
337 
338  return ave_time;
339  }
340 
341  // polymorphic
342  float Run(const BaseArgument* p_arg,
343  const StreamConfig& stream_config = StreamConfig{}) override
344  {
345  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
346  }
347  };
348 
349  static constexpr bool IsValidCompilationParameter()
350  {
351  // TODO: properly implement this check
352  return true;
353  }
354 
355  static bool IsSupportedArgument(const Argument& arg)
356  {
357  // only impl kbatch 1 now
358  if(arg.KBatch > 1)
359  {
360  return false;
361  }
362  if(!ck::is_xdl_supported())
363  {
364  return false;
365  }
366 
367  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
368  {
369  return false;
370  }
371 
372  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
373  GemmSpec == GemmSpecialization::NKPadding ||
374  GemmSpec == GemmSpecialization::MNKPadding ||
375  GemmSpec == GemmSpecialization::KPadding))
376  {
377  return false;
378  }
379  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
380  {
381  return false;
382  }
383 
384  return GridwiseGemm::CheckValidity(arg);
385  }
386 
387  // polymorphic
388  bool IsSupportedArgument(const BaseArgument* p_arg) override
389  {
390  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
391  }
392 
393  static auto MakeArgument(const void* p_sorted_token_ids,
394  const void* p_sorted_expert_ids,
395  const void* p_max_token_id,
396  const void* p_a,
397  const void* p_b,
398  std::array<const void*, NumDTensor> p_ds,
399  void* p_c,
400  index_t NumTokens,
401  index_t TopK,
402  index_t M,
403  index_t N,
404  index_t K,
405  index_t StrideA,
406  index_t StrideB,
407  std::array<index_t, NumDTensor> StrideDs,
408  index_t StrideC,
409  index_t KBatch,
410  AElementwiseOperation a_element_op,
411  BElementwiseOperation b_element_op,
412  CElementwiseOperation c_element_op)
413  {
414  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
415  static_cast<const index_t*>(p_sorted_expert_ids),
416  static_cast<const index_t*>(p_max_token_id),
417  static_cast<const ADataType*>(p_a),
418  static_cast<const BDataType*>(p_b),
419  p_ds,
420  static_cast<CDataType*>(p_c),
421  NumTokens,
422  TopK,
423  M,
424  N,
425  K,
426  StrideA,
427  StrideB,
428  StrideDs,
429  StrideC,
430  KBatch,
431  a_element_op,
432  b_element_op,
433  c_element_op};
434  }
435 
436  static auto MakeInvoker() { return Invoker{}; }
437 
438  // polymorphic
439  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
440  const void* p_b,
441  std::array<const void*, NumDTensor> p_ds,
442  void* p_c,
443  index_t M,
444  index_t N,
445  index_t K,
446  index_t StrideA,
447  index_t StrideB,
448  std::array<ck::index_t, NumDTensor> StrideDs,
449  index_t StrideC,
450  index_t KBatch,
451  AElementwiseOperation a_element_op,
452  BElementwiseOperation b_element_op,
453  CElementwiseOperation c_element_op) override
454  {
455  return std::make_unique<Argument>(nullptr,
456  nullptr,
457  nullptr,
458  static_cast<const ADataType*>(p_a),
459  static_cast<const BDataType*>(p_b),
460  p_ds,
461  static_cast<CDataType*>(p_c),
462  M, // randoms set, no use
463  0,
464  M,
465  N,
466  K,
467  StrideA,
468  StrideB,
469  StrideDs,
470  StrideC,
471  KBatch,
472  a_element_op,
473  b_element_op,
474  c_element_op);
475  }
476 
477  // polymorphic
478  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
479  {
480  return std::make_unique<Invoker>(Invoker{});
481  }
482 
483  // polymorphic
484  std::string GetTypeString() const override
485  {
486  auto str = std::stringstream();
487 
488  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
491 
492  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
494 
495  // clang-format off
496  str << "DeviceMoeGEmm"
497  << "<"
498  << getGemmSpecializationString(GemmSpec) << ", "
499  << std::string(ALayout::name)[0]
500  << std::string(BLayout::name)[0]
501  << std::string(CLayout::name)[0]
502  << ">"
503  << " BlkSize: "
504  << BlockSize << ", "
505  << "BlkTile: "
506  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
507  << "WaveTile: "
508  << MPerXDL<<"x"<<NPerXDL << ", "
509  << "WaveMap: "
510  << MXdlPerWave<<"x" << NXdlPerWave<<", "
511  << "VmemReadVec: "
512  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
513  << "BlkGemmPipelineScheduler: "
514  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
515  << "BlkGemmPipelineVersion: "
516  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
517  << "BlkGemmPipelinePrefetchStages: "
518  << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
519  // clang-format on
520 
521  return str.str();
522  }
523 };
524 
525 } // namespace device
526 } // namespace tensor_operation
527 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:216
Definition: ck.hpp:269
bool is_xdl_supported()
Definition: device_prop.hpp:55
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:300
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:68
Definition: stream_config.hpp:10
Definition: gridwise_moe_gemm.hpp:666
Definition: gridwise_moe_gemm.hpp:165
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:240
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:324
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1134
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:414
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1127
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:562
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:954
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:207
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_moe_gemm.hpp:169
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm.hpp:170
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm.hpp:342
Definition: device_moe_gemm.hpp:89
int GetPreShuffleParameters() override
Definition: device_moe_gemm.hpp:165
typename GridwiseGemm::Argument Argument
Definition: device_moe_gemm.hpp:149
GridwiseMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemm
Definition: device_moe_gemm.hpp:147
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm.hpp:393
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm.hpp:349
static constexpr index_t BPackedSize
Definition: device_moe_gemm.hpp:158
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm.hpp:355
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm.hpp:388
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm.hpp:439
static constexpr index_t NumDTensor
Definition: device_moe_gemm.hpp:90
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm.hpp:478
std::string GetTypeString() const override
Definition: device_moe_gemm.hpp:484
static auto MakeInvoker()
Definition: device_moe_gemm.hpp:436
static constexpr index_t APackedSize
Definition: device_moe_gemm.hpp:151
Definition: flush_cache.hpp:20