20 template <
typename GridwiseGemm,
23 typename AGridDesc_K0_M0_M1_K1,
24 typename BGridDesc_K0_N0_N1_K1,
25 typename CGridDesc_M0_M10_M11_N0_N10_N11,
26 typename Block2CTileMap,
27 bool HasMainKBlockLoop,
28 bool HasDoubleTailKBlockLoop>
30 #if CK_USE_LAUNCH_BOUNDS
34 const FloatAB* __restrict__ p_b_grid,
35 FloatC* __restrict__ p_c_grid,
36 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
37 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
38 const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
39 const Block2CTileMap block_2_ctile_map)
41 constexpr
index_t shared_block_size =
42 GridwiseGemm::GetSharedMemoryNumberOfByte() /
sizeof(FloatAB);
44 __shared__ FloatAB p_shared_block[shared_block_size];
46 GridwiseGemm::Run(p_a_grid,
50 a_grid_desc_k0_m0_m1_k1,
51 b_grid_desc_k0_n0_n1_k1,
52 c_grid_desc_m0_m10_m11_n0_n10_n11,
63 typename AGridDesc_K0_M_K1,
64 typename BGridDesc_K0_N_K1,
65 typename CGridDesc_M_N,
73 typename M11N11ThreadClusterM110Xs,
74 typename M11N11ThreadClusterN110Xs,
75 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
76 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
77 typename ABlockTransferThreadClusterArrangeOrder,
78 typename ABlockTransferSrcAccessOrder,
79 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
80 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
81 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
82 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
83 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
84 typename BBlockTransferThreadClusterArrangeOrder,
85 typename BBlockTransferSrcAccessOrder,
86 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
87 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
88 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
89 typename CThreadTransferSrcDstAccessOrder,
90 index_t CThreadTransferSrcDstVectorDim,
91 index_t CThreadTransferDstScalarPerVector>
105 constexpr
auto max_lds_align =
K1;
119 constexpr
auto a_block_aligned_space_size =
122 constexpr
auto b_block_aligned_space_size =
125 return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) *
sizeof(FloatAB);
128 __host__ __device__
static constexpr
bool
130 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
131 const CGridDesc_M_N& c_grid_desc_m_n)
133 const auto M = a_grid_desc_k0_m_k1.GetLength(
I1);
134 const auto N = b_grid_desc_k0_n_k1.GetLength(
I1);
135 const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
139 return (M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1) &&
140 K0 == b_grid_desc_k0_n_k1.GetLength(
I0) &&
141 K1 == a_grid_desc_k0_m_k1.GetLength(
I2) &&
142 K1 == b_grid_desc_k0_n_k1.GetLength(
I2)) &&
143 (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
148 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
155 const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
157 return has_main_k_block_loop;
162 const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
164 return has_double_tail_k_block_loop;
167 __host__ __device__
static constexpr
auto
170 const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
171 const auto M = a_grid_desc_k0_m_k1.GetLength(
I1);
174 const auto M0 = M / M1;
176 const auto a_grid_desc_k0_m0_m1_k1 =
184 return a_grid_desc_k0_m0_m1_k1;
187 __host__ __device__
static constexpr
auto
190 const auto K0 = b_grid_desc_k0_n_k1.GetLength(
I0);
191 const auto N = b_grid_desc_k0_n_k1.GetLength(
I1);
194 const auto N0 = N / N1;
196 const auto b_grid_desc_k0_n0_n1_k1 =
204 return b_grid_desc_k0_n0_n1_k1;
207 __host__ __device__
static constexpr
auto
210 const auto M = c_grid_desc_m_n.GetLength(
I0);
211 const auto N = c_grid_desc_m_n.GetLength(
I1);
216 const auto M0 = M / M1;
217 const auto N0 = N / N1;
226 constexpr
auto M10 = M1 / M11;
227 constexpr
auto N10 = N1 / N11;
236 return c_grid_desc_m0_m10_m11_n0_n10_n11;
240 __host__ __device__
static constexpr
auto
253 template <
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
254 __device__
static void
255 Run(
const FloatAB* __restrict__ p_a_grid,
256 const FloatAB* __restrict__ p_b_grid,
257 FloatC* __restrict__ p_c_grid,
258 FloatAB* __restrict__ p_shared_block,
266 const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
267 p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
268 const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
269 p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
270 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
271 p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
274 const auto c_m0_n0_block_cluster_idx =
278 const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[
I0]);
279 const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[
I1]);
281 if(!block_2_ctile_map.ValidCTileIndex(
283 make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I0),
284 c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I3))))
290 constexpr
auto max_lds_align =
K1;
314 static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
315 a_k0_m_k1_block_desc.GetElementSpaceSize() &&
316 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
317 b_k0_n_k1_block_desc.GetElementSpaceSize() &&
325 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
326 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
327 ABlockTransferThreadClusterArrangeOrder,
331 decltype(a_block_desc_k0_m0_m1_k1),
332 ABlockTransferSrcAccessOrder,
334 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
335 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
336 ABlockTransferSrcVectorTensorContiguousDimOrder,
339 true>(a_grid_desc_k0_m0_m1_k1,
341 a_block_desc_k0_m0_m1_k1,
349 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
350 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
351 BBlockTransferThreadClusterArrangeOrder,
355 decltype(b_block_desc_k0_n0_n1_k1),
356 BBlockTransferSrcAccessOrder,
358 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
359 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
360 BBlockTransferSrcVectorTensorContiguousDimOrder,
363 true>(b_grid_desc_k0_n0_n1_k1,
365 b_block_desc_k0_n0_n1_k1,
374 const auto blockwise_gemm =
380 decltype(a_k0_m_k1_block_desc),
381 decltype(b_k0_n_k1_block_desc),
385 M11N11ThreadClusterM110Xs,
386 M11N11ThreadClusterN110Xs,
390 constexpr
auto c_m10_m11_n10_n11_thread_tensor_lengths =
391 decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
398 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
401 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
403 FloatAB* p_a_block_double = p_shared_block;
404 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
407 auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
408 c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
411 c_thread_buf.Clear();
413 constexpr
auto a_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0, 0);
414 constexpr
auto b_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0, 0);
416 auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
417 p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
418 auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
419 p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
421 auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
422 p_a_block_double + a_block_aligned_space_size,
423 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
424 auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
425 p_b_block_double + b_block_aligned_space_size,
426 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
430 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
431 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
433 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
434 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
437 if constexpr(HasMainKBlockLoop)
439 const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(
I0);
441 index_t k_block_data_begin = 0;
448 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
449 a_block_slice_copy_step);
450 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
451 b_block_slice_copy_step);
454 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
455 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
460 blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
466 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
467 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
470 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
471 a_block_slice_copy_step);
472 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
473 b_block_slice_copy_step);
476 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
477 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
483 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
486 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
487 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
489 k_block_data_begin += 2 * K0PerBlock;
490 }
while(k_block_data_begin < K0 - 2 * K0PerBlock);
494 if constexpr(HasDoubleTailKBlockLoop)
496 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
497 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
502 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
503 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
507 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
510 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
511 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
517 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
525 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
530 constexpr
auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
533 Number<c_m10_m11_n10_n11_thread_tensor_lengths[
I0]>{},
539 const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
540 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
546 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
547 decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
550 c_m10_m11_n10_n11_thread_tensor_lengths[
I0],
551 c_m10_m11_n10_n11_thread_tensor_lengths[
I1],
553 c_m10_m11_n10_n11_thread_tensor_lengths[
I2],
554 c_m10_m11_n10_n11_thread_tensor_lengths[
I3]>,
555 CThreadTransferSrcDstAccessOrder,
556 CThreadTransferSrcDstVectorDim,
557 CThreadTransferDstScalarPerVector,
558 CGlobalMemoryDataOperation,
560 true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
562 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I0],
563 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I1],
565 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I2],
566 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I3]),
568 .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
571 c_grid_desc_m0_m10_m11_n0_n10_n11,
582 typename AGridDesc_B_K0_M_K1,
583 typename BGridDesc_B_K0_N_K1,
584 typename CGridDesc_M_N,
592 typename M11N11ThreadClusterM110Xs,
593 typename M11N11ThreadClusterN110Xs,
594 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
595 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
596 typename ABlockTransferThreadClusterArrangeOrder,
597 typename ABlockTransferSrcAccessOrder,
598 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
599 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
600 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
601 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
602 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
603 typename BBlockTransferThreadClusterArrangeOrder,
604 typename BBlockTransferSrcAccessOrder,
605 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
606 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
607 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
608 typename CThreadTransferSrcDstAccessOrder,
609 index_t CThreadTransferSrcDstVectorDim,
610 index_t CThreadTransferDstScalarPerVector>
624 constexpr
auto max_lds_align =
K1;
639 a_block_desc_b_k0_m_k1.GetElementSpaceSize(), max_lds_align);
642 b_block_desc_b_k0_n_k1.GetElementSpaceSize(), max_lds_align);
644 return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) *
sizeof(FloatAB);
647 __host__ __device__
static constexpr
bool
649 const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
650 const CGridDesc_M_N& c_grid_desc_m_n)
654 if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() *
sizeof(FloatAB) <= TwoGB &&
655 b_grid_desc_b_k0_n_k1.GetElementSpaceSize() *
sizeof(FloatAB) <= TwoGB &&
656 c_grid_desc_m_n.GetElementSpaceSize() *
sizeof(FloatC) <= TwoGB))
661 const auto M = a_grid_desc_b_k0_m_k1.GetLength(
I2);
662 const auto N = b_grid_desc_b_k0_n_k1.GetLength(
I2);
663 const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(
I1);
664 const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(
I0);
668 return (M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1) &&
669 K0 == b_grid_desc_b_k0_n_k1.GetLength(
I1) &&
670 K1 == a_grid_desc_b_k0_m_k1.GetLength(
I3) &&
671 K1 == b_grid_desc_b_k0_n_k1.GetLength(
I3)) &&
672 KBatch == b_grid_desc_b_k0_n_k1.GetLength(
I0) &&
673 (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
678 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
685 const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
687 return has_main_k_block_loop;
692 const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
694 return has_double_tail_k_block_loop;
697 __host__ __device__
static constexpr
auto
700 const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(
I0);
701 const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(
I1);
702 const auto M = a_grid_desc_b_k0_m_k1.GetLength(
I2);
705 const auto M0 = M / M1;
708 a_grid_desc_b_k0_m_k1,
716 return a_grid_desc_b_k0_m0_m1_k1;
719 __host__ __device__
static constexpr
auto
722 const auto KBatch = b_grid_desc_b_k0_n_k1.GetLength(
I0);
723 const auto K0 = b_grid_desc_b_k0_n_k1.GetLength(
I1);
724 const auto N = b_grid_desc_b_k0_n_k1.GetLength(
I2);
727 const auto N0 = N / N1;
730 b_grid_desc_b_k0_n_k1,
738 return b_grid_desc_b_k0_n0_n1_k1;
741 __host__ __device__
static constexpr
auto
744 const auto M = c_grid_desc_m_n.GetLength(
I0);
745 const auto N = c_grid_desc_m_n.GetLength(
I1);
750 const auto M0 = M / M1;
751 const auto N0 = N / N1;
760 constexpr
auto M10 = M1 / M11;
761 constexpr
auto N10 = N1 / N11;
770 return c_grid_desc_m0_m10_m11_n0_n10_n11;
778 c_m_n_grid_desc, M01, N01, KBatch);
789 template <
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
790 __device__
static void
791 Run(
const FloatAB* __restrict__ p_a_grid,
792 const FloatAB* __restrict__ p_b_grid,
793 FloatC* __restrict__ p_c_grid,
794 FloatAB* __restrict__ p_shared_block,
802 const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
803 p_a_grid, a_grid_desc_b_k0_m0_m1_k1.GetElementSpaceSize());
804 const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
805 p_b_grid, b_grid_desc_b_k0_n0_n1_k1.GetElementSpaceSize());
806 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
807 p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
810 const auto block_work_idx =
813 const index_t k_batch_id = block_work_idx[
I0];
815 if(!c_block_cluster_adaptor.ValidCTileIndex(
817 make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I0),
818 c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(
I3))))
824 const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
826 const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[
I2]);
829 constexpr
auto max_lds_align =
K1;
865 static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
866 a_k0_m_k1_block_desc.GetElementSpaceSize() &&
867 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
868 b_k0_n_k1_block_desc.GetElementSpaceSize() &&
875 Sequence<1, K0PerBlock, 1, MPerBlock,
K1.value>,
876 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
877 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
878 ABlockTransferThreadClusterArrangeOrder,
882 decltype(a_block_desc_b_k0_m0_m1_k1),
883 ABlockTransferSrcAccessOrder,
885 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
886 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
887 ABlockTransferSrcVectorTensorContiguousDimOrder,
890 true>(a_grid_desc_b_k0_m0_m1_k1,
892 a_block_desc_b_k0_m0_m1_k1,
899 Sequence<1, K0PerBlock, 1, NPerBlock,
K1.value>,
900 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
901 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
902 BBlockTransferThreadClusterArrangeOrder,
906 decltype(b_block_desc_b_k0_n0_n1_k1),
907 BBlockTransferSrcAccessOrder,
909 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
910 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
911 BBlockTransferSrcVectorTensorContiguousDimOrder,
914 true>(b_grid_desc_b_k0_n0_n1_k1,
916 b_block_desc_b_k0_n0_n1_k1,
925 const auto blockwise_gemm =
931 decltype(a_k0_m_k1_block_desc),
932 decltype(b_k0_n_k1_block_desc),
936 M11N11ThreadClusterM110Xs,
937 M11N11ThreadClusterN110Xs,
941 constexpr
auto c_m10_m11_n10_n11_thread_tensor_lengths =
942 decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
949 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
952 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
954 FloatAB* p_a_block_double = p_shared_block;
955 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
958 auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
959 c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
962 c_thread_buf.Clear();
964 constexpr
auto a_block_slice_copy_step =
make_multi_index(0, K0PerBlock, 0, 0, 0);
965 constexpr
auto b_block_slice_copy_step =
make_multi_index(0, K0PerBlock, 0, 0, 0);
967 auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
968 p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
969 auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
970 p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
972 auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
973 p_a_block_double + a_block_aligned_space_size,
974 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
975 auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
976 p_b_block_double + b_block_aligned_space_size,
977 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
981 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
982 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
984 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
985 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
988 if constexpr(HasMainKBlockLoop)
990 const auto K0 = a_grid_desc_b_k0_m0_m1_k1.GetLength(
I1);
992 index_t k_block_data_begin = 0;
999 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
1000 a_block_slice_copy_step);
1001 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
1002 b_block_slice_copy_step);
1005 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1006 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1011 blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
1017 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
1018 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
1021 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
1022 a_block_slice_copy_step);
1023 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
1024 b_block_slice_copy_step);
1027 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1028 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1034 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
1037 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
1038 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
1040 k_block_data_begin += 2 * K0PerBlock;
1041 }
while(k_block_data_begin < K0 - 2 * K0PerBlock);
1045 if constexpr(HasDoubleTailKBlockLoop)
1047 a_blockwise_copy.
MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, a_block_slice_copy_step);
1048 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step);
1053 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1054 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1058 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
1061 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
1062 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
1068 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
1076 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
1081 constexpr
auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
1084 Number<c_m10_m11_n10_n11_thread_tensor_lengths[
I0]>{},
1090 const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
1091 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
1097 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
1098 decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
1101 c_m10_m11_n10_n11_thread_tensor_lengths[
I0],
1102 c_m10_m11_n10_n11_thread_tensor_lengths[
I1],
1104 c_m10_m11_n10_n11_thread_tensor_lengths[
I2],
1105 c_m10_m11_n10_n11_thread_tensor_lengths[
I3]>,
1106 CThreadTransferSrcDstAccessOrder,
1107 CThreadTransferSrcDstVectorDim,
1108 CThreadTransferDstScalarPerVector,
1109 CGlobalMemoryDataOperation,
1111 true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
1113 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I0],
1114 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I1],
1115 n_block_data_idx_on_grid,
1116 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I2],
1117 c_m10_m11_n10_n11_thread_origin_idx_on_block[
I3]),
1119 .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
1122 c_grid_desc_m0_m10_m11_n0_n10_n11,
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
int64_t long_index_t
Definition: ck.hpp:290
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_dl_v1r3.hpp:33
__host__ constexpr __device__ auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition: container_helper.hpp:380
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:718
Definition: block_to_ctile_map.hpp:615
Definition: blockwise_gemm_dl_v2r3.hpp:47
Definition: blockwise_tensor_slice_transfer_v5r1.hpp:37
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition: blockwise_tensor_slice_transfer_v5r1.hpp:100
Definition: gridwise_gemm_dl_v1r3.hpp:612
static constexpr auto I2
Definition: gridwise_gemm_dl_v1r3.hpp:615
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dl_v1r3.hpp:621
decltype(MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})) AGridDesc_B_K0_M0_M1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:782
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:742
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:676
__host__ static constexpr __device__ auto MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:698
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc_M_N &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition: gridwise_gemm_dl_v1r3.hpp:774
decltype(MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})) BGridDesc_B_K0_N0_N1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:784
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:690
static constexpr auto K1
Definition: gridwise_gemm_dl_v1r3.hpp:619
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:683
static constexpr auto I1
Definition: gridwise_gemm_dl_v1r3.hpp:614
__host__ static constexpr __device__ auto MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:720
static constexpr auto I0
Definition: gridwise_gemm_dl_v1r3.hpp:613
static constexpr auto I3
Definition: gridwise_gemm_dl_v1r3.hpp:616
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1, const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:648
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: gridwise_gemm_dl_v1r3.hpp:786
decltype(MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_dl_v1r3.hpp:787
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const AGridDesc_B_K0_M0_M1_K1 &a_grid_desc_b_k0_m0_m1_k1, const BGridDesc_B_K0_N0_N1_K1 &b_grid_desc_b_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const CBlockClusterAdaptor &c_block_cluster_adaptor, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition: gridwise_gemm_dl_v1r3.hpp:791
Definition: gridwise_gemm_dl_v1r3.hpp:93
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const AGridDesc_K0_M0_M1_K1 &a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 &b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap &block_2_ctile_map, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition: gridwise_gemm_dl_v1r3.hpp:255
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:129
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:208
static constexpr auto I2
Definition: gridwise_gemm_dl_v1r3.hpp:96
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dl_v1r3.hpp:102
static constexpr auto K1
Definition: gridwise_gemm_dl_v1r3.hpp:100
decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:247
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:188
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:146
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:153
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:241
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: gridwise_gemm_dl_v1r3.hpp:250
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:160
decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) Block2CTileMap
Definition: gridwise_gemm_dl_v1r3.hpp:251
static constexpr auto I1
Definition: gridwise_gemm_dl_v1r3.hpp:95
decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: gridwise_gemm_dl_v1r3.hpp:248
static constexpr auto I0
Definition: gridwise_gemm_dl_v1r3.hpp:94
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:168
static constexpr auto I3
Definition: gridwise_gemm_dl_v1r3.hpp:97
Definition: sequence.hpp:43
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: unary_element_wise_operation.hpp:241