/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp Source File
device_batched_gemm_e_permute.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 #include <iostream>
6 #include <vector>
7 
8 #include "device_base.hpp"
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
15 {
18 };
19 
20 template <typename ALayout,
21  typename BLayout,
22  typename DELayout,
23  typename ADataType,
24  typename BDataType,
25  typename EDataType,
26  typename AElementwiseOperation,
27  typename BElementwiseOperation,
28  typename CDEElementwiseOperation>
30 {
31  virtual std::unique_ptr<BaseArgument>
32  MakeArgumentPointer(const void* p_a,
33  const void* p_b,
34  void* p_e,
35  index_t M,
36  index_t N,
37  index_t K,
38  index_t stride_A,
39  index_t stride_B,
40  index_t batch_stride_A,
41  index_t batch_stride_B,
42  BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
43  index_t BatchCount,
44  AElementwiseOperation a_element_op,
45  BElementwiseOperation b_element_op,
46  CDEElementwiseOperation cde_element_op) = 0;
47 
48  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
49 };
50 
51 } // namespace device
52 } // namespace tensor_operation
53 } // namespace ck
Definition: ck.hpp:270
int32_t index_t
Definition: ck.hpp:301
Definition: device_base.hpp:223
Definition: device_batched_gemm_e_permute.hpp:15
ck::index_t N_
Definition: device_batched_gemm_e_permute.hpp:16
ck::index_t stride_N_
Definition: device_batched_gemm_e_permute.hpp:17
ck::index_t G1_
Definition: device_batched_gemm_e_permute.hpp:16
ck::index_t stride_G1_
Definition: device_batched_gemm_e_permute.hpp:17
ck::index_t M_
Definition: device_batched_gemm_e_permute.hpp:16
ck::index_t stride_M_
Definition: device_batched_gemm_e_permute.hpp:17
ck::index_t G0_
Definition: device_batched_gemm_e_permute.hpp:16
ck::index_t stride_G0_
Definition: device_batched_gemm_e_permute.hpp:17
Definition: device_batched_gemm_e_permute.hpp:30
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0