/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp Source File
device_contraction_utils.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cassert>
7 #include <sstream>
8 #include <vector>
9 
10 #include "ck/ck.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
32 template <index_t NumDim1, index_t NumDim2>
33 auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
34 {
35  if(lengths.size() != NumDim1 + NumDim2)
36  {
37  std::ostringstream err;
38  err << "Incorrect number of lengths in "
39  << "device_contraction_utils.hpp"
40  << ":" << __LINE__ << ", in function: " << __func__;
41  throw std::runtime_error(err.str());
42  }
43  if(strides.size() != NumDim1 + NumDim2)
44  {
45  std::ostringstream err;
46  err << "Incorrect number of strides in "
47  << "device_contraction_utils.hpp"
48  << ":" << __LINE__ << ", in function: " << __func__;
49  throw std::runtime_error(err.str());
50  }
51 
52  // Determine the beginning and end idx of the group representing the FCD.
53  index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
54  if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
55  {
56  // MZ or KZ are ones
57  bool dims1_are_ones = true;
58  for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
59  {
60  if(lengths[dim_idx] != 1)
61  {
62  dims1_are_ones = false;
63  }
64  }
65 
66  if(dims1_are_ones)
67  {
68  begin_idx = NumDim1;
69  end_idx = NumDim1 + NumDim2 - 1;
70  continous_dim = 1;
71  }
72  else
73  {
74  begin_idx = 0;
75  end_idx = NumDim1 - 1;
76  continous_dim = 0;
77  }
78  }
79  else if(strides[NumDim1 - 1] == 1)
80  {
81  begin_idx = 0;
82  end_idx = NumDim1 - 1;
83  continous_dim = 0;
84  }
85  else if(strides[NumDim1 + NumDim2 - 1] == 1)
86  {
87  begin_idx = NumDim1;
88  end_idx = NumDim1 + NumDim2 - 1;
89  continous_dim = 1;
90  }
91  else
92  {
93  // The dimension consecutive in memory is not the last dimension of any group, so only
94  // one element can be read/written at once.
95  consecutive_stride = 1;
96  continous_dim = 0;
97  return make_tuple(continous_dim, consecutive_stride);
98  }
99 
100  for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
101  {
102  if(strides[dim_idx] == consecutive_stride)
103  {
104  consecutive_stride *= lengths[dim_idx];
105  }
106  else
107  {
108  break;
109  }
110  }
111  const index_t max_subsequent_elems = consecutive_stride;
112  return make_tuple(continous_dim, max_subsequent_elems);
113 }
114 
115 } // namespace device
116 } // namespace tensor_operation
117 } // namespace ck
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_contraction_utils.hpp:33
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289