/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp Source File
device_grouped_gemm_softmax_gemm_permute.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
9 #include "device_base.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
16 template <index_t NumDimG,
17  index_t NumDimM,
18  index_t NumDimN,
19  index_t NumDimK,
20  index_t NumDimO,
21  typename ADataType,
22  typename B0DataType,
23  typename B1DataType,
24  typename CDataType,
25  typename Acc0BiasDataType,
26  typename Acc1BiasDataType,
27  typename AElementwiseOperation,
28  typename B0ElementwiseOperation,
29  typename Acc0ElementwiseOperation,
30  typename B1ElementwiseOperation,
31  typename CElementwiseOperation,
32  MaskingSpecialization MaskingSpec>
34 {
35  struct ProblemDesc
36  {
37  std::vector<index_t> a_gs_ms_ks_lengths;
38  std::vector<index_t> a_gs_ms_ks_strides;
39 
40  std::vector<index_t> b0_gs_ns_ks_lengths;
41  std::vector<index_t> b0_gs_ns_ks_strides;
42 
43  std::vector<index_t> b1_gs_os_ns_lengths;
44  std::vector<index_t> b1_gs_os_ns_strides;
45 
46  std::vector<index_t> c_gs_ms_os_lengths;
47  std::vector<index_t> c_gs_ms_os_strides;
48 
49  std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
50  std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
51 
52  std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
53  std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
54  };
55 
56  virtual std::unique_ptr<BaseArgument>
57  MakeArgumentPointer(std::vector<const void*> p_a_vec,
58  std::vector<const void*> p_b0_vec,
59  std::vector<const void*> p_b1_vec,
60  std::vector<void*> p_c_vec,
61  std::vector<std::vector<const void*>> p_acc0_biases_vec,
62  std::vector<std::vector<const void*>> p_acc1_biases_vec,
63  std::vector<ProblemDesc> problem_desc_vec,
64  AElementwiseOperation a_element_op,
65  B0ElementwiseOperation b0_element_op,
66  Acc0ElementwiseOperation acc0_element_op,
67  B1ElementwiseOperation b1_element_op,
68  CElementwiseOperation c_element_op) = 0;
69 
70  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
71 };
72 
73 } // namespace device
74 } // namespace tensor_operation
75 } // namespace ck
MaskingSpecialization
Definition: masking_specialization.hpp:11
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:36
std::vector< index_t > b1_gs_os_ns_strides
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:44
std::vector< index_t > c_gs_ms_os_strides
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:47
std::vector< index_t > b1_gs_os_ns_lengths
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:43
std::vector< std::vector< index_t > > acc1_biases_gs_ms_os_strides
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:53
std::vector< std::vector< index_t > > acc1_biases_gs_ms_os_lengths
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:52
std::vector< index_t > b0_gs_ns_ks_lengths
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:40
std::vector< index_t > c_gs_ms_os_lengths
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:46
std::vector< std::vector< index_t > > acc0_biases_gs_ms_ns_strides
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:50
std::vector< index_t > a_gs_ms_ks_strides
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:38
std::vector< index_t > a_gs_ms_ks_lengths
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:37
std::vector< index_t > b0_gs_ns_ks_strides
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:41
std::vector< std::vector< index_t > > acc0_biases_gs_ms_ns_lengths
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:49
Definition: device_grouped_gemm_softmax_gemm_permute.hpp:34
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > p_a_vec, std::vector< const void * > p_b0_vec, std::vector< const void * > p_b1_vec, std::vector< void * > p_c_vec, std::vector< std::vector< const void * >> p_acc0_biases_vec, std::vector< std::vector< const void * >> p_acc1_biases_vec, std::vector< ProblemDesc > problem_desc_vec, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, Acc0ElementwiseOperation acc0_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)=0