/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp Source File
device_gemm_streamk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
9 #include "device_base.hpp"
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <typename ALayout,
16  typename BLayout,
17  typename CLayout,
18  typename ADataType,
19  typename BDataType,
20  typename CDataType,
21  typename AElementwiseOperation,
22  typename BElementwiseOperation,
23  typename CElementwiseOperation>
25 {
26  virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
27  const void* p_b,
28  void* p_c,
29  ck::index_t M,
30  ck::index_t N,
31  ck::index_t K,
32  ck::index_t StrideA,
33  ck::index_t StrideB,
34  ck::index_t StrideC,
35  AElementwiseOperation a_element_op,
36  BElementwiseOperation b_element_op,
37  CElementwiseOperation c_element_op,
38  ck::index_t NumSKBlocks = 0) = 0;
39 
40  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
41 };
42 
43 template <typename ALayout,
44  typename BLayout,
45  typename CLayout,
46  typename ADataType,
47  typename BDataType,
48  typename CDataType,
49  typename AElementwiseOperation,
50  typename BElementwiseOperation,
51  typename CElementwiseOperation>
52 using DeviceGemmStreamKPtr = std::unique_ptr<DeviceGemmStreamK<ALayout,
53  BLayout,
54  CLayout,
55  ADataType,
56  BDataType,
57  CDataType,
58  AElementwiseOperation,
59  BElementwiseOperation,
60  CElementwiseOperation>>;
61 
62 } // namespace device
63 } // namespace tensor_operation
64 } // namespace ck
std::unique_ptr< DeviceGemmStreamK< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > > DeviceGemmStreamKPtr
Definition: device_gemm_streamk.hpp:60
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_gemm_streamk.hpp:25
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ck::index_t NumSKBlocks=0)=0