/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File
arch.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 // Address Space for AMDGCN
7 // https://llvm.org/docs/AMDGPUUsage.html#address-space
8 
16 
17 #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
18 #define CK_TILE_VMCNT(cnt) \
19  ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
20  ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
21 #define CK_TILE_EXPCNT(cnt) \
22  ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
23 #define CK_TILE_LGKMCNT(cnt) \
24  ([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
25 
26 namespace ck_tile {
27 
28 template <typename, bool>
29 struct safe_underlying_type;
30 
31 template <typename T>
32 struct safe_underlying_type<T, true>
33 {
34  using type = std::underlying_type_t<T>;
35 };
36 
37 template <typename T>
38 struct safe_underlying_type<T, false>
39 {
40  using type = void;
41 };
42 
43 template <typename T>
44 using safe_underlying_type_t = typename safe_underlying_type<T, std::is_enum<T>::value>::type;
45 
46 enum struct address_space_enum : std::uint16_t
47 {
48  generic = 0,
49  global,
50  lds,
51  sgpr,
52  constant,
53  vgpr
54 };
55 
56 enum struct memory_operation_enum : std::uint16_t
57 {
58  set = 0,
59  atomic_add,
60  atomic_max,
61  add
62 };
63 
64 namespace core::arch {
65 
70 enum struct amdgcn_target_id
71 {
72  GFX908 = 0x0908, // MI-100...
73  GFX90A = 0x090A,
74  GFX942 = 0x0942,
75  GFX950 = 0x0950,
76  GFX1030 = 0x1030,
77  GFX1031 = 0x1031,
78  GFX1032 = 0x1032,
79  GFX1034 = 0x1034,
80  GFX1035 = 0x1035,
81  GFX1036 = 0x1036,
82  GFX103_GENERIC = 0x103F,
83  GFX1100 = 0x1100,
84  GFX1101 = 0x1101,
85  GFX1102 = 0x1102,
86  GFX1103 = 0x1103,
87  GFX1150 = 0x1150,
88  GFX1151 = 0x1151,
89  GFX1152 = 0x1152,
90  GFX11_GENERIC = 0x11FF,
91  GFX1200 = 0x1200,
92  GFX1201 = 0x1201,
93  GFX12_GENERIC = 0x12FF,
94  HOST = 0x0000,
95 };
96 
97 enum struct amdgcn_target_family_id
98 {
99  GFX9 = 0x09,
100  GFX10_3 = 0x10,
101  GFX11 = 0x11,
102  GFX12 = 0x12,
103  HOST = 0x00,
104 };
105 
106 enum struct amdgcn_target_arch_id
107 {
108  CDNA = 0x01,
109  RDNA = 0x02,
110  HOST = 0x00,
111 };
112 
113 enum struct amdgcn_target_wave_size_id
114 {
115  WAVE32 = 32u,
116  WAVE64 = 64u,
117  HOST = 64u, // TODO: Is this correct? Should the host default to 64 or 1?
118 };
119 
120 #if 1 //__cplusplus <= 201703L
121 
122 template <amdgcn_target_id TargetId = amdgcn_target_id::HOST,
123  amdgcn_target_family_id FamilyId = amdgcn_target_family_id::HOST,
124  amdgcn_target_arch_id ArchId = amdgcn_target_arch_id::HOST,
125  amdgcn_target_wave_size_id WaveSizeId = amdgcn_target_wave_size_id::HOST>
126 struct amdgcn_target
127 {
128  static constexpr amdgcn_target_id TARGET_ID = TargetId;
129  static constexpr amdgcn_target_family_id FAMILY_ID = FamilyId;
130  static constexpr amdgcn_target_arch_id ARCH_ID = ArchId;
131  static constexpr amdgcn_target_wave_size_id WAVE_SIZE_ID = WaveSizeId;
132 };
133 
134 template <amdgcn_target_id targetId>
135 static constexpr auto make_amdgcn_gfx9_target()
136 {
137  return amdgcn_target<targetId,
138  amdgcn_target_family_id::GFX9,
139  amdgcn_target_arch_id::CDNA,
140  amdgcn_target_wave_size_id::WAVE64>{};
141 }
142 
143 template <amdgcn_target_id targetId>
144 static constexpr auto make_amdgcn_gfx10_3_target()
145 {
146  return amdgcn_target<targetId,
147  amdgcn_target_family_id::GFX10_3,
148  amdgcn_target_arch_id::RDNA,
149  amdgcn_target_wave_size_id::WAVE32>{};
150 }
151 
152 template <amdgcn_target_id targetId>
153 static constexpr auto make_amdgcn_gfx11_target()
154 {
155  return amdgcn_target<targetId,
156  amdgcn_target_family_id::GFX11,
157  amdgcn_target_arch_id::RDNA,
158  amdgcn_target_wave_size_id::WAVE32>{};
159 }
160 
161 template <amdgcn_target_id targetId>
162 static constexpr auto make_amdgcn_gfx12_target()
163 {
164  return amdgcn_target<targetId,
165  amdgcn_target_family_id::GFX12,
166  amdgcn_target_arch_id::RDNA,
167  amdgcn_target_wave_size_id::WAVE32>{};
168 }
169 
170 template <typename CompilerTarget, amdgcn_target_id... TargetIds>
171 static constexpr auto is_target_id_any_of()
172 {
173  return is_any_value_of(CompilerTarget::TARGET_ID, TargetIds...);
174 }
175 
176 template <typename CompilerTarget, amdgcn_target_family_id... FamilyIds>
177 static constexpr auto is_target_family_any_of()
178 {
179  return is_any_value_of(CompilerTarget::FAMILY_ID, FamilyIds...);
180 }
181 
182 template <typename CompilerTarget>
183 static constexpr bool is_target_family_gfx9()
184 {
185  return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX9;
186 }
187 
188 template <typename CompilerTarget>
189 static constexpr bool is_target_family_gfx10_3()
190 {
191  return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX10_3;
192 }
193 
194 template <typename CompilerTarget>
195 static constexpr bool is_target_family_gfx11()
196 {
197  return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX11;
198 }
199 
200 template <typename CompilerTarget>
201 static constexpr bool is_target_family_gfx12()
202 {
203  return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX12;
204 }
205 
206 template <typename CompilerTarget>
207 static constexpr bool is_target_arch_cdna()
208 {
209  return CompilerTarget::ARCH_ID == amdgcn_target_arch_id::CDNA;
210 }
211 
212 template <typename CompilerTarget>
213 static constexpr bool is_target_arch_rdna()
214 {
215  return CompilerTarget::ARCH_ID == amdgcn_target_arch_id::RDNA;
216 }
217 
218 template <typename CompilerTarget>
219 static constexpr bool is_target_wave_size_32()
220 {
221  return CompilerTarget::WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE32;
222 }
223 
224 template <typename CompilerTarget>
225 static constexpr bool is_target_wave_size_64()
226 {
227  return CompilerTarget::WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE64;
228 }
229 
230 // Helper to map compiler state to target arch id
231 
232 #define MAP_COMPILER_STATE_TO_GFX9_TARGET(COMPILER_STATE, TARGET_ID) \
233  if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
234  { \
235  return make_amdgcn_gfx9_target<amdgcn_target_id::TARGET_ID>(); \
236  } \
237  else
238 
239 #define MAP_COMPILER_STATE_TO_GFX10_3_TARGET(COMPILER_STATE, TARGET_ID) \
240  if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
241  { \
242  return make_amdgcn_gfx10_3_target<amdgcn_target_id::TARGET_ID>(); \
243  } \
244  else
245 
246 #define MAP_COMPILER_STATE_TO_GFX11_TARGET(COMPILER_STATE, TARGET_ID) \
247  if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
248  { \
249  return make_amdgcn_gfx11_target<amdgcn_target_id::TARGET_ID>(); \
250  } \
251  else
252 
253 #define MAP_COMPILER_STATE_TO_GFX12_TARGET(COMPILER_STATE, TARGET_ID) \
254  if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
255  { \
256  return make_amdgcn_gfx12_target<amdgcn_target_id::TARGET_ID>(); \
257  } \
258  else
259 
265 constexpr auto get_compiler_target()
266 {
267  MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX908, GFX908);
268  MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX90A, GFX90A);
269  MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX942, GFX942);
270  MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX950, GFX950);
271  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
272  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
273  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
274  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
275  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
276  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
277  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX10_3_GENERIC, GFX103_GENERIC);
278  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1100, GFX1100);
279  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1101, GFX1101);
280  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1102, GFX1102);
281  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1103, GFX1103);
282  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150);
283  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151);
284  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152);
285  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC);
286  MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200);
287  MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201);
288  MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX12_GENERIC, GFX12_GENERIC);
289 
290  // Return HOST by default
291  if constexpr(amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE)
292  {
293  return amdgcn_target<>{};
294  }
295 }
296 
297 // Cleanup
298 #undef MAP_COMPILER_STATE_TO_GFX9_TARGET
299 #undef MAP_COMPILER_STATE_TO_GFX10_3_TARGET
300 #undef MAP_COMPILER_STATE_TO_GFX11_TARGET
301 #undef MAP_COMPILER_STATE_TO_GFX12_TARGET
302 
303 // Sanity check: device compile must have a valid target architecture
304 static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
305  get_compiler_target().TARGET_ID != amdgcn_target_id::HOST,
306  "Device compile must have a valid target device architecture");
307 
308 // Sanity check: host compile must have HOST target architecture
309 static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
310  get_compiler_target().TARGET_ID == amdgcn_target_id::HOST,
311  "Host compile must target HOST architecture");
312 
313 // TODO: c++20 use the make functions and constexpr if to avoid string construction and find at
314 // runtime
315 #define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID(NAME_STRING, TARGET_ID) \
316  if(str.find(NAME_STRING) != std::string::npos) \
317  { \
318  return amdgcn_target_id::TARGET_ID; \
319  } \
320  else
321 
328 // TODO: c++20 constexpr if and string_view to avoid std::string construction and find at runtime
329 // TODO: c++20 return amdgcn_target instance instead of just the target id
330 CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const* testStr)
331 {
332  auto str = std::string(testStr);
333  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx908", GFX908);
334  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx90a", GFX90A);
335  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx942", GFX942);
336  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx950", GFX950);
337  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1030", GFX1030);
338  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1031", GFX1031);
339  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1032", GFX1032);
340  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1034", GFX1034);
341  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1035", GFX1035);
342  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1036", GFX1036);
343  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx10_3_generic", GFX103_GENERIC);
344  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1100", GFX1100);
345  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1101", GFX1101);
346  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1102", GFX1102);
347  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1103", GFX1103);
348  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1150", GFX1150);
349  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1151", GFX1151);
350  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1152", GFX1152);
351  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx11_generic", GFX11_GENERIC);
352  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1200", GFX1200);
353  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1201", GFX1201);
354  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx12_generic", GFX12_GENERIC);
355 
356  // Default case: return HOST target if no match is found
357  return amdgcn_target_id::HOST;
358 }
359 
360 #undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID
361 
368 template <typename CompilerTarget, amdgcn_target_id... SupportedTargetIds>
369 using enable_if_target_id_t =
370  std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID, SupportedTargetIds...)>;
371 
379 template <typename CompilerTarget, amdgcn_target_family_id... SupportedTargetFamilyIds>
380 using enable_if_target_family_id_t =
381  std::enable_if_t<is_any_value_of(CompilerTarget::FAMILY_ID, SupportedTargetFamilyIds...)>;
382 
388 template <typename CompilerTarget, amdgcn_target_arch_id... SupportedTargetArchIds>
389 using enable_if_target_arch_id_t =
390  std::enable_if_t<is_any_value_of(CompilerTarget::ARCH_ID, SupportedTargetArchIds...)>;
391 
399 template <typename CompilerTarget, amdgcn_target_wave_size_id... SupportedTargetWaveSizeIds>
400 using enable_if_target_wave_size_id_t =
401  std::enable_if_t<is_any_value_of(CompilerTarget::WAVE_SIZE_ID, SupportedTargetWaveSizeIds...)>;
402 
404 
409 template <typename CompilerTarget>
410 using enable_if_target_family_gfx9_t =
411  enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX9>;
412 
417 template <typename CompilerTarget>
418 using enable_if_target_family_gfx10_3_t =
419  enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX10_3>;
420 
425 template <typename CompilerTarget>
426 using enable_if_target_family_gfx11_t =
427  enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX11>;
428 
433 template <typename CompilerTarget>
434 using enable_if_target_family_gfx12_t =
435  enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX12>;
436 
441 template <typename CompilerTarget>
442 using enable_if_target_arch_cdna_t =
443  enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::CDNA>;
444 
449 template <typename CompilerTarget>
450 using enable_if_target_arch_rdna_t =
451  enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::RDNA>;
452 
457 template <typename CompilerTarget>
458 using enable_if_target_wave32_t =
459  enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE32>;
460 
465 template <typename CompilerTarget>
466 using enable_if_target_wave64_t =
467  enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE64>;
468 
469 #elif __cplusplus >= 202002L
470 
471 struct amdgcn_target
472 {
473  // Target architecture identifiers
474  // These are set to HOST (0) by default
475  // TARGET_ID is the specific architecture id (e.g., GFX908)
476  // FAMILY_ID is the architecture family id (e.g., GFX9)
477  // ARCH_ID is the architecture class id (e.g., CDNA, RDNA)
478  // WAVE_SIZE_ID is the wavefront size id (e.g., WAVE32, WAVE64)
479  const amdgcn_target_id TARGET_ID = amdgcn_target_id::HOST;
480  const amdgcn_target_family_id FAMILY_ID = amdgcn_target_family_id::HOST;
481  const amdgcn_target_arch_id ARCH_ID = amdgcn_target_arch_id::HOST;
482  const amdgcn_target_wave_size_id WAVE_SIZE_ID = amdgcn_target_wave_size_id::HOST;
483 };
484 
485 static constexpr auto make_amdgcn_gfx10_3_target(amdgcn_target_id targetId)
486 {
487  return amdgcn_target{.TARGET_ID = targetId,
488  .FAMILY_ID = amdgcn_target_family_id::GFX10_3,
489  .ARCH_ID = amdgcn_target_arch_id::RDNA,
490  .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
491 }
492 
493 static constexpr auto make_amdgcn_gfx9_target(amdgcn_target_id targetId)
494 {
495  return amdgcn_target{.TARGET_ID = targetId,
496  .FAMILY_ID = amdgcn_target_family_id::GFX9,
497  .ARCH_ID = amdgcn_target_arch_id::CDNA,
498  .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE64};
499 }
500 
501 static constexpr auto make_amdgcn_gfx11_target(amdgcn_target_id targetId)
502 {
503  return amdgcn_target{.TARGET_ID = targetId,
504  .FAMILY_ID = amdgcn_target_family_id::GFX11,
505  .ARCH_ID = amdgcn_target_arch_id::RDNA,
506  .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
507 }
508 
509 static constexpr auto make_amdgcn_gfx12_target(amdgcn_target_id targetId)
510 {
511  return amdgcn_target{.TARGET_ID = targetId,
512  .FAMILY_ID = amdgcn_target_family_id::GFX12,
513  .ARCH_ID = amdgcn_target_arch_id::RDNA,
514  .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
515 }
516 
517 static constexpr bool is_target_family_gfx9(amdgcn_target target)
518 {
519  return target.FAMILY_ID == amdgcn_target_family_id::GFX9;
520 }
521 
522 static constexpr bool is_target_family_gfx10_3(amdgcn_target target)
523 {
524  return target.FAMILY_ID == amdgcn_target_family_id::GFX10_3;
525 }
526 
527 static constexpr bool is_target_family_gfx11(amdgcn_target target)
528 {
529  return target.FAMILY_ID == amdgcn_target_family_id::GFX11;
530 }
531 
532 static constexpr bool is_target_family_gfx12(amdgcn_target target)
533 {
534  return target.FAMILY_ID == amdgcn_target_family_id::GFX12;
535 }
536 
537 static constexpr bool is_target_arch_cdna(amdgcn_target target)
538 {
539  return target.ARCH_ID == amdgcn_target_arch_id::CDNA;
540 }
541 
542 static constexpr bool is_target_arch_rdna(amdgcn_target target)
543 {
544  return target.ARCH_ID == amdgcn_target_arch_id::RDNA;
545 }
546 
547 static constexpr bool is_target_wave_size_32(amdgcn_target target)
548 {
549  return target.WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE32;
550 }
551 
552 static constexpr bool is_target_wave_size_64(amdgcn_target target)
553 {
554  return target.WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE64;
555 }
556 
557 // Helper to map compiler state to target arch id
558 #define MAP_COMPILER_STATE_TO_GFX10_3_TARGET(COMPILER_STATE, TARGET_ID) \
559  if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
560  { \
561  return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
562  }
563 
564 #define MAP_COMPILER_STATE_TO_GFX9_TARGET(COMPILER_STATE, TARGET_ID) \
565  if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
566  { \
567  return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
568  }
569 
570 #define MAP_COMPILER_STATE_TO_GFX11_TARGET(COMPILER_STATE, TARGET_ID) \
571  if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
572  { \
573  return make_amdgcn_gfx11_target(amdgcn_target_id::TARGET_ID); \
574  }
575 
576 #define MAP_COMPILER_STATE_TO_GFX12_TARGET(COMPILER_STATE, TARGET_ID) \
577  if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
578  { \
579  return make_amdgcn_gfx12_target(amdgcn_target_id::TARGET_ID); \
580  }
581 
586 CK_TILE_HOST_DEVICE constexpr auto get_compiler_target()
587 {
588  MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX908, GFX908);
589  MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX90A, GFX90A);
590  MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX942, GFX942);
591  MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX950, GFX950);
592  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
593  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
594  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
595  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
596  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
597  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
598  MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX10_3_GENERIC, GFX103_GENERIC);
599  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1100, GFX1100);
600  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1101, GFX1101);
601  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1102, GFX1102);
602  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1103, GFX1103);
603  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150);
604  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151);
605  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152);
606  MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC);
607  MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200);
608  MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201);
609  MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX12_GENERIC, GFX12_GENERIC);
610 
611  // Default to HOST
612  return amdgcn_target{};
613 }
614 
615 // Cleanup
616 #undef MAP_COMPILER_STATE_TO_GFX9_TARGET
617 #undef MAP_COMPILER_STATE_TO_GFX10_3_TARGET
618 #undef MAP_COMPILER_STATE_TO_GFX11_TARGET
619 #undef MAP_COMPILER_STATE_TO_GFX12_TARGET
620 
621 // Sanity check: device compile must have a valid target architecture
622 static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
623  get_compiler_target().TARGET_ID != amdgcn_target_id::HOST,
624  "Device compile must have a valid target device architecture");
625 
626 // Sanity check: host compile must have HOST target architecture
627 static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
628  get_compiler_target().TARGET_ID == amdgcn_target_id::HOST,
629  "Host compile must target HOST architecture");
630 
631 #define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET(NAME_STRING, TARGET_ID) \
632  if constexpr(str.find(NAME_STRING) != std::string::npos) \
633  { \
634  return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
635  } \
636  else
637 
638 #define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET(NAME_STRING, TARGET_ID) \
639  if constexpr(str.find(NAME_STRING) != std::string::npos) \
640  { \
641  return make_amdgcn_gfx10_3_target(amdgcn_target_id::TARGET_ID); \
642  } \
643  else
644 
645 #define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET(NAME_STRING, TARGET_ID) \
646  if constexpr(str.find(NAME_STRING) != std::string::npos) \
647  { \
648  return make_amdgcn_gfx11_target(amdgcn_target_id::TARGET_ID); \
649  } \
650  else
651 
652 #define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET(NAME_STRING, TARGET_ID) \
653  if constexpr(str.find(NAME_STRING) != std::string::npos) \
654  { \
655  return make_amdgcn_gfx12_target(amdgcn_target_id::TARGET_ID); \
656  } \
657  else
658 
665 CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* testStr)
666 {
667  auto str = std::string(testStr);
668  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx908", GFX908);
669  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx90a", GFX90A);
670  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx942", GFX942);
671  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx950", GFX950);
672  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1030", GFX1030);
673  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1031", GFX1031);
674  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1032", GFX1032);
675  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1034", GFX1034);
676  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1035", GFX1035);
677  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1036", GFX1036);
678  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx10_3_generic", GFX103_GENERIC);
679  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1100", GFX1100);
680  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1101", GFX1101);
681  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1102", GFX1102);
682  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1103", GFX1103);
683  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1150", GFX1150);
684  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1151", GFX1151);
685  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1152", GFX1152);
686  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx11_generic", GFX11_GENERIC);
687  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1200", GFX1200);
688  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1201", GFX1201);
689  MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx12_generic", GFX12_GENERIC);
690 
691  // Default case
692  return amdgcn_target{};
693 }
694 
695 #undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET
696 #undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET
697 #undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET
698 #undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET
699 
706 template <amdgcn_target CompilerTarget, amdgcn_target_id... SupportedTargetIds>
707 using enable_if_target_id_t =
708  std::enable_if_t<is_any_value_of(CompilerTarget.TARGET_ID, SupportedTargetIds...)>;
709 
717 template <amdgcn_target CompilerTarget, amdgcn_target_family_id... SupportedTargetFamilyIds>
718 using enable_if_target_family_id_t =
719  std::enable_if_t<is_any_value_of(CompilerTarget.FAMILY_ID, SupportedTargetFamilyIds...)>;
720 
726 template <amdgcn_target CompilerTarget, amdgcn_target_arch_id... SupportedTargetArchIds>
727 using enable_if_target_arch_id_t =
728  std::enable_if_t<is_any_value_of(CompilerTarget.ARCH_ID, SupportedTargetArchIds...)>;
729 
737 template <amdgcn_target CompilerTarget, amdgcn_target_wave_size_id... SupportedTargetWaveSizeIds>
738 using enable_if_target_wave_size_id_t =
739  std::enable_if_t<is_any_value_of(CompilerTarget.WAVE_SIZE_ID, SupportedTargetWaveSizeIds...)>;
740 
742 
747 template <amdgcn_target CompilerTarget>
748 using enable_if_target_family_gfx9_t =
749  enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX9>;
750 
755 template <amdgcn_target CompilerTarget>
756 using enable_if_target_family_gfx10_3_t =
757  enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX10_3>;
758 
763 template <amdgcn_target CompilerTarget>
764 using enable_if_target_family_gfx11_t =
765  enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX11>;
766 
771 template <amdgcn_target CompilerTarget>
772 using enable_if_target_family_gfx12_t =
773  enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX12>;
774 
779 template <amdgcn_target CompilerTarget>
780 using enable_if_target_arch_cdna_t =
781  enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::CDNA>;
782 
787 template <amdgcn_target CompilerTarget>
788 using enable_if_target_arch_rdna_t =
789  enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::RDNA>;
790 
795 template <amdgcn_target CompilerTarget>
796 using enable_if_target_wave32_t =
797  enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE32>;
798 
803 template <amdgcn_target CompilerTarget>
804 using enable_if_target_wave64_t =
805  enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE64>;
806 
807 #endif // __cplusplus <= 201703L
808 
809 } // namespace core::arch
810 
811 CK_TILE_HOST bool is_wave32()
812 {
813  hipDeviceProp_t props{};
814  int device;
815  auto status = hipGetDevice(&device);
816  if(status != hipSuccess)
817  {
818  return false;
819  }
820  status = hipGetDeviceProperties(&props, device);
821  if(status != hipSuccess)
822  {
823  return false;
824  }
825  return props.major > 9;
826 }
827 
831 {
832  return static_cast<index_t>(core::arch::get_compiler_target().WAVE_SIZE_ID);
833 }
834 
835 CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
836 
837 CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
838 
839 // TODO: deprecate these
840 CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
841 
842 CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
843 
844 CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
845 
846 // Use these instead
847 CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
848 
849 template <bool ReturnSgpr = true>
850 CK_TILE_DEVICE index_t get_warp_id(bool_constant<ReturnSgpr> = {})
851 {
852  const index_t warp_id = threadIdx.x / get_warp_size();
853  if constexpr(ReturnSgpr)
854  {
855  return amd_wave_read_first_lane(warp_id);
856  }
857  else
858  {
859  return warp_id;
860  }
861 }
862 
863 CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
864 
865 CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
866 
867 CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
868 {
869 #ifdef __gfx12__
870  asm volatile("s_wait_loadcnt %0 \n"
871  "s_barrier_signal -1 \n"
872  "s_barrier_wait -1"
873  :
874  : "n"(cnt)
875  : "memory");
876 #else
877  asm volatile("s_waitcnt vmcnt(%0) \n"
878  "s_barrier"
879  :
880  : "n"(cnt)
881  : "memory");
882 #endif
883 }
884 
885 struct WaitcntLayoutGfx12
886 { // s_wait_loadcnt_dscnt: mem[13:8], ds[5:0]
887  CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // mem
888  CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // ds
889  CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
890 
891  CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 8); }
892  CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 0); }
893  CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
894 };
895 
896 struct WaitcntLayoutGfx11
897 { // vm[15:10] (6), lgkm[9:4] (6), exp unused
898  CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F;
899  CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F;
900  CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
901 
902  CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); }
903  CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); }
904  CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
905 };
906 
907 struct WaitcntLayoutLegacy
908 { // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV
909  CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2
910  CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8]
911  CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4]
912  CK_TILE_DEVICE static constexpr bool HAS_EXP = true;
913 
914  CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c)
915  {
916  c &= VM_MASK;
917  return ((c & 0xF) << 0) | ((c & 0x30) << 10);
918  }
919  CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 8); }
920  CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return ((c & EXP_MASK) << 4); }
921 };
922 
923 // Select active layout
924 #if defined(__gfx12__)
925 using Waitcnt = WaitcntLayoutGfx12;
926 #elif defined(__gfx11__)
927 using Waitcnt = WaitcntLayoutGfx11;
928 #else
929 using Waitcnt = WaitcntLayoutLegacy;
930 #endif
931 
932 //----------------------------------------------
933 // Public API: only from_* (constexpr templates)
934 //----------------------------------------------
935 struct waitcnt_arg
936 {
937  // kMax* exposed for callers; match field widths per-arch
938 #if defined(__gfx12__) || defined(__gfx11__)
939  CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
940  CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
941  CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none
942 #else
943  CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split)
944  CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits
945  CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits
946 #endif
947 
948  template <index_t cnt>
949  CK_TILE_DEVICE static constexpr index_t from_vmcnt()
950  {
951  static_assert((cnt & ~Waitcnt::VM_MASK) == 0, "vmcnt out of range");
952  return Waitcnt::pack_vm(cnt);
953  }
954 
955  template <index_t cnt>
956  CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
957  {
958  static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0, "lgkmcnt out of range");
959  return Waitcnt::pack_lgkm(cnt);
960  }
961 
962  template <index_t cnt>
963  CK_TILE_DEVICE static constexpr index_t from_expcnt()
964  {
965  if constexpr(Waitcnt::HAS_EXP)
966  {
967  // EXP_MASK only exists on legacy
968 #if !defined(__gfx12__) && !defined(__gfx11__)
969  static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range");
970  return Waitcnt::pack_exp(cnt);
971 #else
972  (void)cnt;
973  return 0;
974 #endif
975  }
976  else
977  {
978  static_assert(cnt == 0, "expcnt unsupported on this arch");
979  return 0;
980  }
981  }
982 };
983 
984 template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
985  index_t expcnt = waitcnt_arg::kMaxExpCnt,
986  index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
987 CK_TILE_DEVICE void s_waitcnt()
988 {
989 #if defined(__gfx12__)
990  // GFX12 do't use __builtin_amdgcn_s_waitcnt
991  constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
992  waitcnt_arg::from_expcnt<expcnt>() |
993  waitcnt_arg::from_lgkmcnt<lgkmcnt>();
994 
995  asm volatile("s_wait_loadcnt_dscnt %0" : : "n"(wait_mask) : "memory");
996 #else
997  __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
998  waitcnt_arg::from_expcnt<expcnt>() |
999  waitcnt_arg::from_lgkmcnt<lgkmcnt>());
1000 #endif
1001 }
1002 
1003 template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
1004  index_t expcnt = waitcnt_arg::kMaxExpCnt,
1005  index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
1006 CK_TILE_DEVICE void s_waitcnt_barrier()
1007 {
1008 #if defined(__gfx12__)
1009  // GFX12 optimization: Manual barrier implementation avoids performance penalty
1010  // from __builtin_amdgcn_s_barrier which inserts extra s_wait_loadcnt_dscnt 0x0
1011  constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
1012  waitcnt_arg::from_expcnt<expcnt>() |
1013  waitcnt_arg::from_lgkmcnt<lgkmcnt>();
1014 
1015  asm volatile("s_wait_loadcnt_dscnt %0\n"
1016  "s_barrier_signal -1\n"
1017  "s_barrier_wait -1"
1018  :
1019  : "n"(wait_mask)
1020  : "memory");
1021 #else
1022  s_waitcnt<vmcnt, expcnt, lgkmcnt>();
1023  __builtin_amdgcn_s_barrier();
1024 #endif
1025 }
1026 
1027 template <index_t lgkmcnt = 0>
1029 {
1030  s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
1031 }
1032 
1033 template <index_t vmcnt = 0>
1035 {
1036  s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
1037 }
1038 
1039 CK_TILE_DEVICE void s_nop(index_t cnt = 0)
1040 {
1041 #if 1
1042  asm volatile("s_nop %0" : : "n"(cnt) :);
1043 #else
1044  __builtin_amdgcn_sched_barrier(cnt);
1045 #endif
1046 }
1047 
1048 #define CK_TILE_CONSTANT_ADDRESS_SPACE \
1049  __attribute__((address_space( \
1050  static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
1051 
1052 template <typename T>
1053 __device__ T* cast_pointer_to_generic_address_space(T CK_TILE_CONSTANT_ADDRESS_SPACE* p)
1054 {
1055  // cast a pointer in "Constant" address space (4) to "Generic" address space (0)
1056  // only c-style pointer cast seems be able to be compiled
1057 #pragma clang diagnostic push
1058 #pragma clang diagnostic ignored "-Wold-style-cast"
1059  return (T*)(p); // NOLINT(old-style-cast)
1060 #pragma clang diagnostic pop
1061 }
1062 
1063 template <typename T>
1064 __host__ __device__ T CK_TILE_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
1065 {
1066  // cast a pointer in "Generic" address space (0) to "Constant" address space (4)
1067  // only c-style pointer cast seems be able to be compiled;
1068 #pragma clang diagnostic push
1069 #pragma clang diagnostic ignored "-Wold-style-cast"
1070  return (T CK_TILE_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
1071 #pragma clang diagnostic pop
1072 }
1073 
1074 CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
1075 {
1076 #if defined(__gfx950__)
1077  return 163840;
1078 #else
1079  return 65536;
1080 #endif
1081 }
1082 
1084 CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
1085 {
1086  switch(addr_space)
1087  {
1088  case address_space_enum::generic: return "generic";
1089  case address_space_enum::global: return "global";
1090  case address_space_enum::lds: return "lds";
1091  case address_space_enum::sgpr: return "sgpr";
1092  case address_space_enum::constant: return "constant";
1093  case address_space_enum::vgpr: return "vgpr";
1094  default: return "unknown";
1095  }
1096 }
1097 
1098 // Architecture tags
1099 struct gfx9_t
1100 {
1101 };
1102 struct gfx950_t
1103 {
1104 };
1105 struct gfx103_t
1106 {
1107 };
1108 struct gfx11_t
1109 {
1110 };
1111 struct gfx12_t
1112 {
1113 };
1114 struct gfx_invalid_t
1115 {
1116 };
1117 
1118 CK_TILE_DEVICE static constexpr auto get_device_arch()
1119 {
1120 // FIXME(0): on all devices except gfx11 it returns gfx12_t
1121 // FIXME(1): during the host compilation pass it returns gfx12_t
1122 #if defined(__gfx11__)
1123  return gfx11_t{};
1124 #else
1125  return gfx12_t{};
1126 #endif
1127 }
1128 
1129 CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; }
1130 
1131 namespace detail {
1132 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; }
1133 
1134 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; }
1135 
1136 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; }
1137 
1138 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; }
1139 
1140 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
1141 
1142 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; }
1143 
1144 CK_TILE_DEVICE static constexpr auto arch_tag_dispatch()
1145 {
1146 #if defined(__gfx103__)
1147  return gfx103_t{};
1148 #elif defined(__gfx11__)
1149  return gfx11_t{};
1150 #elif defined(__gfx12__)
1151  return gfx12_t{};
1152 #elif defined(__gfx950__)
1153  return gfx950_t{};
1154 #elif defined(__gfx9__)
1155  return gfx9_t{};
1156 #else
1157  return gfx_invalid_t{};
1158 #endif
1159 }
1160 } // namespace detail
1161 CK_TILE_DEVICE static constexpr auto get_n_lds_banks()
1162 {
1163  return detail::get_n_lds_banks(detail::arch_tag_dispatch());
1164 }
1165 
1166 enum LLVMSchedGroupMask : int32_t
1167 {
1168  NONE = 0,
1169  ALU = 1 << 0,
1170  VALU = 1 << 1,
1171  SALU = 1 << 2,
1172  MFMA = 1 << 3,
1173  VMEM = 1 << 4,
1174  VMEM_READ = 1 << 5,
1175  VMEM_WRITE = 1 << 6,
1176  DS = 1 << 7,
1177  DS_READ = 1 << 8,
1178  DS_WRITE = 1 << 9,
1179  ALL = (DS_WRITE << 1) - 1,
1180 };
1181 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:16
int32_t index_t
Definition: integer.hpp:9
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
__device__ void s_nop()
Definition: synchronization.hpp:61
__device__ index_t get_block_size()
Definition: get_id.hpp:51
__device__ void block_sync_lds_direct_load()
Definition: synchronization.hpp:43
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:43
__device__ X atomic_max(X *p_dst, const X &x)
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
unsigned short uint16_t
Definition: stdint.h:125
signed int int32_t
Definition: stdint.h:123