23 template <
typename LowLengths>
49 static_assert(LowerIndex::Size() == NDimLow,
"wrong!");
56 __host__ __device__ constexpr
const auto&
GetUpperLengths()
const {
return up_lengths_; }
58 template <
typename LowIdx,
typename UpIdx>
60 const UpIdx& idx_up)
const
62 static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
63 "wrong! inconsistent # of dimension");
69 idx_low(i) = tmp / this->low_lengths_scan_[i];
70 tmp %= this->low_lengths_scan_[i];
76 template <
typename LowIdxDiff,
82 const UpIdxDiff& idx_up_diff,
84 const UpIdx& idx_up_new,
87 static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
88 LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
89 "wrong! inconsistent # of dimension");
92 constexpr
auto INm1 =
Number<NDimLow - 1>{};
97 idx_diff_low(INm1) = idx_up_diff[I0];
114 template <
typename UpIdx>
115 __host__ __device__
static constexpr
bool
121 __host__ __device__
void Print()
const
124 printf(
"Merge_v3_direct_division_mod_wrw, ");
125 printf(
"low_lengths_ ");
127 printf(
"low_lengths_scan_ ");
129 printf(
"up_lengths_ ");
135 template <
typename LowLengths>
141 template <
typename GridwiseGemm,
145 typename AGridDesc_B_K0_M_K1,
146 typename BGridDesc_B_K0_N_K1,
147 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
148 typename AElementwiseOperation,
149 typename BElementwiseOperation,
150 typename CElementwiseOperation,
151 typename CBlockClusterAdaptor,
152 bool HasMainKBlockLoop>
154 #if CK_USE_LAUNCH_BOUNDS
158 const FloatB* __restrict__ p_b_grid,
159 FloatC* __restrict__ p_c_grid,
160 const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
161 const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
162 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
163 c_grid_desc_mblock_mperblock_nblock_nperblock,
164 const AElementwiseOperation a_element_op,
165 const BElementwiseOperation b_element_op,
166 const CElementwiseOperation c_element_op,
167 const CBlockClusterAdaptor c_block_cluster_adaptor)
169 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
171 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
173 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
177 a_b_k0_m_k1_grid_desc,
178 b_b_k0_n_k1_grid_desc,
179 c_grid_desc_mblock_mperblock_nblock_nperblock,
183 c_block_cluster_adaptor);
188 ignore = a_b_k0_m_k1_grid_desc;
189 ignore = b_b_k0_n_k1_grid_desc;
190 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
194 ignore = c_block_cluster_adaptor;
204 typename AGridDesc_B_K0_M_K1,
205 typename BGridDesc_B_K0_N_K1,
206 typename CMNGridDesc,
207 typename AElementwiseOperation,
208 typename BElementwiseOperation,
209 typename CElementwiseOperation,
218 typename ABlockTransferThreadClusterLengths_K0_M_K1,
219 typename ABlockTransferThreadClusterArrangeOrder,
220 typename ABlockTransferSrcAccessOrder,
221 index_t ABlockTransferSrcVectorDim,
222 index_t ABlockTransferSrcScalarPerVector,
223 index_t ABlockTransferDstScalarPerVector_K1,
224 bool AThreadTransferSrcResetCoordinateAfterRun,
225 bool ABlockLdsExtraM,
229 typename BBlockTransferThreadClusterLengths_K0_N_K1,
230 typename BBlockTransferThreadClusterArrangeOrder,
231 typename BBlockTransferSrcAccessOrder,
232 index_t BBlockTransferSrcVectorDim,
233 index_t BBlockTransferSrcScalarPerVector,
234 index_t BBlockTransferDstScalarPerVector_K1,
235 bool BThreadTransferSrcResetCoordinateAfterRun,
236 bool BBlockLdsExtraN,
240 index_t CShuffleMRepeatPerShuffle,
241 index_t CShuffleNRepeatPerShuffle,
242 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
243 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
244 bool ABlockLdsExtraM1Wrw =
false,
245 bool BBlockLdsExtraN1Wrw =
false,
246 index_t NumGemmKPrefetchStage = 1,
248 typename ComputeTypeA = FloatA,
249 typename ComputeTypeB = ComputeTypeA>
267 decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
274 #if CK_GFX90A_DENORM_WORKAROUND
296 constexpr
auto max_lds_align = K1;
299 constexpr
auto a_block_desc_k0_m_k1 = [&]() {
300 if constexpr(ABlockLdsExtraM)
302 if constexpr(ABlockLdsExtraM1Wrw)
313 a_block_desc_k0_m0_m1_k1,
321 return a_block_desc_k0_m_k1_tmp;
337 return a_block_desc_k0_m_k1;
342 constexpr
auto max_lds_align = K1;
345 constexpr
auto a_block_desc_b_k0_m_k1 = [&]() {
346 if constexpr(ABlockLdsExtraM)
348 if constexpr(ABlockLdsExtraM1Wrw)
364 a_block_desc_b_k0_m0_m1_k1,
373 return a_block_desc_b_k0_m_k1_tmp;
393 return a_block_desc_b_k0_m_k1;
398 constexpr
auto max_lds_align = K1;
401 constexpr
auto b_block_desc_k0_n_k1 = [&]() {
402 if constexpr(BBlockLdsExtraN)
404 if constexpr(BBlockLdsExtraN1Wrw)
415 b_block_desc_k0_n0_n1_k1,
423 return b_block_desc_k0_n_k1_tmp;
439 return b_block_desc_k0_n_k1;
444 constexpr
auto max_lds_align = K1;
447 constexpr
auto b_block_desc_b_k0_n_k1 = [&]() {
448 if constexpr(BBlockLdsExtraN)
450 if constexpr(BBlockLdsExtraN1Wrw)
466 b_block_desc_b_k0_n0_n1_k1,
475 return b_block_desc_b_k0_n_k1_tmp;
495 return b_block_desc_b_k0_n_k1;
500 constexpr
auto max_lds_align = K1;
503 constexpr
auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
506 constexpr
auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
510 a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
513 b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
515 constexpr
auto c_block_size =
516 GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
520 c_block_size *
sizeof(FloatC));
524 template <
typename Block2CTileMap>
525 __host__ __device__
static constexpr
bool
527 const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
528 const CMNGridDesc& c_m_n_grid_desc,
529 const Block2CTileMap& block_2_ctile_map)
532 "wrong! K1 need to be known at compile-time");
534 static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
535 (NPerBlock % (NRepeat * NPerXDL)) == 0,
536 "Invalid tuning param!");
538 const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
539 const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
540 const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
541 const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
544 const auto num_k_loop = K0 / K0PerBlock;
546 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
551 if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
552 K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
553 K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
554 K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
555 KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
558 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
561 if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
573 const index_t num_loop = K0 / K0PerBlock;
575 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
580 __host__ __device__
static constexpr
auto
583 const auto M = c_m_n_grid_desc.GetLength(I0);
584 const auto N = c_m_n_grid_desc.GetLength(I1);
586 const auto MBlock = M / MPerBlock;
587 const auto NBlock = N / NPerBlock;
602 c_m_n_grid_desc, M01, N01, KBatch);
605 __host__ __device__
static constexpr
auto
608 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
609 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
619 decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
622 template <
bool HasMainKBlockLoop>
623 __device__
static void Run(
const FloatA* __restrict__ p_a_grid,
624 const FloatB* __restrict__ p_b_grid,
625 FloatC* __restrict__ p_c_grid,
626 void* __restrict__ p_shared,
627 const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
628 const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
630 c_grid_desc_mblock_mperblock_nblock_nperblock,
631 const AElementwiseOperation& a_element_op,
632 const BElementwiseOperation& b_element_op,
633 const CElementwiseOperation& c_element_op,
636 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
637 p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
638 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
639 p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
640 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
641 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
643 const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
646 const auto block_work_idx =
649 const index_t k_batch_id = block_work_idx[I0];
651 if(!c_block_cluster_adaptor.ValidCTileIndex(
652 make_tuple(block_work_idx[I1], block_work_idx[I2]),
653 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
654 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
660 const index_t m_block_data_idx_on_grid =
661 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
663 const index_t n_block_data_idx_on_grid =
664 __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
667 constexpr
auto max_lds_align = K1;
670 constexpr
auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
672 constexpr
auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
674 constexpr
auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
676 constexpr
auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
678 auto a_blockwise_copy =
680 AElementwiseOperation,
682 InMemoryDataOperationEnum::Set,
684 ABlockTransferThreadClusterLengths_K0_M_K1,
685 ABlockTransferThreadClusterArrangeOrder,
688 decltype(a_b_k0_m_k1_grid_desc),
689 decltype(a_b_k0_m_k1_block_desc),
690 ABlockTransferSrcAccessOrder,
692 ABlockTransferSrcVectorDim,
694 ABlockTransferSrcScalarPerVector,
695 ABlockTransferDstScalarPerVector_K1,
698 AThreadTransferSrcResetCoordinateAfterRun,
700 a_b_k0_m_k1_grid_desc,
703 a_b_k0_m_k1_block_desc,
708 auto b_blockwise_copy =
710 BElementwiseOperation,
712 InMemoryDataOperationEnum::Set,
714 BBlockTransferThreadClusterLengths_K0_N_K1,
715 BBlockTransferThreadClusterArrangeOrder,
718 decltype(b_b_k0_n_k1_grid_desc),
719 decltype(b_b_k0_n_k1_block_desc),
720 BBlockTransferSrcAccessOrder,
722 BBlockTransferSrcVectorDim,
724 BBlockTransferSrcScalarPerVector,
725 BBlockTransferDstScalarPerVector_K1,
728 BThreadTransferSrcResetCoordinateAfterRun,
730 b_b_k0_n_k1_grid_desc,
733 b_b_k0_n_k1_block_desc,
750 auto blockwise_gemm =
755 decltype(a_k0_m_k1_block_desc),
756 decltype(b_k0_n_k1_block_desc),
766 constexpr
auto a_block_space_size =
769 constexpr
auto a_block_slice_copy_step =
make_multi_index(0, K0PerBlock, 0, 0);
770 constexpr
auto b_block_slice_copy_step =
make_multi_index(0, K0PerBlock, 0, 0);
772 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
773 static_cast<FloatAAdjusted*
>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
775 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
777 b_k0_n_k1_block_desc.GetElementSpaceSize());
780 const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
782 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
783 a_b_k0_m_k1_block_desc,
787 a_block_slice_copy_step,
788 b_b_k0_n_k1_grid_desc,
789 b_b_k0_n_k1_block_desc,
793 b_block_slice_copy_step,
800 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
801 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
803 constexpr
auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
804 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
806 constexpr
auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
807 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
809 constexpr
auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
810 constexpr
auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
811 constexpr
auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
812 constexpr
auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
813 constexpr
auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
814 constexpr
auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
815 constexpr
auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
816 constexpr
auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
818 constexpr
auto c_block_desc_mblock_mperblock_nblock_nperblock =
819 GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
821 auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
822 static_cast<FloatC*
>(p_shared),
823 c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
825 static_assert(M1 == MWave,
"");
826 static_assert(N1 == NWave,
"");
827 static_assert(M2 * M3 * M4 == MPerXDL,
"");
828 static_assert(N2 == NPerXDL,
"");
831 c_block_desc_mblock_mperblock_nblock_nperblock,
849 const auto c_thread_mtx_on_block =
850 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
852 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
853 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
855 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
861 const auto m_thread_data_on_block_idx =
862 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
865 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
871 const auto n_thread_data_on_block_idx =
872 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
876 auto c_thread_copy_vgpr_to_lds =
879 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
880 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
883 CShuffleNRepeatPerShuffle,
893 InMemoryDataOperationEnum::Set,
896 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
899 m_thread_data_on_block_idx[I1],
900 n_thread_data_on_block_idx[I1],
901 m_thread_data_on_block_idx[I2],
902 m_thread_data_on_block_idx[I3],
903 m_thread_data_on_block_idx[I4],
904 n_thread_data_on_block_idx[I2]),
910 CElementwiseOperation,
911 CGlobalMemoryDataOperation,
913 CShuffleMRepeatPerShuffle * MWave * MPerXDL,
915 CShuffleNRepeatPerShuffle * NWave * NPerXDL>,
916 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
920 decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
921 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
924 CBlockTransferScalarPerVector_NWaveNPerXDL,
927 {c_block_desc_mblock_mperblock_nblock_nperblock,
929 c_grid_desc_mblock_mperblock_nblock_nperblock,
933 constexpr
auto mxdlperwave_forward_step =
935 constexpr
auto nxdlperwave_forward_step =
937 constexpr
auto nxdlperwave_backward_step =
941 constexpr
auto mxdlperwave = mxdlperwave_iter;
944 constexpr
bool nxdlperwave_forward_sweep =
945 (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
947 constexpr
index_t nxdlperwave_value =
948 nxdlperwave_forward_sweep
950 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
958 c_thread_copy_vgpr_to_lds.Run(
959 c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
960 make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
962 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
969 c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
971 c_grid_desc_mblock_mperblock_nblock_nperblock,
975 if constexpr(nxdlperwave_forward_sweep &&
976 (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
978 c_block_copy_lds_to_global.MoveDstSliceWindow(
979 c_grid_desc_mblock_mperblock_nblock_nperblock,
980 nxdlperwave_forward_step);
982 else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
984 c_block_copy_lds_to_global.MoveDstSliceWindow(
985 c_grid_desc_mblock_mperblock_nblock_nperblock,
986 nxdlperwave_backward_step);
991 if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
993 c_block_copy_lds_to_global.MoveDstSliceWindow(
994 c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:157
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:24
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
__host__ __device__ void print_multi_index(const Tuple< Xs... > &x)
Definition: statically_indexed_array_multi_index.hpp:147
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
__host__ constexpr __device__ auto make_merge_transform_v4_no_carry(const LowLengths &low_lengths)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:136
Definition: block_to_ctile_map.hpp:718
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:79
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:251
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:267
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:396
ComputeTypeB FloatBAdjusted
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:281
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:264
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:498
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:619
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:294
ComputeTypeA FloatAAdjusted
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:280
__host__ static constexpr __device__ auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:340
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:606
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:570
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CMNGridDesc &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:598
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:623
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CMNGridDesc &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:526
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:581
decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:620
__host__ static constexpr __device__ auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:442
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:25
__host__ constexpr __device__ Merge_v4_no_carry(const LowLengths &low_lengths)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:43
LowLengthsScan low_lengths_scan_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:38
__host__ constexpr __device__ Merge_v4_no_carry()=default
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:35
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:116
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_up_diff, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:81
static constexpr index_t NDimLow
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:26
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:52
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:56
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:59
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:107
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:100
UpLengths up_lengths_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:39
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:32
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:54
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:102
LowLengths low_lengths_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:37
__host__ __device__ void Print() const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:121
Definition: xdlops_gemm.hpp:886
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241