/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/masking_specialization.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/masking_specialization.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/masking_specialization.hpp Source File
masking_specialization.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck {
7 namespace tensor_operation {
8 namespace device {
9 
11 {
14 };
15 
17 {
18  switch(s)
19  {
20  case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
21  case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
22  default: return "Unrecognized specialization!";
23  }
24 }
25 
27 {
28  __host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
29  {
30  return false;
31  };
32 
33  __host__ __device__ constexpr bool
34  IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
35  {
36  return false;
37  }
38 };
39 
41 {
42  __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
43 
44  __host__ __device__ constexpr bool
45  IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
46  {
47  return operator()(m + m_tile - 1, n);
48  }
49 };
50 
51 // to track the points which need to be set to -inf on C0
52 // Note: no need to reset M padding value, because they will not be stored out.
53 template <typename MaskOutPredicate>
55 {
56  __host__ __device__ C0MatrixMask_impl(index_t NRaw)
57  : NRaw_(NRaw), predicate_(MaskOutPredicate{})
58  {
59  }
60 
61  __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
62  {
63  return n >= NRaw_;
64  }
65 
66  __host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
67  {
68  return predicate_(m, n) || IsNOutOfBound(n);
69  }
70 
71  __host__ __device__ constexpr bool
72  IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
73  {
74  return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
75  }
76 
77  private:
78  // index_t MRaw_;
79  index_t NRaw_;
80  MaskOutPredicate predicate_;
81 };
82 
83 } // namespace device
84 } // namespace tensor_operation
85 } // namespace ck
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition: masking_specialization.hpp:16
MaskingSpecialization
Definition: masking_specialization.hpp:11
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: masking_specialization.hpp:55
__host__ constexpr __device__ bool IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
Definition: masking_specialization.hpp:72
__host__ constexpr __device__ bool IsNOutOfBound(index_t n) const
Definition: masking_specialization.hpp:61
__host__ __device__ C0MatrixMask_impl(index_t NRaw)
Definition: masking_specialization.hpp:56
__host__ constexpr __device__ bool IsMaskedElement(index_t m, index_t n) const
Definition: masking_specialization.hpp:66
Definition: masking_specialization.hpp:27
__host__ constexpr __device__ bool IsTileSkippable(index_t, index_t, index_t, index_t) const
Definition: masking_specialization.hpp:34
__host__ constexpr __device__ bool operator()(index_t, index_t) const
Definition: masking_specialization.hpp:28
Definition: masking_specialization.hpp:41
__host__ constexpr __device__ bool IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t) const
Definition: masking_specialization.hpp:45
__host__ constexpr __device__ bool operator()(index_t m, index_t n) const
Definition: masking_specialization.hpp:42