/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp Source File
device_batched_gemm_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
9 #include "device_base.hpp"
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <typename ALayout,
16  typename B0Layout,
17  typename B1Layout,
18  typename CLayout,
19  typename ADataType,
20  typename B0DataType,
21  typename B1DataType,
22  typename CDataType,
23  typename AElementwiseOperation,
24  typename B0ElementwiseOperation,
25  typename Acc0ElementwiseOperation,
26  typename B1ElementwiseOperation,
27  typename CElementwiseOperation>
29 {
30  virtual std::unique_ptr<BaseArgument>
31  MakeArgumentPointer(const void* p_a,
32  const void* p_b0,
33  const void* p_b1,
34  void* p_c,
35  ck::index_t M,
36  ck::index_t N,
37  ck::index_t K,
38  ck::index_t O,
39  ck::index_t Batch,
40  ck::index_t StrideA,
41  ck::index_t StrideB0,
42  ck::index_t StrideB1,
43  ck::index_t StrideC,
44  ck::index_t BatchStrideA,
45  ck::index_t BatchStrideB0,
46  ck::index_t BatchStrideB1,
47  ck::index_t BatchStrideC,
48  AElementwiseOperation a_element_op,
49  B0ElementwiseOperation b0_element_op,
50  Acc0ElementwiseOperation acc0_element_op,
51  B1ElementwiseOperation b1_element_op,
52  CElementwiseOperation c_element_op) = 0;
53 
54  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
55 };
56 
57 } // namespace device
58 } // namespace tensor_operation
59 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_batched_gemm_gemm.hpp:29
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t O, ck::index_t Batch, ck::index_t StrideA, ck::index_t StrideB0, ck::index_t StrideB1, ck::index_t StrideC, ck::index_t BatchStrideA, ck::index_t BatchStrideB0, ck::index_t BatchStrideB1, ck::index_t BatchStrideC, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, Acc0ElementwiseOperation acc0_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)=0