13     template <
typename BlockGemm, 
bool IsFwd = true, 
typename RandValDramBlockWindowTmp>
 
   14     __host__ __device__ 
static constexpr 
auto 
   18         (void)randval_dram_block_window_tmp;
 
   19         (void)seqlen_qk_start;
 
   30                                      unsigned long long seed,
 
   33                                      uint8_t p_undrop_in_uint8_t_,
 
   34                                      bool is_store_randval_)
 
   42     template <
typename BlockGemm, 
bool IsFwd = true, 
typename RandValDramBlockWindowTmp>
 
   47         constexpr 
auto config =
 
   48             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
   50         constexpr 
index_t MWarp     = config.template at<1>();
 
   51         constexpr 
index_t NWarp     = config.template at<2>();
 
   52         constexpr 
index_t kMPerStep = MWarp * WG::kM;
 
   53         constexpr 
index_t kNPerStep = NWarp * WG::kN;
 
   55         const auto block_origin  = randval_dram_block_window_tmp.get_window_origin();
 
   56         auto randval_dram_window = [&]() {
 
   60                     randval_dram_block_window_tmp.get_bottom_tensor_view(),
 
   62                     {block_origin.at(number<0>{}), seqlen_qk_start}); 
 
   67                     randval_dram_block_window_tmp.get_bottom_tensor_view(),
 
   69                     {seqlen_qk_start, block_origin.at(number<1>{})}); 
 
   73         return randval_dram_window;
 
   76     template <
typename BlockGemm>
 
   79         constexpr 
auto config =
 
   80             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
   82         constexpr 
index_t MWarp     = config.template at<1>();
 
   83         constexpr 
index_t kMPerStep = MWarp * WG::kM;
 
   84         constexpr 
index_t kNPerStep = WG::kN;
 
   86         constexpr 
index_t kN0       = kNPerStep / kN1;
 
   95             randval_lds_block_desc_0,
 
  102         return randval_lds_block_desc;
 
  105     template <
typename BlockGemm>
 
  108         constexpr 
auto config =
 
  109             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  110         constexpr 
index_t MWarp = config.template at<1>();
 
  111         constexpr 
index_t NWarp = config.template at<2>();
 
  113         constexpr 
index_t MIterPerWarp = 1;
 
  114         constexpr 
index_t NIterPerWarp = 1;
 
  125         constexpr 
auto randval_block_inner_part_dstr_encoding = []() {
 
  126             if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
 
  127                          std::is_same_v<typename BlockGemm::BDataType, half_t> &&
 
  128                          std::is_same_v<typename BlockGemm::CDataType, float>)
 
  138         constexpr 
auto randval_block_part_dstr_encode =
 
  140                                                           randval_block_inner_part_dstr_encoding);
 
  145     template <
typename BlockGemm>
 
  148         constexpr 
auto config =
 
  149             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  151         constexpr 
index_t MWarp = config.template at<1>();
 
  152         constexpr 
index_t NWarp = config.template at<2>();
 
  154         constexpr 
index_t MIterPerWarp = 1;
 
  155         constexpr 
index_t NIterPerWarp = 1;
 
  165         constexpr 
auto randval_block_part_dstr_encode =
 
  167                                                           typename WG::CWarpDstrEncoding{});
 
  172     template <
typename BlockGemm,
 
  173               typename PComputeDataType,
 
  174               typename RandValOutputDataType,
 
  175               typename PComputeWindow,
 
  176               typename RandValDramWindow>
 
  179                                  PComputeWindow& p_compute,
 
  180                                  RandValDramWindow& randval_dram_window)
 const 
  182         constexpr 
auto config =
 
  183             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  185         constexpr 
index_t MWarp      = config.template at<1>();
 
  186         constexpr 
index_t NWarp      = config.template at<2>();
 
  188         constexpr 
index_t kMPerBlock = BlockGemmShape::kM;
 
  189         constexpr 
index_t kNPerBlock = BlockGemmShape::kN;
 
  190         constexpr 
