StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference#
The Stream K GEMM kernel class. More...
#include <streamk_gemm_kernel.hpp>
Classes | |
| struct | StreamKKernelArgs |
| ALayout and ADataType are expected to be scalars, not a tuple. More... | |
Public Types | |
| using | UniversalGemmKernel = UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > |
| Inject the UniversalGemmKernel base class to support execution of all necessary functions. More... | |
| using | TilePartitioner = TilePartitioner_ |
| using | GemmPipeline = GemmPipeline_ |
| using | EpiloguePipeline = EpiloguePipeline_ |
| using | ALayout = typename GemmPipeline::ALayout |
| Specify the layout configurations for A, B, and C. More... | |
| using | BLayout = typename GemmPipeline::BLayout |
| using | CLayout = typename GemmPipeline::CLayout |
| using | ADataType = typename GemmPipeline::ADataType |
| Specify the data type configurations for A, B, and C. More... | |
| using | BDataType = typename GemmPipeline::BDataType |
| using | CDataType = typename EpiloguePipeline::ODataType |
| using | AccDataType = typename EpiloguePipeline::AccDataType |
| using | KernelArgs = StreamKKernelArgs |
| using | Kernel = StreamKKernel< TilePartitioner, GemmPipeline, EpiloguePipeline > |
Public Member Functions | |
| CK_TILE_DEVICE void | BaseGemm (StreamKKernelArgs &kargs, index_t tile_idx, index_t num_loop, index_t i_k_a, index_t i_k_b, index_t k_size, void *smem_ptr_0) const |
| Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue. More... | |
| CK_TILE_DEVICE void | SignalStorePartialDone (const StreamKKernelArgs &kargs, index_t cta_idx) const |
| Signals that the current thread block(CTA) has completed storing its partial results. More... | |
| CK_TILE_DEVICE void | WaitStorePartialDone (const StreamKKernelArgs &kargs, index_t cta_idx) const |
| Waits for the thread block (cta_idx) to complete storing its partial results. More... | |
| template<typename OAccTile > | |
| CK_TILE_DEVICE void | AddBlockTile (OAccTile &in_out_block_tile, const OAccTile &in_block_tile) const |
| Adds the values of a block tile to an output block tile. More... | |
| template<typename DataType , typename OAccTileDist > | |
| CK_TILE_DEVICE auto | LoadPartial (const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTileDist &c_block_tile_dist) const |
| Loads a partial block tile from the workspace buffer. More... | |
| template<typename OAccTile > | |
| CK_TILE_DEVICE void | StorePartial (const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTile &c_block_tile) const |
| Stores a partial block tile to the workspace buffer. More... | |
| CK_TILE_DEVICE void | StreamKGemm (StreamKKernelArgs &kargs, index_t cta_idx, void *smem_ptr_0) const |
| Runs the main Stream - K algorithm. More... | |
| template<bool U = PersistentDP> | |
| CK_TILE_DEVICE std::enable_if_t<!U > | operator() (StreamKKernelArgs kargs) const |
| Entry point for the Stream-K Kernel with non-persistent DP. More... | |
| template<bool U = PersistentDP> | |
| CK_TILE_DEVICE std::enable_if_t< U > | operator() (StreamKKernelArgs kargs) const |
| Entry point for the Stream-K Kernel with persistent DP. More... | |
Static Public Member Functions | |
| static CK_TILE_HOST const std::string | GetName () |
| static CK_TILE_HOST auto | GridSize (const TilePartitioner &tile_partitioner) -> dim3 |
| Compute the grid size for the Stream K kernel using the tile_partitioner. More... | |
| static CK_TILE_HOST auto | MaxOccupancyGridSize (const stream_config &s) -> dim3 |
| Get the maximum occupancy grid size for the persistent kernel on the current device. More... | |
| static constexpr CK_TILE_HOST auto | BlockSize () -> dim3 |
| static CK_TILE_HOST StreamKKernelArgs | MakeKernelArgs (const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy()) |
| Constructs kernel arguments for the Stream-K kernel. More... | |
| template<bool UseDefaultScheduler = true> | |
| static CK_TILE_DEVICE void | RunGemm (const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size) |
| static CK_TILE_HOST bool | IsSupportedArgument (const StreamKKernelArgs &kargs) |
| static CK_TILE_HOST uint32_t | GetWorkSpaceSize (const StreamKKernelArgs &kargs) |
| Computes the buffer size needed to store accumulation results for Stream K. More... | |
| static CK_TILE_HOST void | SetWorkSpacePointer (StreamKKernelArgs &kargs, void *workspace_ptr) |
| Sets the kargs' current workspace_ptr to the given workspace_ptr. More... | |
Static Public Attributes | |
| static constexpr index_t | kBlockSize = UniversalGemmKernel::kBlockSize |
| static constexpr bool | PersistentDP = UniversalGemmKernel::PersistentKernel |
| template<typename T > | |
| static constexpr bool | is_tuple_v = is_detected<is_tuple, T>::value |
Detailed Description
template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
The Stream K GEMM kernel class.
- Overview
- This class is responsible for the Stream-K kernel, making use of UniversalGemm. The main kernel functions are the operator() functions. There is one for Persistent and one for Non-Persistent data parallel sections of the Stream-K algorithm.
Both the Non-Persistent and Persistent kernels make use of BaseGemm() and StreamKGemm(). BaseGemm() computes offsets into the A,B,C tensors, then calls RunGemm() which runs the GEMM pipeline and epilogue. StreamKGemm() performs the main Stream-K algorithm. Each iteration of the Stream-K loop calls BaseGemm().
Member Typedef Documentation
◆ AccDataType
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::AccDataType = typename EpiloguePipeline::AccDataType |
◆ ADataType
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ADataType = typename GemmPipeline::ADataType |
Specify the data type configurations for A, B, and C.
◆ ALayout
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ALayout = typename GemmPipeline::ALayout |
Specify the layout configurations for A, B, and C.
◆ BDataType
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BDataType = typename GemmPipeline::BDataType |
◆ BLayout
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BLayout = typename GemmPipeline::BLayout |
◆ CDataType
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::CDataType = typename EpiloguePipeline::ODataType |
◆ CLayout
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::CLayout = typename GemmPipeline::CLayout |
◆ EpiloguePipeline
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::EpiloguePipeline = EpiloguePipeline_ |
◆ GemmPipeline
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GemmPipeline = GemmPipeline_ |
◆ Kernel
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline> |
◆ KernelArgs
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::KernelArgs = StreamKKernelArgs |
◆ TilePartitioner
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::TilePartitioner = TilePartitioner_ |
◆ UniversalGemmKernel
| using ck_tile::StreamKKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::UniversalGemmKernel = UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_> |
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Member Function Documentation
◆ AddBlockTile()
|
inline |
Adds the values of a block tile to an output block tile.
- Parameters
-
in_out_block_tile The output block tile to which values are added. in_block_tile The input block tile whose values are added.
- Note
- This function iterates over the distributed spans of the block tiles and updates the output block tile with accumulated values.
◆ BaseGemm()
|
inline |
Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
- Parameters
-
kargs Stream-K kernel arguments. tile_idx The 1D tile index in the C tensor for this workgroup. num_loop The number of iterations (at the macro tile level) in the K dimension this workgroup will perform in the C tile. i_k_a The K offset in the A tensor. i_k_b The K offset in the B tensor. k_size The portion of the K dimension this workgroup processes in the assigned tile_idx.smem_ptr_0 Pointer to LDS.
◆ BlockSize()
|
inlinestaticconstexpr |
◆ GetName()
|
inlinestatic |
◆ GetWorkSpaceSize()
|
inlinestatic |
Computes the buffer size needed to store accumulation results for Stream K.
- Returns
- The buffer size needed.
◆ GridSize()
|
inlinestatic |
Compute the grid size for the Stream K kernel using the tile_partitioner.
- Returns
- The grid size.
◆ IsSupportedArgument()
|
inlinestatic |
◆ LoadPartial()
|
inline |
Loads a partial block tile from the workspace buffer.
- Parameters
-
kargs Kernel arguments, including the workspace pointer. cta_idx The index of the thread block (CTA). c_block_tile_dist The tile distribution for the block.
- Returns
- The loaded partial block tile.
- Note
- This function calculates the buffer pointer and uses the tile distribution for loading the partial block tile.
◆ MakeKernelArgs()
|
inlinestatic |
Constructs kernel arguments for the Stream-K kernel.
- Parameters
-
host_args Stream-K host arguments. num_cu Number of compute units (CUs). The default is the number of CUs on the device. The caller may select their own to assist with test reproducibility, etc. occupancy The maximum number of active blocks per CU for this kernel. The caller may select their own to assist with test reproducibility, etc.
- Returns
- The kernel arguments for Stream-K.
◆ MaxOccupancyGridSize()
|
inlinestatic |
Get the maximum occupancy grid size for the persistent kernel on the current device.
- Returns
- The maximum occupancy grid size.
- Note
- This function queries the maximum occupancy of the kernel using
hipOccupancyMaxActiveBlocksPerMultiprocessor.
◆ operator()() [1/2]
|
inline |
Entry point for the Stream-K Kernel with non-persistent DP.
- Overview
- For the Non-Persistent kernel, each data parallel workgroup will compute the results for their assigned macro-tile by calling
BaseGemm(). The Stream-K workgroups will do their assigned work by callingStreamKGemm(), which callsBaseGemm()in the Stream-K loop.
◆ operator()() [2/2]
|
inline |
Entry point for the Stream-K Kernel with persistent DP.
- Overview
- For the Persistent kernel, each workgroup will first compute their assigned data-parallel tiles. Each data parallel tile will be computed by calling
BaseGemm(). Then the workgroups will proceed with the Stream-K portion by callingStreamKGemm(), which callsBaseGemm()in the Stream-K loop.
◆ RunGemm()
|
inlinestatic |
◆ SetWorkSpacePointer()
|
inlinestatic |
Sets the kargs' current workspace_ptr to the given workspace_ptr.
- Note
- Assumes that the given workspace_ptr points to allocated device memory.
◆ SignalStorePartialDone()
|
inline |
Signals that the current thread block(CTA) has completed storing its partial results.
- Parameters
-
kargs Kernel arguments, including the workspace pointer. cta_idx The index of the current thread block (CTA).
- Note
- This function utilizes a workgroup barrier to set a synchronization flag for the given CTA index.
◆ StorePartial()
|
inline |
Stores a partial block tile to the workspace buffer.
- Parameters
-
kargs Kernel arguments, including the workspace pointer. cta_idx The index of the thread block (CTA). c_block_tile The block tile to be stored.
- Note
- This function calculates the buffer pointer and uses the tile window for storing the partial block tile.
◆ StreamKGemm()
|
inline |
Runs the main Stream - K algorithm.
- Parameters
-
kargs Stream - K kernel arguments. cta_idx The current Stream - K workgroup's index. smem_ptr_0 Pointer to LDS.
- Note
- It is assumed that the first Stream - K workgroup has a
cta_idxof zero. If a non-persistent data-parallel (DP) section is used, then a Stream-K workgroup'scta_idx*should be something likeblockIdx.xminus number of DP workgroups.
◆ WaitStorePartialDone()
|
inline |
Waits for the thread block (cta_idx) to complete storing its partial results.
- Parameters
-
kargs Kernel arguments, including the workspace pointer. cta_idx The index of the thread block (CTA).
- Note
- This function utilizes a workgroup barrier to wait for the synchronization flag to be set by the given CTA index.
Member Data Documentation
◆ is_tuple_v
|
staticconstexpr |
◆ kBlockSize
|
staticconstexpr |
◆ PersistentDP
|
staticconstexpr |
The documentation for this struct was generated from the following file:
- /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp