6 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
15 #ifdef CK_EXPERIMENTAL_BUILDER
16 #include "ck_tile/builder/reflect/description.hpp"
23 namespace tensor_operation {
26 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
27 #define GET_OBJECT_NAME_IMLP \
28 std::optional<std::string> GetObjectName() const override \
30 std::string str = __PRETTY_FUNCTION__; \
31 static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
33 if(!std::regex_search(str, match, obj_name_expr)) \
37 return std::string(match[1]) + ';'; \
40 #define GET_TEMPLATE_INFO_IMPL \
41 std::optional<std::string> GetTemplateInfo() const override \
43 std::string str = __PRETTY_FUNCTION__; \
44 static std::regex template_expr{"\\[(.*)\\]"}; \
46 if(!std::regex_search(str, match, template_expr)) \
48 return std::nullopt; \
50 return std::string(match[1]); \
53 #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
63 static constexpr
auto GetXdlPerWave2()
65 constexpr
index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
66 constexpr
index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
67 static_assert(MWaves > 0);
69 constexpr
index_t NWaves = Waves / MWaves;
70 if constexpr(NWaves == 0)
76 if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0)
78 return NPerBlock_ / (NWaves * NPerXDL_);
87 #define GET_NXDL_PER_WAVE_IMPL \
88 template <bool IsWave64> \
89 static constexpr auto GetNXdlPerWave() \
91 return GetXdlPerWave2<BlockSize, \
100 #define GET_MXDL_PER_WAVE_IMPL \
101 template <bool IsWave64, \
102 index_t MPerXDLAligned = MPerXDL, \
103 index_t NPerXDLAligned = NPerXDL, \
104 index_t NXdlPerWaveAligned = NXdlPerWave> \
105 static constexpr auto GetMXdlPerWave() \
107 return GetXdlPerWave2<BlockSize, \
112 NXdlPerWaveAligned, \
122 index_t CShuffleMXdlPerWavePerShuffle_,
123 index_t CShuffleNXdlPerWavePerShuffle_,
125 static constexpr
auto GetWarpTileConfig()
127 constexpr
auto MXdlPerWave64 = MXdlPerWave_;
128 constexpr
auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
129 constexpr
auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
131 constexpr
auto NXdlPerWave =
133 ? GetXdlPerWave2<BlockSize_,
140 : GetXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
142 if constexpr(IsWave64 ==
false && NXdlPerWave != 0)
144 constexpr
auto CShuffleNXdlPerWavePerShuffle32 =
145 NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
146 ? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
147 : CShuffleNXdlPerWavePerShuffle_;
148 static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
153 CShuffleMXdlPerWavePerShuffle32,
154 CShuffleNXdlPerWavePerShuffle32>{};
158 return Sequence<MPerXDL_,
162 CShuffleMXdlPerWavePerShuffle_,
163 CShuffleNXdlPerWavePerShuffle_>{};
167 #define INVOKER_RUN_IMPL \
168 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
170 if(get_warp_size() == 64) \
172 if constexpr(NXdlPerWave64 > 0) \
174 return RunImp<GridwiseGemm64>(arg, stream_config); \
179 if constexpr(NXdlPerWave32 > 0) \
181 return RunImp<GridwiseGemm32>(arg, stream_config); \
187 #define INVOKER_RUN3_IMPL \
188 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
190 if(get_warp_size() == 64) \
192 if constexpr(NXdlPerWave64 > 0) \
194 return RunImp<GridwiseGemm64>(arg, stream_config); \
199 if constexpr(NXdlPerWave32 > 0) \
201 return RunImp<GridwiseGemm32>( \
202 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg), \
218 __device__
static bool constexpr IsValidGemmCompilationParameter()
220 #if defined(__gfx11__) || defined(__gfx12__)
221 if constexpr(MPerXdl != 16 || NPerXdl != 16)
227 #if defined(__gfx11__)
230 constexpr
bool SupportMemOp =
233 if constexpr(SupportMemOp ==
false)
238 if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
240 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
241 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
242 if constexpr(MWaves > 0 && NWaves > 0)
244 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
251 #define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_) \
252 template <InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = \
253 InMemoryDataOperationEnum::Set> \
254 __device__ static bool constexpr IsValidCompilationParameter() \
256 return ck::tensor_operation::device::IsValidGemmCompilationParameter< \
265 CGlobalMemoryDataOperation_>(); \
268 #ifndef CK_CODE_GEN_RTC
300 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
304 #ifdef CK_EXPERIMENTAL_BUILDER
306 virtual std::unique_ptr<ck_tile::reflect::Description> describe()
const {
return nullptr; }
313 virtual std::optional<std::string>
GetObjectName()
const {
return std::nullopt; }
319 std::ostringstream oss;
321 oss << std::hex <<
typeid(*this).hash_code();
InMemoryDataOperationEnum
Definition: ck.hpp:279
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:301
Definition: stream_config.hpp:10
Definition: device_base.hpp:270
BaseArgument & operator=(const BaseArgument &)=default
BaseArgument(const BaseArgument &)=default
void * p_workspace_
Definition: device_base.hpp:277
virtual __host__ __device__ ~BaseArgument()
Definition: device_base.hpp:275
Definition: device_base.hpp:281
virtual ~BaseInvoker()
Definition: device_base.hpp:291
BaseInvoker & operator=(const BaseInvoker &)=default
virtual float Run(const BaseArgument *, const StreamConfig &=StreamConfig{})
Definition: device_base.hpp:286
BaseInvoker(const BaseInvoker &)=default
Definition: device_base.hpp:296
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:328
virtual std::string GetInstanceString() const
Definition: device_base.hpp:309
virtual bool IsSupportedArgument(const BaseArgument *)
Definition: device_base.hpp:301
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition: device_base.hpp:326
virtual std::optional< std::string > GetTemplateInfo() const
Definition: device_base.hpp:315
virtual std::string GetTypeString() const
Definition: device_base.hpp:302
BaseOperator(const BaseOperator &)=default
virtual std::string GetTypeIdHashCode() const
Definition: device_base.hpp:317
virtual std::optional< std::string > GetObjectName() const
Definition: device_base.hpp:313
BaseOperator & operator=(const BaseOperator &)=default
virtual std::string GetTypeIdName() const
Definition: device_base.hpp:311
virtual ~BaseOperator()
Definition: device_base.hpp:336