/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/common/generic_2d_block_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/common/generic_2d_block_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/common/generic_2d_block_shape.hpp Source File
generic_2d_block_shape.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck_tile {
7 
8 /*
9 // clang-format off
10 
11 4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
12 
13  Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
14  +<----------------------< Repeat_N(2)>--------------------->+
15  | |
16  +<-- <WarpPerBlock_N(2)> -->+
17  Warp_N
18  +--------------+--------------+--------------+--------------+----+----------------+
19  Warp_M | wrap_0 | wrap_1 | | ^ ^
20  +--------------+--------------+ | <WarpPerBlock_M(2)> |
21  | wrap_2 | wrap_3 | | v
22  +--------------+--------------+--------------+--------------+----+ Block_M
23  | | |
24  + + |
25  | | | v
26  +--------------+--------------+--------------+--------------+ +
27 
28  each Warp-tile (e.g 16 thrd per row)
29 
30  Vector_N (contiguous pixels each thrd holds along N, or vector size)
31  +-----------+-----------+-----------+-----------+-----------+
32  | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
33  +-----------+-----------+-----------+-----------+-----------+
34  | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
35  +-----------+-----------+-----------+-----------+-----------+
36 // clang-format on
37 */
38 template <typename BlockTile_, // block size, seq<M, N>
39  typename WarpPerBlock_, // num warps along seq<M, N>
40  typename WarpTile_, // warp size, seq<M, N>
41  typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
43 {
44  // block size
45  static constexpr index_t Block_M = BlockTile_::at(number<0>{});
46  static constexpr index_t Block_N = BlockTile_::at(number<1>{});
47 
48  // num warps along seq<M, N>, within each block
49  static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
50  static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
51 
52  // warp size
53  static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
54  static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
55 
56  static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
57  static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
58  // repeat of each thread along seq<M, N>
59  static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
60  static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
61 
62  // vector size along seq<M, N>
63  static constexpr index_t Vector_M = Vector_::at(number<0>{});
64  static constexpr index_t Vector_N = Vector_::at(number<1>{});
65 
66  static_assert(Warp_M % Vector_M == 0);
67  static_assert(Warp_N % Vector_N == 0);
68  // num of threads along seq<M, N>, within each warp
69  static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
70  static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
73 
75 };
76 
77 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
Definition: generic_2d_block_shape.hpp:43
static constexpr index_t Repeat_N
Definition: generic_2d_block_shape.hpp:60
static constexpr index_t Vector_M
Definition: generic_2d_block_shape.hpp:63
static constexpr index_t Warp_M
Definition: generic_2d_block_shape.hpp:53
static constexpr index_t WarpPerBlock_M
Definition: generic_2d_block_shape.hpp:49
static constexpr index_t WarpPerBlock_N
Definition: generic_2d_block_shape.hpp:50
static constexpr index_t ThreadPerWarp_N
Definition: generic_2d_block_shape.hpp:70
static constexpr index_t ThreadPerWarp_M
Definition: generic_2d_block_shape.hpp:69
static constexpr index_t Repeat_M
Definition: generic_2d_block_shape.hpp:59
static constexpr index_t Warp_N
Definition: generic_2d_block_shape.hpp:54
static constexpr index_t Block_M
Definition: generic_2d_block_shape.hpp:45
static constexpr index_t Block_N
Definition: generic_2d_block_shape.hpp:46
static constexpr index_t BlockSize
Definition: generic_2d_block_shape.hpp:74
static constexpr index_t ThreadPerBlock_M
Definition: generic_2d_block_shape.hpp:71
static constexpr index_t Vector_N
Definition: generic_2d_block_shape.hpp:64
static constexpr index_t ThreadPerBlock_N
Definition: generic_2d_block_shape.hpp:72
Definition: integral_constant.hpp:13