/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp Source File
device_batched_gemm_multi_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
9 #include "device_base.hpp"
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <typename ALayout,
16  typename BLayout,
17  typename DsLayout,
18  typename ELayout,
19  typename ADataType,
20  typename BDataType,
21  typename DsDataType,
22  typename EDataType,
23  typename AElementwiseOperation,
24  typename BElementwiseOperation,
25  typename CDEElementwiseOperation>
27 {
28  static constexpr index_t NumDTensor = DsDataType::Size();
29 
30  static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
31 
32  virtual std::unique_ptr<BaseArgument>
33  MakeArgumentPointer(const void* p_a,
34  const void* p_b,
35  const std::array<const void*, NumDTensor>& p_ds,
36  void* p_e,
37  index_t M,
38  index_t N,
39  index_t K,
40  index_t Batch,
41  index_t StrideA,
42  index_t StrideB,
43  const std::array<ck::index_t, NumDTensor>& StrideDs,
44  index_t StrideE,
45  index_t BatchStrideA,
46  index_t BatchStrideB,
47  const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
48  index_t BatchStrideE,
49  AElementwiseOperation a_element_op,
50  BElementwiseOperation b_element_op,
51  CDEElementwiseOperation cde_element_op) = 0;
52 
53  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
54 };
55 
56 template <typename ALayout,
57  typename BLayout,
58  typename DsLayout,
59  typename ELayout,
60  typename ADataType,
61  typename BDataType,
62  typename DsDataType,
63  typename EDataType,
64  typename AElementwiseOperation,
65  typename BElementwiseOperation,
66  typename CDEElementwiseOperation>
68 {
69  static constexpr index_t NumDTensor = DsDataType::Size();
70 
71  static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
72 
73  virtual std::unique_ptr<BaseArgument>
74  MakeArgumentPointer(const void* p_a,
75  const void* p_b,
76  const std::array<const void*, NumDTensor>& p_ds,
77  void* p_e,
78  index_t M,
79  index_t N,
80  index_t K,
81  index_t Batch,
82  index_t StrideA,
83  index_t StrideB,
84  const std::array<ck::index_t, NumDTensor>& StrideDs,
85  index_t StrideE,
86  index_t BatchStrideA,
87  index_t BatchStrideB,
88  const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
89  index_t BatchStrideE,
90  AElementwiseOperation a_element_op,
91  BElementwiseOperation b_element_op,
92  CDEElementwiseOperation cde_element_op,
93  index_t KBatch) = 0;
94 
95  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
96 };
97 
98 } // namespace device
99 } // namespace tensor_operation
100 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_batched_gemm_multi_d.hpp:27
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumDTensor
Definition: device_batched_gemm_multi_d.hpp:28
Definition: device_batched_gemm_multi_d.hpp:68
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, index_t KBatch)=0
static constexpr index_t NumDTensor
Definition: device_batched_gemm_multi_d.hpp:69
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0