index_t kMPerStep  = MWarp * WG::kM;
 
  191         constexpr 
index_t kNPerStep  = NWarp * WG::kN;
 
  194         auto randval_lds = make_tensor_view<address_space_enum::lds>(
 
  195             reinterpret_cast<uint8_t*
>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
 
  198             randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
 
  201         auto randval_dist_generated =
 
  202             make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
 
  203         static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
 
  205         auto randval_lds_read_window =
 
  207                              randval_lds_window.get_window_lengths(),
 
  208                              randval_lds_window.get_window_origin(),
 
  209                              MakeRandValLdsShuffleTileDistribution<BlockGemm>());
 
  211         const int start_m0_idx = randval_dram_window.get_window_origin().at(
number<0>{});
 
  214             static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
 
  215                 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
 
  216                     int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
 
  217                     int block_col_start = (start_n0_idx / WG::kN) + i_n0;
 
  218                     uint2 rowcol        = make_uint2(block_row_start, block_col_start);
 
  221                     uint8_t random_uint8_t[16];
 
  222                     ph.get_random_16x8(random_uint8_t,
 
  223                                        reinterpret_cast<unsigned long long&
>(rowcol));
 
  225                     constexpr 
auto randval_dist_generated_spans =
 
  226                         decltype(randval_dist_generated)::get_distributed_spans();
 
  227                     int i_random_idx = 0;
 
  231                             randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
 
  235                     store_tile(randval_lds_window, randval_dist_generated);
 
  238                     auto randval = 
load_tile(randval_lds_read_window);
 
  240                     const auto randval_store = cast_tile<RandValOutputDataType>(randval);
 
  241                     store_tile(randval_dram_window, randval_store);
 
  248         static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
 
  249             static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
 
  250                 int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
 
  251                 int block_col_start = (start_n0_idx / WG::kN) + i_n0;
 
  252                 uint2 rowcol        = make_uint2(block_row_start, block_col_start);
 
  255                 uint8_t random_uint8_t[16];
 
  256                 ph.get_random_16x8(random_uint8_t, 
reinterpret_cast<unsigned long long&
>(rowcol));
 
  258                 constexpr 
auto randval_dist_generated_spans =
 
  259                     decltype(randval_dist_generated)::get_distributed_spans();
 
  260                 int i_random_idx = 0;
 
  264                         randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
 
  268                 store_tile(randval_lds_window, randval_dist_generated);
 
  271                 auto randval                 = 
load_tile(randval_lds_read_window);
 
  272                 constexpr 
auto randval_spans = decltype(randval)::get_distributed_spans();
 
  276                         constexpr 
auto p_idx1 =
 
  280                         p_compute(p_idx)     = randval[r_idx] <= p_undrop_in_uint8_t
 
  281                                                    ? p_compute[p_idx] * rp_undrop
 
  282                                                    : PComputeDataType(0);
 
  295 template <
bool IsDropout_, 
bool IsWG32_, 
bool IsStoreRandval_>
 
  298 template <
bool IsWG32_, 
bool IsStoreRandval_>
 
  301     static constexpr 
bool IsDropout      = 
false;
 
  302     static constexpr 
bool IsStoreRandval = IsStoreRandval_;
 
  304     template <
typename BlockGemm, 
bool IsFwd = true, 
typename RandValDramBlockWindowTmp>
 
  305     __host__ __device__ 
static constexpr 
auto 
  309         (void)randval_dram_block_window_tmp;
 
  310         (void)seqlen_qk_start;
 
  316 template <
bool IsWG32_, 
bool IsStoreRandval_>
 
  319     static constexpr 
bool IsDropout = 
true;
 
  322     static constexpr 
bool IsWG32         = IsWG32_;
 
  323     static constexpr 
