/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp Source File
batched_transpose_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 #include <string>
11 #include <type_traits>
12 
13 namespace ck_tile {
14 
16 {
17  const void* p_input;
18  void* p_output;
22  // index_t dim_blocks;
26 };
27 
28 template <typename Pipeline_>
30 {
33 
34  using Type = typename Problem::InputType;
35 
37  {
38  const void* p_input;
39  void* p_output;
44  };
45 
48 
49  CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
50  {
51  size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
52  size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
53  size_t grid_size_z = h.batch;
54  return dim3(grid_size_x, grid_size_y, grid_size_z);
55  }
56 
57  CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
58  {
59  Kargs k;
60  k.p_input = h.p_input;
61  k.p_output = h.p_output;
62  k.batch = h.batch;
63  k.height = h.height;
64  k.width = h.width;
65  k.dim_stride = h.dim_stride;
66  return k;
67  }
68 
69  CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
70 
71  CK_TILE_DEVICE void operator()(Kargs kargs) const
72  {
73 
74  static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
75  static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
76  static constexpr bool kPadM = Problem::kPadM;
77  static constexpr bool kPadN = Problem::kPadN;
78 
79  static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread;
80  static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread;
81 
82  static_assert(kMPerThread == 1 && kNPerThread == 1);
83 
84  const auto iDim = blockIdx.z;
85  const auto x_m_n = [&]() {
86  const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
87  static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
88  make_tuple(kargs.height, kargs.width),
89  make_tuple(kargs.width, 1),
90  number<kNPerThread>{}, // TODO thread load value
91  number<1>{});
92 
93  return pad_tensor_view(x_dram_naive,
96  }();
97 
98  const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
99  const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
100 
101  const auto y_n_m = [&]() {
102  const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
103  static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
104  make_tuple(kargs.width, kargs.height),
105  make_tuple(kargs.height, 1),
107  number<1>{});
108 
109  return pad_tensor_view(y_dram_naive,
112  }();
113 
114  auto x_block_window =
115  make_tile_window(x_m_n,
117  {static_cast<ck_tile::index_t>(iM * kMPerBlock),
118  static_cast<ck_tile::index_t>(iN * kNPerBlock)});
119 
120  auto y_block_window =
121  make_tile_window(y_n_m,
123  {static_cast<ck_tile::index_t>(iN * kNPerBlock),
124  static_cast<ck_tile::index_t>(iM * kMPerBlock)});
125 
126  Pipeline{}(x_block_window, y_block_window);
127  }
128 };
129 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
Definition: batched_transpose_kernel.hpp:16
index_t height
Definition: batched_transpose_kernel.hpp:20
index_t batch
Definition: batched_transpose_kernel.hpp:19
void * p_output
Definition: batched_transpose_kernel.hpp:18
index_t dim_block_w
Definition: batched_transpose_kernel.hpp:25
index_t dim_stride
Definition: batched_transpose_kernel.hpp:23
const void * p_input
Definition: batched_transpose_kernel.hpp:17
index_t width
Definition: batched_transpose_kernel.hpp:21
index_t dim_block_h
Definition: batched_transpose_kernel.hpp:24
Definition: batched_transpose_kernel.hpp:37
index_t width
Definition: batched_transpose_kernel.hpp:42
index_t height
Definition: batched_transpose_kernel.hpp:41
index_t dim_stride
Definition: batched_transpose_kernel.hpp:43
index_t batch
Definition: batched_transpose_kernel.hpp:40
const void * p_input
Definition: batched_transpose_kernel.hpp:38
void * p_output
Definition: batched_transpose_kernel.hpp:39
Definition: batched_transpose_kernel.hpp:30
remove_cvref_t< typename Pipeline::Problem > Problem
Definition: batched_transpose_kernel.hpp:32
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: batched_transpose_kernel.hpp:71
static constexpr CK_TILE_HOST auto MakeKargs(const Hargs &h)
Definition: batched_transpose_kernel.hpp:57
static constexpr CK_TILE_HOST_DEVICE auto BlockSize()
Definition: batched_transpose_kernel.hpp:69
typename Problem::InputType Type
Definition: batched_transpose_kernel.hpp:34
static constexpr CK_TILE_HOST auto GridSize(const Hargs &h)
Definition: batched_transpose_kernel.hpp:49
remove_cvref_t< Pipeline_ > Pipeline
Definition: batched_transpose_kernel.hpp:31
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52