40 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
45 (void)randval_dram_block_window_tmp;
46 (void)seqlen_qk_start;
57 unsigned long long seed,
61 bool is_store_randval_)
71 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
76 constexpr
auto config =
77 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
79 constexpr
bool IsWG32 = WG::kM == 32;
80 constexpr
index_t MWarp = config.template at<1>();
81 constexpr
index_t NWarp = config.template at<2>();
83 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
84 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
85 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
86 constexpr
index_t kNPerStep = NWarp * WG::kN;
88 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
89 auto randval_dram_window = [&]() {
93 randval_dram_block_window_tmp.get_bottom_tensor_view(),
95 {block_origin.at(number<0>{}), seqlen_qk_start});
100 randval_dram_block_window_tmp.get_bottom_tensor_view(),
102 {seqlen_qk_start, block_origin.at(number<1>{})});
106 return randval_dram_window;
109 template <
typename BlockGemm>
112 constexpr
auto config =
113 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
115 constexpr
bool IsWG32 = WG::kM == 32;
116 constexpr
index_t MWarp = config.template at<1>();
117 constexpr
index_t NWarp = config.template at<2>();
119 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
120 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
121 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
122 constexpr
index_t kNPerStep = NWarp * WG::kN;
124 constexpr
index_t kN0 = kNPerStep / kN1;
133 randval_lds_block_desc_0,
140 return randval_lds_block_desc;
143 template <
typename BlockGemm>
146 constexpr
auto config =
147 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
149 constexpr
bool IsWG32 = WG::kM == 32;
150 constexpr
index_t MWarp = config.template at<1>();
151 constexpr
index_t NWarp = config.template at<2>();
153 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
154 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
155 constexpr
index_t NIterPerWarp = 1;
168 constexpr
auto randval_block_inner_part_dstr_encoding =
170 typename WG::BDataType,
171 typename WG::CDataType,
176 IsWG32>::CWarpDstrEncoding{};
178 constexpr
auto randval_block_part_dstr_encode =
180 randval_block_inner_part_dstr_encoding);
185 template <
typename BlockGemm>
188 constexpr
auto config =
189 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
191 constexpr
bool IsWG32 = WG::kM == 32;
192 constexpr
index_t MWarp = config.template at<1>();
193 constexpr
index_t NWarp = config.template at<2>();
195 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
196 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
197 constexpr
index_t NIterPerWarp = 1;
207 constexpr
auto randval_block_part_dstr_encode =
209 typename WG::CWarpDstrEncoding{});
214 template <
typename BlockGemm,
215 typename PComputeDataType,
216 typename RandValOutputDataType,
217 typename PComputeWindow,
218 typename RandValDramWindow>
221 PComputeWindow& p_compute,
222 RandValDramWindow& randval_dram_window)
const
224 constexpr
auto config =
225 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
227 constexpr
bool IsWG32 = WG::kM == 32;
228 constexpr
index_t MWarp = config.template at<1>();
229 constexpr
index_t NWarp = config.template at<2>();
231 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
232 constexpr
index_t kNPerBlock = BlockGemmShape::kN;
233 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
234 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
235 constexpr
index_t kNPerStep = NWarp * WG::kN;
238 auto randval_lds = make_tensor_view<address_space_enum::lds>(
239 reinterpret_cast<uint8_t*
>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
242 randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
245 auto randval_dist_generated =
246 make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
248 const auto randval_lds_read_window =
250 randval_lds_window.get_window_lengths(),
251 randval_lds_window.get_window_origin(),
252 MakeRandValLdsShuffleTileDistribution<BlockGemm>());
254 const index_t start_m0_idx = randval_dram_window.get_window_origin().at(
number<0>{});
255 const index_t iMWarp = get_warp_id() / NWarp;
256 const index_t iNWarp = get_warp_id() % NWarp;
258 auto generate_randval = [&](
auto i_m0,
auto i_n0) {
260 uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
261 const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
262 const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
267 const unsigned long long ph_subsequence =
268 bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
269 const index_t ph_offset = get_lane_id();
271 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
278 const unsigned long long ph_subsequence =
279 bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
280 const index_t subtile_m0 = wg_m0 % 2;
283 const index_t ph_offset = (get_lane_id() & 15) +
284 (((get_lane_id() >> 4) & 1) << 5) +
287 if constexpr(MIterPerWarp == 1)
289 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
291 random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
295 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
301 const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
302 const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
304 if constexpr(MIterPerWarp == 1)
306 static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
308 random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
312 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
314 random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
319 constexpr
auto randval_dist_generated_spans =
320 decltype(randval_dist_generated)::get_distributed_spans();
321 int i_random_idx = 0;
325 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
329 store_tile(randval_lds_window, randval_dist_generated);
331 const auto randval =
load_tile(randval_lds_read_window);
336 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
337 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
338 const auto randval = generate_randval(i_m0, i_n0);
341 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
342 store_tile(randval_dram_window, randval_store);
346 constexpr
auto randval_spans = decltype(randval)::get_distributed_spans();
349 constexpr
auto p_idx0 =
351 idx0.
impl_.template at<0>()>{};
352 constexpr
auto p_idx1 =
354 idx1.
impl_.template at<1>(),
355 idx1.impl_.template at<2>()>{};
358 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
359 ? p_compute[p_idx] * rp_undrop
360 : PComputeDataType(0);
378 template <
bool IsDropout_,
bool IsWG32_,
bool IsStoreRandval_>
381 template <
bool IsWG32_,
bool IsStoreRandval_>
384 static constexpr
bool IsDropout =
false;
385 static constexpr
bool IsStoreRandval = IsStoreRandval_;
387 template <
typename BlockGemm,
bool IsFwd = false,
typename RandValDramBlockWindowTmp>
392 (void)randval_dram_block_window_tmp;
393 (void)seqlen_qk_start;
399 template <
bool IsWG32_,
bool IsStoreRandval_>
402 static constexpr
bool IsDropout =
true;
403 static constexpr
bool IsStoreRandval = IsStoreRandval_;
408 unsigned long long seed,
409 unsigned long long offset,
415 rp_undrop(rp_undrop_),
416 p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
420 template <
typename BlockGemm,
bool IsFwd = false,
typename RandValDramBlockWindowTmp>
425 constexpr
auto config =
426 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
428 constexpr
bool IsWG32 = WG::kM == 32;
429 constexpr
index_t MWarp = config.template at<1>();
430 constexpr
index_t NWarp = config.template at<2>();
432 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
433 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
434 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
435 constexpr
index_t kNPerStep = NWarp * WG::kN;
437 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
438 auto randval_dram_window = [&]() {
442 randval_dram_block_window_tmp.get_bottom_tensor_view(),
444 {block_origin.at(number<0>{}), seqlen_qk_start});
449 randval_dram_block_window_tmp.get_bottom_tensor_view(),
451 {seqlen_qk_start, block_origin.at(number<1>{})});
455 return randval_dram_window;
458 template <
typename BlockGemm>
461 constexpr
auto config =
462 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
464 constexpr
bool IsWG32 = WG::kM == 32;
465 constexpr
index_t MWarp = config.template at<1>();
466 constexpr
index_t NWarp = config.template at<2>();
468 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
469 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
470 constexpr
index_t NIterPerWarp = 1;
480 constexpr
auto randval_block_inner_part_dstr_encoding =
482 typename WG::BDataType,
483 typename WG::CDataType,
488 IsWG32>::CWarpDstrEncoding{};
491 typename WG::CWarpDstrEncoding>);
493 constexpr
auto randval_block_part_dstr_encode =
495 randval_block_inner_part_dstr_encoding);
500 template <
typename BlockGemm,
501 typename RandValOutputDataType,
502 typename PComputeWindow,
503 typename RandValDramWindow>
506 PComputeWindow& p_compute,
507 RandValDramWindow& randval_dram_window)
const
509 constexpr
auto config =
510 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
512 constexpr
bool IsWG32 = WG::kM == 32;
513 constexpr
index_t MWarp = config.template at<1>();
514 constexpr
index_t NWarp = config.template at<2>();
516 constexpr
index_t kMPerBlock = BlockGemmShape::kM;
517 constexpr
index_t kNPerBlock = BlockGemmShape::kN;
518 constexpr
index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
519 constexpr
index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
520 constexpr
index_t kNPerStep = NWarp * WG::kN;
523 auto randval_dist_generated =
524 make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
526 const index_t iMWarp = get_warp_id() / NWarp;
527 const index_t iNWarp = get_warp_id() % NWarp;
529 auto generate_randval = [&](
auto i_m0,
auto i_n0) {
531 uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
532 const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
533 const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
538 const unsigned long long ph_subsequence =
539 bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
540 const index_t ph_offset = get_lane_id();
542 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
549 const unsigned long long ph_subsequence =
550 bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
551 const index_t subtile_m0 = wg_m0 % 2;
554 const index_t ph_offset = (get_lane_id() & 15) +
555 (((get_lane_id() >> 4) & 1) << 5) +
558 if constexpr(MIterPerWarp == 1)
560 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
562 random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
566 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
572 const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
573 const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
575 if constexpr(MIterPerWarp == 1)
577 static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
579 random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
583 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
585 random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
590 constexpr
auto randval_dist_generated_spans =
591 decltype(randval_dist_generated)::get_distributed_spans();
592 int i_random_idx = 0;
596 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
599 return randval_dist_generated;
602 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
603 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
604 const auto randval = generate_randval(i_m0, i_n0);
607 constexpr
auto randval_spans = decltype(randval)::get_distributed_spans();
611 constexpr
auto p_idx0 =
613 idx0.
impl_.template at<0>(),
614 idx0.impl_.template at<1>(),
615 idx0.impl_.template at<2>()>{};
618 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
624 if constexpr(IsStoreRandval)
626 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
627 store_tile(randval_dram_window, randval_store);
631 if constexpr(IsStoreRandval)
636 if constexpr(IsStoreRandval)
Definition: philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t idx) const
Definition: philox_rand.hpp:75
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t idx0, const index_t idx1) const
Definition: philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition: philox_rand.hpp:42
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
constexpr index_t philox_per_tile
Definition: block_dropout.hpp:35
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::warp_gemm_dispatcher::Dispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:177
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
unsigned char uint8_t
Definition: stdint.h:124
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:389
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_)
Definition: block_dropout.hpp:405
const unsigned long long ph_seed
Definition: block_dropout.hpp:642
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:459
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:422
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:645
const unsigned long long ph_head_offset
Definition: block_dropout.hpp:643
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:504
const float rp_undrop
Definition: block_dropout.hpp:644
Definition: block_dropout.hpp:379
Definition: block_dropout.hpp:53
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:372
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_)
Definition: block_dropout.hpp:54
const float rp_undrop
Definition: block_dropout.hpp:371
const unsigned long long ph_head_offset
Definition: block_dropout.hpp:370
const bool is_store_randval
Definition: block_dropout.hpp:373
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:73
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:219
const unsigned long long ph_seed
Definition: block_dropout.hpp:369
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:144
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:186
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:110
Definition: block_dropout.hpp:39
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:42
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tile_distribution.hpp:40
static constexpr auto impl_
Definition: tile_distribution.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192