/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp Source File
device_grouped_gemm_splitk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 #pragma once
4 
6 
7 namespace ck {
8 namespace tensor_operation {
9 namespace device {
10 
11 template <typename ALayout,
12  typename BLayout,
13  typename DsLayout,
14  typename ELayout,
15  typename ADataType,
16  typename BDataType,
17  typename DsDataType,
18  typename EDataType,
19  typename AElementwiseOperation,
20  typename BElementwiseOperation,
21  typename CElementwiseOperation>
23  BLayout,
24  DsLayout,
25  ELayout,
26  ADataType,
27  BDataType,
28  DsDataType,
29  EDataType,
30  AElementwiseOperation,
31  BElementwiseOperation,
32  CElementwiseOperation>
33 {
34  //----------------------------------------------------------------------------------------------
40  virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
41  //----------------------------------------------------------------------------------------------
47  virtual void SetKBatch(BaseArgument* p_arg, index_t kbatch) const
48  {
49  this->SetKBatchSize(p_arg, kbatch);
50  };
51 };
52 
53 } // namespace device
54 } // namespace tensor_operation
55 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:50
Definition: device_grouped_gemm.hpp:105
Definition: device_grouped_gemm_splitk.hpp:33
virtual void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const =0
Sets the k batch size.
virtual void SetKBatch(BaseArgument *p_arg, index_t kbatch) const
Sets the k batch size.
Definition: device_grouped_gemm_splitk.hpp:47