bool IsStoreRandval = IsStoreRandval_;
 
  328                                         unsigned long long seed,
 
  329                                         unsigned long long offset,
 
  331                                         uint8_t p_undrop_in_uint8_t_)
 
  334                  (IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
 
  335           rp_undrop(rp_undrop_),
 
  336           p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
 
  340     template <
typename BlockGemm, 
bool IsFwd = true, 
typename RandValDramBlockWindowTmp>
 
  345         constexpr 
auto config =
 
  346             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  349         constexpr 
index_t kMPerBlock          = BlockGemmShape::kM;
 
  350         constexpr 
index_t MWarp               = config.template at<1>();
 
  351         constexpr 
index_t NWarp               = config.template at<2>();
 
  352         constexpr 
bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
 
  353         constexpr 
index_t kMPerStep           = [&]() {
 
  354             if constexpr(MBwdWG16MultiIterCheck)
 
  356                 return MWarp * WG::kM * 2;
 
  360                 return MWarp * WG::kM;
 
  363         constexpr 
index_t kNPerStep = NWarp * WG::kN;
 
  365         const auto block_origin  = randval_dram_block_window_tmp.get_window_origin();
 
  366         auto randval_dram_window = [&]() {
 
  370                     randval_dram_block_window_tmp.get_bottom_tensor_view(),
 
  372                     {block_origin.at(number<0>{}), seqlen_qk_start}); 
 
  377                     randval_dram_block_window_tmp.get_bottom_tensor_view(),
 
  379                     {seqlen_qk_start, block_origin.at(number<1>{})}); 
 
  383         return randval_dram_window;
 
  386     template <
typename BlockGemm>
 
  389         constexpr 
auto config =
 
  390             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  392         constexpr 
index_t MWarp     = config.template at<1>();
 
  393         constexpr 
index_t kMPerStep = MWarp * WG::kM;
 
  394         constexpr 
index_t kNPerStep = WG::kN;
 
  396         constexpr 
index_t kN0       = kNPerStep / kN1;
 
  405             randval_lds_block_desc_0,
 
  412         return randval_lds_block_desc;
 
  415     template <
typename BlockGemm, 
bool IsFwd = true>
 
  418         constexpr 
auto config =
 
  419             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  421         constexpr 
index_t kMPerBlock          = BlockGemmShape::kM;
 
  422         constexpr 
index_t MWarp               = config.template at<1>();
 
  423         constexpr 
index_t NWarp               = config.template at<2>();
 
  424         constexpr 
bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
 
  426         constexpr 
index_t MIterPerWarp = [&]() {
 
  427             if constexpr(MBwdWG16MultiIterCheck)
 
  436         constexpr 
index_t NIterPerWarp = 1;
 
  448         constexpr 
auto randval_block_inner_part_dstr_encoding = []() {
 
  449             if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
 
  450                          std::is_same_v<typename BlockGemm::BDataType, half_t> &&
 
  451                          std::is_same_v<typename BlockGemm::CDataType, float>)
 
  467         constexpr 
auto randval_block_part_dstr_encode =
 
  469                                                           randval_block_inner_part_dstr_encoding);
 
  474     template <
typename BlockGemm>
 
  477         constexpr 
auto config =
 
  478             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  480         constexpr 
index_t MWarp = config.template at<1>();
 
  481         constexpr 
index_t NWarp = config.template at<2>();
 
  483         constexpr 
index_t MIterPerWarp = 1;
 
  484         constexpr 
index_t NIterPerWarp = 1;
 
  494         constexpr 
auto randval_block_part_dstr_encode =
 
  496                                                           typename WG::CWarpDstrEncoding{});
 
  501     template <
typename BlockGemm,
 
  502               typename PComputeDataType,
 
  503               typename RandValOutputDataType,
 
  504               typename PComputeWindow,
 
  505               typename RandValDramWindow>
 
  509                                  PComputeWindow& p_compute,
 
  510                                  RandValDramWindow& randval_dram_window)
 const 
  512         constexpr 
auto config =
 
  513             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  515         constexpr 
index_t MWarp      = config.template at<1>();
 
  516         constexpr 
index_t NWarp      = config.template at<2>();
 
  518         constexpr 
index_t kMPerBlock = BlockGemmShape::kM;
 
  519         constexpr 
index_t kNPerBlock = BlockGemmShape::kN;
 
  520         constexpr 
