13 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
14 __host__ __device__
static constexpr
auto
18 (void)randval_dram_block_window_tmp;
19 (void)seqlen_qk_start;
30 unsigned long long seed,
33 uint8_t p_undrop_in_uint8_t_,
34 bool is_store_randval_)
42 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
47 constexpr
auto config =
48 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
50 constexpr
index_t MWarp = config.template at<1>();
51 constexpr
index_t NWarp = config.template at<2>();
52 constexpr
index_t kMPerStep = MWarp * WG::kM;
53 constexpr
index_t kNPerStep = NWarp * WG::kN;
55 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
56 auto randval_dram_window = [&]() {
60 randval_dram_block_window_tmp.get_bottom_tensor_view(),
62 {block_origin.at(number<0>{}), seqlen_qk_start});
67 randval_dram_block_window_tmp.get_bottom_tensor_view(),
69 {seqlen_qk_start, block_origin.at(number<1>{})});
73 return randval_dram_window;
76 template <
typename BlockGemm>
79 constexpr
auto config =
80 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
82 constexpr
index_t MWarp = config.template at<1>();
83 constexpr
index_t kMPerStep = MWarp * WG::kM;
84 constexpr
index_t kNPerStep = WG::kN;
86 constexpr
index_t kN0 = kNPerStep / kN1;
95 randval_lds_block_desc_0,
102 return randval_lds_block_desc;
105 template <
typename BlockGemm>
108 constexpr
auto config =
109 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
110 constexpr
index_t MWarp = config.template at<1>();
111 constexpr
index_t NWarp = config.template at<2>();
113 constexpr
index_t MIterPerWarp = 1;
114 constexpr
index_t NIterPerWarp = 1;
125 constexpr
auto randval_block_inner_part_dstr_encoding = []() {
126 if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
127 std::is_same_v<typename BlockGemm::BDataType, half_t> &&
128 std::is_same_v<typename BlockGemm::CDataType, float>)
138 constexpr
auto randval_block_part_dstr_encode =
140 randval_block_inner_part_dstr_encoding);
145 template <
typename BlockGemm>
148 constexpr
auto config =
149 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
151 constexpr
index_t MWarp = config.template at<1>();
152 constexpr
index_t NWarp = config.template at<2>();
154 constexpr
index_t MIterPerWarp = 1;
155 constexpr
index_t NIterPerWarp = 1;
165 constexpr
auto randval_block_part_dstr_encode =
167 typename WG::CWarpDstrEncoding{});
172 template <
typename BlockGemm,
173 typename PComputeDataType,
174 typename RandValOutputDataType,
175 typename PComputeWindow,
176 typename RandValDramWindow>
179 PComputeWindow& p_compute,
180 RandValDramWindow& randval_dram_window)
const
182 constexpr
auto config =
183 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
185 constexpr
index_t MWarp = config.template at<1>();
186 constexpr
index_t NWarp = config.template at<2>();
188 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
189 constexpr
index_t kNPerBlock = BlockGemmShape::kN;
190 constexpr
index_t kMPerStep = MWarp * WG::kM;
191 constexpr
index_t kNPerStep = NWarp * WG::kN;
194 auto randval_lds = make_tensor_view<address_space_enum::lds>(
195 reinterpret_cast<uint8_t*
>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
198 randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
201 auto randval_dist_generated =
202 make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
203 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
205 auto randval_lds_read_window =
207 randval_lds_window.get_window_lengths(),
208 randval_lds_window.get_window_origin(),
209 MakeRandValLdsShuffleTileDistribution<BlockGemm>());
211 const int start_m0_idx = randval_dram_window.get_window_origin().at(
number<0>{});
214 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
215 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
216 int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) +
get_warp_id();
217 int block_col_start = (start_n0_idx / WG::kN) + i_n0;
218 uint2 rowcol = make_uint2(block_row_start, block_col_start);
221 uint8_t random_uint8_t[16];
222 ph.get_random_16x8(random_uint8_t,
223 reinterpret_cast<unsigned long long&
>(rowcol));
225 constexpr
auto randval_dist_generated_spans =
226 decltype(randval_dist_generated)::get_distributed_spans();
227 int i_random_idx = 0;
231 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
235 store_tile(randval_lds_window, randval_dist_generated);
238 auto randval =
load_tile(randval_lds_read_window);
240 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
241 store_tile(randval_dram_window, randval_store);
248 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
249 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
250 int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) +
get_warp_id();
251 int block_col_start = (start_n0_idx / WG::kN) + i_n0;
252 uint2 rowcol = make_uint2(block_row_start, block_col_start);
255 uint8_t random_uint8_t[16];
256 ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&
>(rowcol));
258 constexpr
auto randval_dist_generated_spans =
259 decltype(randval_dist_generated)::get_distributed_spans();
260 int i_random_idx = 0;
264 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
268 store_tile(randval_lds_window, randval_dist_generated);
271 auto randval =
load_tile(randval_lds_read_window);
272 constexpr
auto randval_spans = decltype(randval)::get_distributed_spans();
276 constexpr
auto p_idx1 =
280 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
281 ? p_compute[p_idx] * rp_undrop
282 : PComputeDataType(0);
295 template <
bool IsDropout_,
bool IsWG32_,
bool IsStoreRandval_>
298 template <
bool IsWG32_,
bool IsStoreRandval_>
301 static constexpr
bool IsDropout =
false;
302 static constexpr
bool IsStoreRandval = IsStoreRandval_;
304 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
305 __host__ __device__
static constexpr
auto
309 (void)randval_dram_block_window_tmp;
310 (void)seqlen_qk_start;
316 template <
bool IsWG32_,
bool IsStoreRandval_>
319 static constexpr
bool IsDropout =
true;
322 static constexpr
bool IsWG32 = IsWG32_;
323 static constexpr
bool IsStoreRandval = IsStoreRandval_;
328 unsigned long long seed,
329 unsigned long long offset,
331 uint8_t p_undrop_in_uint8_t_)
335 rp_undrop(rp_undrop_),
336 p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
340 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
345 constexpr
auto config =
346 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
349 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
350 constexpr
index_t MWarp = config.template at<1>();
351 constexpr
index_t NWarp = config.template at<2>();
352 constexpr
bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
353 constexpr
index_t kMPerStep = [&]() {
354 if constexpr(MBwdWG16MultiIterCheck)
356 return MWarp * WG::kM * 2;
360 return MWarp * WG::kM;
363 constexpr
index_t kNPerStep = NWarp * WG::kN;
365 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
366 auto randval_dram_window = [&]() {
370 randval_dram_block_window_tmp.get_bottom_tensor_view(),
372 {block_origin.at(number<0>{}), seqlen_qk_start});
377 randval_dram_block_window_tmp.get_bottom_tensor_view(),
379 {seqlen_qk_start, block_origin.at(number<1>{})});
383 return randval_dram_window;
386 template <
typename BlockGemm>
389 constexpr
auto config =
390 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
392 constexpr
index_t MWarp = config.template at<1>();
393 constexpr
index_t kMPerStep = MWarp * WG::kM;
394 constexpr
index_t kNPerStep = WG::kN;
396 constexpr
index_t kN0 = kNPerStep / kN1;
405 randval_lds_block_desc_0,
412 return randval_lds_block_desc;
415 template <
typename BlockGemm,
bool IsFwd = true>
418 constexpr
auto config =
419 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
421 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
422 constexpr
index_t MWarp = config.template at<1>();
423 constexpr
index_t NWarp = config.template at<2>();
424 constexpr
bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
426 constexpr
index_t MIterPerWarp = [&]() {
427 if constexpr(MBwdWG16MultiIterCheck)
436 constexpr
index_t NIterPerWarp = 1;
448 constexpr
auto randval_block_inner_part_dstr_encoding = []() {
449 if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
450 std::is_same_v<typename BlockGemm::BDataType, half_t> &&
451 std::is_same_v<typename BlockGemm::CDataType, float>)
467 constexpr
auto randval_block_part_dstr_encode =
469 randval_block_inner_part_dstr_encoding);
474 template <
typename BlockGemm>
477 constexpr
auto config =
478 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
480 constexpr
index_t MWarp = config.template at<1>();
481 constexpr
index_t NWarp = config.template at<2>();
483 constexpr
index_t MIterPerWarp = 1;
484 constexpr
index_t NIterPerWarp = 1;
494 constexpr
auto randval_block_part_dstr_encode =
496 typename WG::CWarpDstrEncoding{});
501 template <
typename BlockGemm,
502 typename PComputeDataType,
503 typename RandValOutputDataType,
504 typename PComputeWindow,
505 typename RandValDramWindow>
509 PComputeWindow& p_compute,
510 RandValDramWindow& randval_dram_window)
const
512 constexpr
auto config =
513 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
515 constexpr
index_t MWarp = config.template at<1>();
516 constexpr
index_t NWarp = config.template at<2>();
518 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
519 constexpr
index_t kNPerBlock = BlockGemmShape::kN;
520 constexpr
index_t kMPerStep = MWarp * WG::kM;
521 constexpr
index_t kNPerStep = NWarp * WG::kN;
524 auto randval_lds = make_tensor_view<address_space_enum::lds>(
525 reinterpret_cast<uint8_t*
>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
528 randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
531 auto randval_dist_generated =
532 make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
533 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
535 auto randval_lds_read_window =
537 randval_lds_window.get_window_lengths(),
538 randval_lds_window.get_window_origin(),
539 MakeRandValLdsShuffleTileDistribution<BlockGemm>());
541 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
542 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
543 int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) +
get_warp_id();
544 int block_col_start = (start_n0_idx / WG::kN) + i_n0;
545 uint2 rowcol = make_uint2(block_row_start, block_col_start);
548 uint8_t random_uint8_t[16];
549 ph.
get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&
>(rowcol));
551 constexpr
auto randval_dist_generated_spans =
552 decltype(randval_dist_generated)::get_distributed_spans();
553 int i_random_idx = 0;
557 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
561 store_tile(randval_lds_window, randval_dist_generated);
564 auto randval =
load_tile(randval_lds_read_window);
565 constexpr
auto randval_spans = decltype(randval)::get_distributed_spans();
569 constexpr
auto p_idx1 =
573 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
574 ? p_compute[p_idx] * rp_undrop
575 : PComputeDataType(0);
579 if constexpr(IsStoreRandval)
581 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
582 store_tile(randval_dram_window, randval_store);
586 if constexpr(IsStoreRandval)
591 if constexpr(IsStoreRandval)
597 template <
typename BlockGemm,
598 typename RandValOutputDataType,
599 typename PComputeWindow,
600 typename RandValDramWindow>
603 PComputeWindow& p_compute,
604 RandValDramWindow& randval_dram_window)
const
606 constexpr
auto config =
607 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
609 constexpr
index_t MWarp = config.template at<1>();
610 constexpr
index_t NWarp = config.template at<2>();
612 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
613 constexpr
index_t kNPerBlock = BlockGemmShape::kN;
614 constexpr
bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
615 constexpr
bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
616 constexpr
index_t kMPerStep = [&]() {
617 if constexpr(MBwdWG16MultiIterCheck)
619 return MWarp * WG::kM * 2;
623 return MWarp * WG::kM;
626 constexpr
index_t kNPerStep = NWarp * WG::kN;
629 auto randval = make_static_distributed_tensor<uint8_t>(
630 MakeRandValTileDistribution<BlockGemm, false>());
632 static_assert(randval.kThreadElementSpaceSize == 16);
634 static_assert(randval.kThreadElementSpaceSize == 4 ||
635 randval.kThreadElementSpaceSize == 8);
637 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
638 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
639 int block_row_start, block_col_start;
642 block_row_start = (start_m0_idx / WG::kM) + i_m0;
643 block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) +
get_warp_id();
647 block_row_start = start_m0_idx / 32 + i_m0;
648 block_col_start = (start_n0_idx / 32) +
get_warp_id() / 2 + i_n0 * 2;
650 uint2 rowcol = make_uint2(block_row_start, block_col_start);
653 uint8_t* random_uint8_t_;
654 if constexpr(MBwdWG16SingleIterCheck)
656 uint8_t random_uint8_t[4];
662 ((
get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
664 random_uint8_t,
reinterpret_cast<unsigned long long&
>(rowcol), start_idx);
665 random_uint8_t_ = random_uint8_t;
667 else if constexpr(MBwdWG16MultiIterCheck)
669 uint8_t random_uint8_t[8];
674 random_uint8_t,
reinterpret_cast<unsigned long long&
>(rowcol), start_idx);
675 random_uint8_t_ = random_uint8_t;
679 uint8_t random_uint8_t[16];
681 reinterpret_cast<unsigned long long&
>(rowcol));
682 random_uint8_t_ = random_uint8_t;
685 constexpr
auto randval_spans = decltype(randval)::get_distributed_spans();
686 int i_random_idx = 0;
690 randval(r_idx) = random_uint8_t_[i_random_idx++];
696 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
702 if constexpr(IsStoreRandval)
704 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
705 store_tile(randval_dram_window, randval_store);
709 if constexpr(IsStoreRandval)
714 if constexpr(IsStoreRandval)
Definition: philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition: philox_rand.hpp:42
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t start_idx) const
Definition: philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t start_idx) const
Definition: philox_rand.hpp:73
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:420
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
CK_TILE_DEVICE index_t get_lane_id()
Definition: arch.hpp:69
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:27
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ static constexpr __device__ auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:306
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_)
Definition: block_dropout.hpp:325
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:387
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:342
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:475
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:722
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:416
ck_tile::philox ph
Definition: block_dropout.hpp:720
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:506
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:601
const float rp_undrop
Definition: block_dropout.hpp:721
Definition: block_dropout.hpp:296
Definition: block_dropout.hpp:26
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:291
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_)
Definition: block_dropout.hpp:27
ck_tile::philox ph
Definition: block_dropout.hpp:289
const float rp_undrop
Definition: block_dropout.hpp:290
const bool is_store_randval
Definition: block_dropout.hpp:292
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:44
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:177
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:106
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:146
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:77
Definition: block_dropout.hpp:12
__host__ static constexpr __device__ auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:15
typename WarpGemmAttribute::CWarpDstrEncoding CWarpDstrEncoding
Definition: warp_gemm_impl.hpp:29
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1443
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution.hpp:42
static constexpr auto impl_
Definition: tile_distribution.hpp:45
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192