/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp Source File
device_gemm_reduce.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 
8 #include "device_base.hpp"
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 // FIXME: DeviceGemmReduce type need to well define the problem
15 template <ck::index_t NumDTensor, ck::index_t NumReduce>
17 {
18  virtual std::unique_ptr<BaseArgument>
19  MakeArgumentPointer(const void* p_a,
20  const void* p_b,
21  const void* p_bias,
22  std::array<const void*, NumDTensor> p_ds,
23  void* p_c,
24  std::array<void*, NumReduce> p_reduces,
25  ck::index_t M,
26  ck::index_t N,
27  ck::index_t K,
28  ck::index_t StrideA,
29  ck::index_t StrideB,
30  ck::index_t StrideC,
31  std::array<ck::index_t, NumDTensor> StrideDs,
32  std::array<void*, 3> gemm_element_ops,
33  std::array<void*, NumDTensor> d_element_ops,
34  std::array<void*, NumReduce> reduce_in_element_ops,
35  std::array<void*, NumReduce> reduce_out_element_ops,
36  ck::index_t BatchCount = 1) = 0;
37 
38  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
39 };
40 
41 template <ck::index_t NumDTensor, ck::index_t NumReduce>
42 using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
43 
44 } // namespace device
45 } // namespace tensor_operation
46 } // namespace ck
std::unique_ptr< DeviceGemmReduce< NumDTensor, NumReduce > > DeviceGemmReducePtr
Definition: device_gemm_reduce.hpp:42
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_gemm_reduce.hpp:17
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, NumDTensor > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, NumDTensor > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, NumDTensor > d_element_ops, std::array< void *, NumReduce > reduce_in_element_ops, std::array< void *, NumReduce > reduce_out_element_ops, ck::index_t BatchCount=1)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0