index_t kMPerStep  = MWarp * WG::kM;
 
  521         constexpr 
index_t kNPerStep  = NWarp * WG::kN;
 
  524         auto randval_lds = make_tensor_view<address_space_enum::lds>(
 
  525             reinterpret_cast<uint8_t*
>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
 
  528             randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
 
  531         auto randval_dist_generated =
 
  532             make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
 
  533         static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
 
  535         auto randval_lds_read_window =
 
  537                              randval_lds_window.get_window_lengths(),
 
  538                              randval_lds_window.get_window_origin(),
 
  539                              MakeRandValLdsShuffleTileDistribution<BlockGemm>());
 
  541         static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
 
  542             static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
 
  543                 int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
 
  544                 int block_col_start = (start_n0_idx / WG::kN) + i_n0;
 
  545                 uint2 rowcol        = make_uint2(block_row_start, block_col_start);
 
  548                 uint8_t random_uint8_t[16];
 
  549                 ph.
get_random_16x8(random_uint8_t, 
reinterpret_cast<unsigned long long&
>(rowcol));
 
  551                 constexpr 
auto randval_dist_generated_spans =
 
  552                     decltype(randval_dist_generated)::get_distributed_spans();
 
  553                 int i_random_idx = 0;
 
  557                         randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
 
  561                 store_tile(randval_lds_window, randval_dist_generated);
 
  564                 auto randval                 = 
load_tile(randval_lds_read_window);
 
  565                 constexpr 
auto randval_spans = decltype(randval)::get_distributed_spans();
 
  569                         constexpr 
auto p_idx1 =
 
  573                         p_compute(p_idx)     = randval[r_idx] <= p_undrop_in_uint8_t
 
  574                                                    ? p_compute[p_idx] * rp_undrop
 
  575                                                    : PComputeDataType(0);
 
  579                 if constexpr(IsStoreRandval)
 
  581                     const auto randval_store = cast_tile<RandValOutputDataType>(randval);
 
  582                     store_tile(randval_dram_window, randval_store);
 
  586             if constexpr(IsStoreRandval)
 
  591         if constexpr(IsStoreRandval)
 
  597     template <
typename BlockGemm,
 
  598               typename RandValOutputDataType,
 
  599               typename PComputeWindow,
 
  600               typename RandValDramWindow>
 
  603                                  PComputeWindow& p_compute,
 
  604                                  RandValDramWindow& randval_dram_window)
 const 
  606         constexpr 
auto config =
 
  607             BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
 
  609         constexpr 
index_t MWarp                = config.template at<1>();
 
  610         constexpr 
index_t NWarp                = config.template at<2>();
 
  612         constexpr 
index_t kMPerBlock           = BlockGemmShape::kM;
 
  613         constexpr 
index_t kNPerBlock           = BlockGemmShape::kN;
 
  614         constexpr 
bool MBwdWG16MultiIterCheck  = (!IsWG32) && (kMPerBlock > 16);
 
  615         constexpr 
bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
 
  616         constexpr 
