/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp Source File
tile_gemm_shape.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
12 {
16 
17  static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
18 
19  static constexpr index_t kM = BlockTile::at(number<0>{});
20  static constexpr index_t kN = BlockTile::at(number<1>{});
21  static constexpr index_t kK = BlockTile::at(number<2>{});
22 };
23 
24 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
Definition: tile_gemm_shape.hpp:12
static constexpr index_t kN
Definition: tile_gemm_shape.hpp:20
remove_cvref_t< BlockWarps_ > BlockWarps
Definition: tile_gemm_shape.hpp:14
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_gemm_shape.hpp:13
static constexpr index_t NumWarps
Definition: tile_gemm_shape.hpp:17
remove_cvref_t< WarpTile_ > WarpTile
Definition: tile_gemm_shape.hpp:15
static constexpr index_t kM
Definition: tile_gemm_shape.hpp:19
static constexpr index_t kK
Definition: tile_gemm_shape.hpp:21
Definition: integral_constant.hpp:13
Definition: math.hpp:98