/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp Source File
tile_fmha_shape.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
11 {
12  if(len == 96)
13  return 128;
14  if(len == 160)
15  return 256;
16 
17  // only length of 96, 160 and power-of-two is supported
18  if(!(len & (len - 1)))
19  return len;
20 
21  return 0;
22 };
23 
24 template <typename BlockTile_, // sequence<...
25  typename Gemm0BlockWarps_,
26  typename Gemm0WarpTile_,
27  typename Gemm1BlockWarps_,
28  typename Gemm1WarpTile_,
29  bool IsVLayoutRowMajor_>
31 {
37 
38  static constexpr index_t NumGemm0Warps =
40  static constexpr index_t NumGemm1Warps =
42  static_assert(NumGemm1Warps % NumGemm0Warps == 0);
43 
45 
46  static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
47  static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
48  static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
49  static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
50  static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
51  static constexpr index_t kQKHeaddim =
52  BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
53  // once (or repeately load Q as a whole tile)
54  static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
55 
56  static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
57 
58  // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
59  static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
63 };
64 
65 template <typename BlockTile_, // sequence<...
66  typename Gemm0BlockWarps_,
67  typename Gemm0WarpTile_,
68  typename Gemm1BlockWarps_,
69  typename Gemm1WarpTile_,
70  typename Gemm2BlockWarps_,
71  typename Gemm2WarpTile_,
72  typename Gemm3BlockWarps_,
73  typename Gemm3WarpTile_,
74  typename Gemm4BlockWarps_,
75  typename Gemm4WarpTile_>
77 {
89 
90  static constexpr index_t NumWarps =
92 
93  static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
95 
96  static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
97  static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
98  static constexpr index_t kK0 =
99  BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll
100  static constexpr index_t kK1 =
101  BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll
102  static constexpr index_t kK2 =
103  BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll
104  static constexpr index_t kK3 =
105  BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll
106  static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
107  static constexpr index_t kQKHeaddim =
108  BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or
109  // K/K^T at once
110  static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
111  // that need load V at once
112 };
113 
114 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: tile_fmha_shape.hpp:77
remove_cvref_t< Gemm2BlockWarps_ > Gemm2BlockWarps
Definition: tile_fmha_shape.hpp:83
static constexpr index_t NumWarps
Definition: tile_fmha_shape.hpp:90
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_fmha_shape.hpp:78
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition: tile_fmha_shape.hpp:80
remove_cvref_t< Gemm2WarpTile_ > Gemm2WarpTile
Definition: tile_fmha_shape.hpp:84
remove_cvref_t< Gemm4WarpTile_ > Gemm4WarpTile
Definition: tile_fmha_shape.hpp:88
static constexpr index_t kQKHeaddim
Definition: tile_fmha_shape.hpp:107
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:79
static constexpr index_t kK3
Definition: tile_fmha_shape.hpp:104
static constexpr index_t kVHeaddim
Definition: tile_fmha_shape.hpp:110
static constexpr index_t kK4
Definition: tile_fmha_shape.hpp:106
static constexpr index_t kN0
Definition: tile_fmha_shape.hpp:97
static constexpr index_t kM0
Definition: tile_fmha_shape.hpp:96
remove_cvref_t< Gemm3BlockWarps_ > Gemm3BlockWarps
Definition: tile_fmha_shape.hpp:85
remove_cvref_t< Gemm4BlockWarps_ > Gemm4BlockWarps
Definition: tile_fmha_shape.hpp:87
remove_cvref_t< Gemm3WarpTile_ > Gemm3WarpTile
Definition: tile_fmha_shape.hpp:86
static constexpr index_t kK1
Definition: tile_fmha_shape.hpp:100
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition: tile_fmha_shape.hpp:82
static constexpr index_t kK2
Definition: tile_fmha_shape.hpp:102
static constexpr index_t kK0
Definition: tile_fmha_shape.hpp:98
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:81
Definition: tile_fmha_shape.hpp:31
static constexpr bool IsVLayoutRowMajor
Definition: tile_fmha_shape.hpp:59
std::conditional_t< IsVLayoutRowMajor, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: tile_fmha_shape.hpp:62
static constexpr index_t kQKHeaddim
Definition: tile_fmha_shape.hpp:51
static constexpr index_t kK0
Definition: tile_fmha_shape.hpp:48
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:35
static constexpr index_t NumGemm0Warps
Definition: tile_fmha_shape.hpp:38
static constexpr index_t NumWarps
Definition: tile_fmha_shape.hpp:44
static constexpr index_t kK1
Definition: tile_fmha_shape.hpp:50
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition: tile_fmha_shape.hpp:36
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition: tile_fmha_shape.hpp:34
static constexpr index_t kSubQKHeaddim
Definition: tile_fmha_shape.hpp:56
static constexpr index_t kN0
Definition: tile_fmha_shape.hpp:47
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:33
static constexpr index_t kN1
Definition: tile_fmha_shape.hpp:49
static constexpr index_t kM0
Definition: tile_fmha_shape.hpp:46
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_fmha_shape.hpp:32
static constexpr index_t NumGemm1Warps
Definition: tile_fmha_shape.hpp:40
Definition: integral_constant.hpp:13
Definition: math.hpp:98
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17