index_t kMPerStep            = [&]() {
 
  617             if constexpr(MBwdWG16MultiIterCheck)
 
  619                 return MWarp * WG::kM * 2;
 
  623                 return MWarp * WG::kM;
 
  626         constexpr 
index_t kNPerStep = NWarp * WG::kN;
 
  629         auto randval = make_static_distributed_tensor<uint8_t>(
 
  630             MakeRandValTileDistribution<BlockGemm, false>());
 
  632             static_assert(randval.kThreadElementSpaceSize == 16);
 
  634             static_assert(randval.kThreadElementSpaceSize == 4 ||
 
  635                           randval.kThreadElementSpaceSize == 8);
 
  637         static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
 
  638             static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
 
  639                 int block_row_start, block_col_start;
 
  642                     block_row_start = (start_m0_idx / WG::kM) + i_m0;
 
  643                     block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
 
  647                     block_row_start = start_m0_idx / 32 + i_m0;
 
  648                     block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
 
  650                 uint2 rowcol = make_uint2(block_row_start, block_col_start);
 
  653                 uint8_t* random_uint8_t_;
 
  654                 if constexpr(MBwdWG16SingleIterCheck)
 
  656                     uint8_t random_uint8_t[4];
 
  662                         ((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
 
  664                         random_uint8_t, 
reinterpret_cast<unsigned long long&
>(rowcol), start_idx);
 
  665                     random_uint8_t_ = random_uint8_t;
 
  667                 else if constexpr(MBwdWG16MultiIterCheck)
 
  669                     uint8_t random_uint8_t[8];
 
  672                     const index_t start_idx = (get_lane_id() >> 4) & 1;
 
  674                         random_uint8_t, 
reinterpret_cast<unsigned long long&
>(rowcol), start_idx);
 
  675                     random_uint8_t_ = random_uint8_t;
 
  679                     uint8_t random_uint8_t[16];
 
  681                                        reinterpret_cast<unsigned long long&
>(rowcol));
 
  682                     random_uint8_t_ = random_uint8_t;
 
  685                 constexpr 
auto randval_spans = decltype(randval)::get_distributed_spans();
 
  686                 int i_random_idx             = 0;
 
  690                         randval(r_idx)        = random_uint8_t_[i_random_idx++];
 
  696                         p_compute(p_idx)      = randval[r_idx] <= p_undrop_in_uint8_t
 
  702                 if constexpr(IsStoreRandval)
 
  704                     const auto randval_store = cast_tile<RandValOutputDataType>(randval);
 
  705                     store_tile(randval_dram_window, randval_store);
 
  709             if constexpr(IsStoreRandval)
 
  714         if constexpr(IsStoreRandval)
 
Definition: philox_rand.hpp:12
 
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition: philox_rand.hpp:42
 
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t start_idx) const
Definition: philox_rand.hpp:56
 
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t start_idx) const
Definition: philox_rand.hpp:73
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
 
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:559
 
Definition: cluster_descriptor.hpp:13
 
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
 
int32_t index_t
Definition: integer.hpp:9
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
 
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
 
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
 
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
 
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
 
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
__host__ static constexpr __device__ auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:306
 
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_)
Definition: block_dropout.hpp:325
 
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:387
 
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:342
 
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:475
 
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:722
 
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:416
 
ck_tile::philox ph
Definition: block_dropout.hpp:720
 
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:506
 
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:601
 
const float rp_undrop
Definition: block_dropout.hpp:721
 
Definition: block_dropout.hpp:296
 
Definition: block_dropout.hpp:26
 
const uint8_t p_undrop_in_uint8_t
Definition: block_dropout.hpp:291
 
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_)
Definition: block_dropout.hpp:27
 
ck_tile::philox ph
Definition: block_dropout.hpp:289
 
const float rp_undrop
Definition: block_dropout.hpp:290
 
const bool is_store_randval
Definition: block_dropout.hpp:292
 
static constexpr CK_TILE_HOST_DEVICE auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:44
 
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition: block_dropout.hpp:177
 
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValTileDistribution()
Definition: block_dropout.hpp:106
 
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsShuffleTileDistribution()
Definition: block_dropout.hpp:146
 
static constexpr CK_TILE_HOST_DEVICE auto MakeRandValLdsBlockDescriptor()
Definition: block_dropout.hpp:77
 
Definition: block_dropout.hpp:12
 
__host__ static constexpr __device__ auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition: block_dropout.hpp:15
 
typename WarpGemmAttribute::CWarpDstrEncoding CWarpDstrEncoding
Definition: warp_gemm_impl.hpp:30
 
Definition: integral_constant.hpp:13
 
Definition: coordinate_transform.hpp:1443
 
Definition: sequence.hpp:52
 
Definition: functional.hpp:43
 
Definition: tile_distribution.hpp:42
 
static constexpr auto impl_
Definition: tile_distribution.hpp:45
 
Definition: tile_distribution_encoding.hpp:26
 
Definition: tuple.hpp:192