/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp Source File
batched_transpose_lds_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // supports 2D transpose which will store to lds,
11 // then use ds_read_b*_tr_b* instruction to get the transposed data
12 template <typename DataType_,
13  typename BlockTile, // sequence<block_x, block_y>
14  typename NumWarps,
15  bool kPadM_,
16  bool kPadN_>
18 {
20 
21  static constexpr index_t kRowWarps_ = NumWarps::at(number<1>{});
22  static constexpr index_t kColWarps_ = NumWarps::at(number<0>{});
24  static constexpr index_t kRowPerBlock_ = BlockTile::at(number<1>{});
25  static constexpr index_t kColPerBlock_ = BlockTile::at(number<0>{});
26 
27  static constexpr index_t kBlockSize = kBlockSize_;
28  // warps per block
29  static constexpr index_t kLeadNumWarps = kRowWarps_;
30  static constexpr index_t kSecondNumWarps = kColWarps_;
31 
34 
37 
38  static_assert(kLeadSizePerBlock % kLeadNumWarps == 0,
39  "block dim should be divided by warp count!");
40  static_assert(kSecondSizePerBlock % kSecondNumWarps == 0,
41  "block dim should be divided by warp count!");
42  // rows/cols per warp
45 
46  static_assert(kLeadSizePerWarp % kQuadrantLeadDim == 0,
47  "xdl dim should be divided by quad dim!");
48  static_assert(kSecondSizePerWarp % kQuadrantSecondDim == 0,
49  "xdl dim should be divided by quad dim!");
50  // xdl rows/cols is divided into quadrants.
53 
54  static constexpr index_t kIterationsInSecondDim =
56 
57  // definitions to adapt to BatchedTransposeKernel
58 
59  // FIXME: support padding
60  static constexpr bool kPadM = kPadM_;
61  static constexpr bool kPadN = kPadN_;
62 
63  static constexpr auto kMPerBlock = kLeadSizePerBlock;
64  static constexpr auto kNPerBlock = kSecondSizePerBlock;
65 
66  // 128-bit is the max single-instruction bandwidth for load/store
67  static constexpr index_t MaxLoadStoreSize = 16;
68  static constexpr auto VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType);
69  static constexpr auto VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType);
70  static constexpr auto LDSVectorSize = MaxLoadStoreSize / sizeof(DataType);
71 };
72 
73 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: batched_transpose_lds_problem.hpp:18
static constexpr index_t kBlockSize_
Definition: batched_transpose_lds_problem.hpp:23
static constexpr index_t kRowPerBlock_
Definition: batched_transpose_lds_problem.hpp:24
static constexpr index_t kBlockSize
Definition: batched_transpose_lds_problem.hpp:27
static constexpr index_t kSecondSizePerWarp
Definition: batched_transpose_lds_problem.hpp:44
static constexpr auto VectorSizeOutput
Definition: batched_transpose_lds_problem.hpp:69
static constexpr index_t kQuadrantSecondDim
Definition: batched_transpose_lds_problem.hpp:36
static constexpr index_t kRowWarps_
Definition: batched_transpose_lds_problem.hpp:21
static constexpr index_t kQuadNumPerLeadDim
Definition: batched_transpose_lds_problem.hpp:51
static constexpr index_t MaxLoadStoreSize
Definition: batched_transpose_lds_problem.hpp:67
static constexpr auto kMPerBlock
Definition: batched_transpose_lds_problem.hpp:63
static constexpr index_t kColPerBlock_
Definition: batched_transpose_lds_problem.hpp:25
static constexpr bool kPadM
Definition: batched_transpose_lds_problem.hpp:60
static constexpr index_t kIterationsInSecondDim
Definition: batched_transpose_lds_problem.hpp:54
static constexpr index_t kLeadSizePerWarp
Definition: batched_transpose_lds_problem.hpp:43
static constexpr index_t kQuadNumPerSecondDim
Definition: batched_transpose_lds_problem.hpp:52
static constexpr index_t kLeadNumWarps
Definition: batched_transpose_lds_problem.hpp:29
static constexpr auto LDSVectorSize
Definition: batched_transpose_lds_problem.hpp:70
static constexpr auto kNPerBlock
Definition: batched_transpose_lds_problem.hpp:64
static constexpr index_t kQuadrantLeadDim
Definition: batched_transpose_lds_problem.hpp:35
static constexpr index_t kColWarps_
Definition: batched_transpose_lds_problem.hpp:22
static constexpr index_t kSecondSizePerBlock
Definition: batched_transpose_lds_problem.hpp:33
static constexpr index_t kLeadSizePerBlock
Definition: batched_transpose_lds_problem.hpp:32
static constexpr bool kPadN
Definition: batched_transpose_lds_problem.hpp:61
static constexpr index_t kSecondNumWarps
Definition: batched_transpose_lds_problem.hpp:30
static constexpr auto VectorSizeInput
Definition: batched_transpose_lds_problem.hpp:68
remove_cvref_t< DataType_ > DataType
Definition: batched_transpose_lds_problem.hpp:19
Definition: amd_transpose_load_encoding.hpp:14
Definition: integral_constant.hpp:13