25 template <
typename SrcData,
29 typename ElementwiseOperation,
30 typename SliceLengths,
31 typename DimAccessOrder,
35 index_t DstScalarStrideInVector,
36 bool DstResetCoordinateAfterRun,
37 typename enable_if<SrcDesc::IsKnownAtCompileTime(),
bool>::type =
false>
49 const Index& dst_slice_origin_idx,
50 const ElementwiseOperation& element_op)
52 element_op_{element_op}
54 static_assert(SrcDesc::IsKnownAtCompileTime(),
55 "wrong! SrcDesc need to known at compile-time");
57 "wrong! Not divisible");
65 template <
typename SrcSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
66 __device__
void Run(
const SrcDesc&,
67 const SrcSliceOriginIdx&,
68 const SrcBuffer& src_buf,
69 const DstDesc& dst_desc,
72 static_assert(SrcDesc::IsKnownAtCompileTime(),
73 "wrong! SrcDesc need to known at compile-time");
76 "wrong! SrcSliceOrigin need to known at compile-time");
78 static_assert(SrcBuffer::IsStaticBuffer(),
"wrong! SrcBuffer need to be StaticBuffer");
82 constexpr
auto src_slice_origin_idx =
to_multi_index(SrcSliceOriginIdx{});
89 constexpr
auto dst_scalar_step_in_vector =
98 "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
110 constexpr
index_t src_offset = src_desc.CalculateOffset(
111 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
118 dst_vector.template AsType<DstData>()(i) = v;
121 const bool is_dst_valid =
125 dst_buf.template Update<DstInMemOp, dst_vector_t>(
126 dst_coord_.GetOffset(),
128 dst_vector.template AsType<dst_vector_t>()[
Number<0>{}]);
130 if constexpr(idx_1d.value != num_access - 1)
140 if constexpr(DstResetCoordinateAfterRun)
142 const auto dst_reset_step =
159 if constexpr(num_access == 0)
165 constexpr
auto reset_step =
174 const Index& dst_slice_origin_step_idx)
177 const auto adjusted_step_idx =
178 DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
189 const ElementwiseOperation element_op_;
201 template <
typename SrcData,
205 typename SliceLengths,
206 typename DimAccessOrder,
209 index_t SrcScalarStrideInVector,
210 bool SrcResetCoordinateAfterRun,
211 bool InvalidElementAsNaN =
false,
212 typename enable_if<DstDesc::IsKnownAtCompileTime(),
bool>::type =
false>
216 (!InvalidElementAsNaN),
217 "Filling invalid element as NaN is only for floating point types");
228 const Index& src_slice_origin_idx)
231 static_assert(DstDesc::IsKnownAtCompileTime(),
232 "wrong! SrcDesc need to known at compile-time");
234 "wrong! Not divisible");
242 template <
typename SrcBuffer,
typename DstBuffer,
typename DstSliceOriginIdx>
243 __device__
void Run(
const SrcDesc& src_desc,
244 const SrcBuffer& src_buf,
246 const DstSliceOriginIdx&,
249 static_assert(DstDesc::IsKnownAtCompileTime(),
250 "wrong! DstDesc need to known at compile-time");
253 "wrong! DstSliceOrigin need to known at compile-time");
257 "wrong! inconsistent type");
261 constexpr
auto dst_slice_origin_idx = DstSliceOriginIdx{};
268 constexpr
auto src_scalar_step_in_vector =
285 const bool is_src_valid =
289 src_vector.template AsType<src_vector_t>()(
Number<0>{}) =
290 src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
295 dst_desc.CalculateOffset(
to_multi_index(dst_slice_origin_idx) + src_data_idx +
296 i * src_scalar_step_in_vector);
298 if constexpr(InvalidElementAsNaN)
302 ? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
308 type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
312 if constexpr(idx_1d.value != num_access - 1)
322 if constexpr(SrcResetCoordinateAfterRun)
324 const auto src_reset_step =
341 if constexpr(num_access == 0)
347 constexpr
auto reset_step =
356 const Index& src_slice_origin_step_idx)
359 const auto adjusted_step_idx =
360 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
370 template <
typename SrcMoveSliceWindowStepHack>
373 const Index& src_slice_origin_step_idx,
374 const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
377 const auto adjusted_step_idx =
378 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
383 src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
397 template <
typename SliceLengths,
403 typename SrcDimAccessOrder,
404 typename DstDimAccessOrder,
409 index_t SrcScalarStrideInVector,
410 index_t DstScalarStrideInVector,
411 bool SrcResetCoordinateAfterRun,
414 bool DstResetCoordinateAfterRun>
429 const Index& src_slice_origin,
430 const DstDesc& dst_desc,
431 const Index& dst_slice_origin)
436 "wrong! Not divisible");
438 "wrong! Not divisible");
451 template <
typename SrcBuffer,
typename SrcStepHacks>
453 RunRead(
const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const SrcStepHacks& src_step_hacks)
461 "wrong! SrcBuffer and SrcData data type are inconsistent");
471 constexpr
auto src_scalar_step_in_vector =
474 constexpr
auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
476 constexpr
auto src_dim_access_order = SrcDimAccessOrder{};
478 constexpr
auto ordered_src_access_lengths =
484 Index forward_step_idx;
487 forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
491 src_desc, forward_step_idx, src_step_hacks[I0][i]);
498 Index backward_step_idx;
501 backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
505 src_desc, backward_step_idx, src_step_hacks[I1][i]);
510 static_ford<decltype(ordered_src_access_lengths)>{}([&](
auto ordered_src_access_idx) {
512 constexpr
auto forward_sweep = [&]() {
515 forward_sweep_(I0) =
true;
518 index_t tmp = ordered_src_access_idx[I0];
521 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
524 forward_sweep_(i) = tmp % 2 == 0;
527 return forward_sweep_;
531 constexpr
auto src_data_idx = [&]() {
535 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
536 : ordered_src_access_lengths[i] - 1 -
537 ordered_src_access_idx[i];
541 src_scalar_per_access;
546 using src_vector_t =
typename decltype(src_tmp_vector)::type;
548 const bool is_src_valid =
552 src_tmp_vector.template AsType<src_vector_t>()(
Number<0>{}) =
553 src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
557 constexpr
index_t buffer_offset =
558 buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
563 constexpr
auto move_on_dim = [&]() constexpr
568 move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
572 ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
582 if constexpr(move_on_dim[i])
584 if constexpr(forward_sweep[i])
587 src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
592 src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
599 if constexpr(SrcResetCoordinateAfterRun)
601 const auto src_reset_step =
608 template <
typename DstBuffer,
typename DstStepHacks>
610 RunWrite(
const DstDesc& dst_desc, DstBuffer& dst_buf,
const DstStepHacks& dst_step_hacks)
618 "wrong! SrcBuffer or DstBuffer data type is wrong");
628 constexpr
auto dst_scalar_step_in_vector =
631 constexpr
auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
633 constexpr
auto dst_dim_access_order = DstDimAccessOrder{};
635 constexpr
auto ordered_dst_access_lengths =
641 Index forward_step_idx;
644 forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
648 dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
655 Index backward_step_idx;
658 backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
662 dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
667 static_ford<decltype(ordered_dst_access_lengths)>{}([&](
auto ordered_dst_access_idx) {
669 constexpr
auto forward_sweep = [&]() {
672 forward_sweep_(I0) =
true;
675 index_t tmp = ordered_dst_access_idx[I0];
678 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
681 forward_sweep_(i) = tmp % 2 == 0;
684 return forward_sweep_;
688 constexpr
auto dst_data_idx = [&]() {
692 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
693 : ordered_dst_access_lengths[i] - 1 -
694 ordered_dst_access_idx[i];
698 dst_scalar_per_access;
705 constexpr
index_t buffer_offset =
706 buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
708 dst_tmp_vector.template AsType<DstData>()(i) =
712 using dst_vector_t =
typename decltype(dst_tmp_vector)::type;
715 const bool is_dst_valid =
718 dst_buf.template Set<dst_vector_t>(
719 dst_coord_.GetOffset(),
721 dst_tmp_vector.template AsType<dst_vector_t>()[
Number<0>{}]);
723 constexpr
auto move_on_dim = [&]() constexpr
728 move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
732 ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
742 if constexpr(move_on_dim[i])
744 if constexpr(forward_sweep[i])
747 dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
752 dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
759 if constexpr(DstResetCoordinateAfterRun)
761 const auto dst_reset_step =
768 template <
typename SrcBuffer>
769 __device__
void RunRead(
const SrcDesc& src_desc,
const SrcBuffer& src_buf)
771 constexpr
index_t ntransform_src = SrcDesc::GetNumOfTransform();
775 constexpr
auto src_step_hacks =
779 RunRead(src_desc, src_buf, src_step_hacks);
782 template <
typename DstBuffer>
783 __device__
void RunWrite(
const DstDesc& dst_desc, DstBuffer& dst_buf)
785 constexpr
index_t ntransform_dst = DstDesc::GetNumOfTransform();
789 constexpr
auto dst_step_hacks =
793 RunWrite(dst_desc, dst_buf, dst_step_hacks);
805 constexpr
auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
807 constexpr
auto src_dim_access_order = SrcDimAccessOrder{};
809 constexpr
auto ordered_src_access_lengths =
813 constexpr
auto forward_sweep = [&]() {
816 forward_sweep_(I0) =
true;
819 index_t tmp = ordered_src_access_lengths[I0] - 1;
822 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
825 forward_sweep_(i) = tmp % 2 == 0;
828 return forward_sweep_;
833 constexpr
auto src_data_idx = [&]() {
837 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
841 src_scalar_per_access;
845 constexpr
auto reset_src_data_step = [&]() {
846 Index reset_src_data_step_;
850 return reset_src_data_step_;
853 return reset_src_data_step;
865 constexpr
auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
867 constexpr
auto dst_dim_access_order = DstDimAccessOrder{};
869 constexpr
auto ordered_dst_access_lengths =
873 constexpr
auto forward_sweep = [&]() {
876 forward_sweep_(I0) =
true;
879 index_t tmp = ordered_dst_access_lengths[I0] - 1;
882 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
885 forward_sweep_(i) = tmp % 2 == 0;
888 return forward_sweep_;
893 constexpr
auto dst_data_idx = [&]() {
897 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
901 dst_scalar_per_access;
905 constexpr
auto reset_dst_data_step = [&]() {
906 Index reset_dst_data_step_;
910 return reset_dst_data_step_;
913 return reset_dst_data_step;
918 const Index& src_slice_origin_step_idx)
921 const auto adjusted_step_idx =
922 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
932 template <
typename SrcMoveSliceWindowStepHack>
935 const Index& src_slice_origin_step_idx,
936 const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
939 const auto adjusted_step_idx =
940 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
945 src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
951 const Index& dst_slice_origin_step_idx)
954 const auto adjusted_step_idx =
955 DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
965 static constexpr
auto buffer_desc_ =
968 static constexpr
auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
970 StaticBuffer<AddressSpaceEnum::Vgpr, SrcData, buffer_size_, true> buffer_;
989 template <
typename SrcData,
993 typename SliceLengths,
994 typename DimAccessOrder,
997 index_t SrcScalarStrideInVector,
998 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1020 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1021 "wrong! SrcDesc and DstDesc need to known at compile-time");
1024 "wrong! Not divisible");
1028 static_assert(SrcScalarPerVector %
PackedSize == 0,
"pk data N cannot be 1");
1032 template <
typename SrcRefToOriginDisplacement,
1033 typename DstOriginIdx,
1036 __device__
void Run(
const SrcDesc&,
1037 const SrcRefToOriginDisplacement&,
1038 const SrcBuffer& src_buf,
1040 const DstOriginIdx&,
1041 DstBuffer& dst_buf)
const
1043 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1044 "wrong! SrcDesc and DstDesc need to known at compile-time");
1049 "wrong! SrcBuffer or DstBuffer data type is wrong");
1051 static_assert(DstBuffer::IsStaticBuffer(),
"wrong! DstBuffer need to be StaticBuffer");
1055 "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
1063 constexpr
auto src_ref_to_origin_disp_idx =
to_multi_index(SrcRefToOriginDisplacement{});
1068 [&](
auto i) constexpr {
1069 if constexpr(i == SrcVectorDim)
1082 [&](
auto i) constexpr {
1083 if constexpr(i == SrcVectorDim)
1094 constexpr
auto access_lengths = SliceLengths{} / src_scalar_per_access;
1096 constexpr
auto dim_access_order = DimAccessOrder{};
1098 constexpr
auto ordered_access_lengths =
1101 static_ford<decltype(ordered_access_lengths)>{}([&](
auto ordered_access_idx) {
1105 constexpr
auto data_to_origin_disp_idx =
1107 src_scalar_per_access;
1110 constexpr
auto data_to_origin_disp_idx =
1111 ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
1114 constexpr
auto src_ref_to_data_disp_idx =
1115 src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
1117 constexpr
auto src_ref_to_data_disp_coord_step =
1120 auto src_data_coord = src_ref_coord_;
1126 using src_vector_t =
typename decltype(src_tmp_vector)::type;
1129 src_desc, src_data_coord);
1132 if constexpr(SrcBuffer::IsDynamicBuffer())
1134 src_tmp_vector.template AsType<src_vector_t>()(
Number<0>{}) =
1135 src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() /
PackedSize,
1138 else if constexpr(SrcBuffer::IsStaticBuffer())
1141 constexpr
index_t src_offset = src_desc.CalculateOffset(
1142 src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
1143 i * src_scalar_step_in_vector);
1155 constexpr
index_t pack_size = 8;
1157 static_assert(SrcScalarPerVector % pack_size == 0,
"");
1162 static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](
auto i) {
1164 dst_tmp_vector.template AsType<dst_v_t>()(i),
1165 src_tmp_vector.template AsType<src_v_t>()[i]);
1170 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1171 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1178 SrcScalarPerVector % 2 == 0)
1184 constexpr
index_t pack_size = 2;
1188 static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](
auto i) {
1190 dst_tmp_vector.template AsType<dst_v_t>()(i),
1191 src_tmp_vector.template AsType<src_v_t>()[i]);
1196 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1197 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1210 dst_tmp_vector.template AsType<DstData>()(i) =
1211 type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
1216 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1217 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1226 template <
typename SrcRefToOriginDisplacement,
1227 typename DstOriginIdx,
1230 __device__
void Run(
const SrcDesc&,
1231 const SrcRefToOriginDisplacement&,
1232 const SrcBuffer& src_buf,
1233 const DstData& scale,
1235 const DstOriginIdx&,
1236 DstBuffer& dst_buf)
const
1238 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1239 "wrong! SrcDesc and DstDesc need to known at compile-time");
1244 "wrong! SrcBuffer or DstBuffer data type is wrong");
1246 static_assert(DstBuffer::IsStaticBuffer(),
"wrong! DstBuffer need to be StaticBuffer");
1250 "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
1258 constexpr
auto src_ref_to_origin_disp_idx =
to_multi_index(SrcRefToOriginDisplacement{});
1263 [&](
auto i) constexpr {
1264 if constexpr(i == SrcVectorDim)
1277 [&](
auto i) constexpr {
1278 if constexpr(i == SrcVectorDim)
1289 constexpr
auto access_lengths = SliceLengths{} / src_scalar_per_access;
1291 constexpr
auto dim_access_order = DimAccessOrder{};
1293 constexpr
auto ordered_access_lengths =
1296 static_ford<decltype(ordered_access_lengths)>{}([&](
auto ordered_access_idx) {
1300 constexpr
auto data_to_origin_disp_idx =
1302 src_scalar_per_access;
1305 constexpr
auto data_to_origin_disp_idx =
1306 ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
1309 constexpr
auto src_ref_to_data_disp_idx =
1310 src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
1312 constexpr
auto src_ref_to_data_disp_coord_step =
1315 auto src_data_coord = src_ref_coord_;
1321 using src_vector_t =
typename decltype(src_tmp_vector)::type;
1324 src_desc, src_data_coord);
1327 if constexpr(SrcBuffer::IsDynamicBuffer())
1329 src_tmp_vector.template AsType<src_vector_t>()(
Number<0>{}) =
1330 src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() /
PackedSize,
1333 else if constexpr(SrcBuffer::IsStaticBuffer())
1336 constexpr
index_t src_offset = src_desc.CalculateOffset(
1337 src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
1338 i * src_scalar_step_in_vector);
1350 scale_vector.template AsType<DstData>()(
Number<0>{}) = scale;
1351 scale_vector.template AsType<DstData>()(
Number<1>{}) = scale;
1353 constexpr
index_t pack_size = 8;
1355 static_assert(SrcScalarPerVector % pack_size == 0,
"");
1361 static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](
auto i) {
1363 dst_tmp_vector.template AsType<dst_v_t>()(i),
1364 src_tmp_vector.template AsType<src_v_t>()[i],
1365 scale_vector.template AsType<scale_v_t>()[
Number<0>{}]);
1370 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1371 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1378 SrcScalarPerVector % 2 == 0)
1384 constexpr
index_t pack_size = 2;
1388 static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](
auto i) {
1390 dst_tmp_vector.template AsType<dst_v_t>()(i),
1391 src_tmp_vector.template AsType<src_v_t>()[i]);
1396 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1397 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1410 dst_tmp_vector.template AsType<DstData>()(i) =
1411 type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
1416 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1417 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1425 template <
typename SrcSliceMoveStepIdx>
1427 const SrcSliceMoveStepIdx& src_slice_move_step_idx)
1429 constexpr
auto src_desc = SrcDesc{};
1431 const auto src_slice_move_step_iter =
1451 template <
typename SrcData,
1455 typename ElementwiseOperation,
1456 typename SliceLengths,
1457 typename DimAccessOrder,
1460 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1461 bool>::type =
false>
1469 const ElementwiseOperation& element_op)
1472 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1473 "wrong! Desc need to known at compile-time");
1476 "wrong! Not divisible");
1479 template <
typename SrcSliceOriginIdx,
1480 typename DstSliceOriginIdx,
1483 __device__
void Run(
const SrcDesc&,
1484 const SrcSliceOriginIdx&,
1485 const SrcBuffer& src_buf,
1487 const DstSliceOriginIdx&,
1490 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1491 "wrong! Desc need to known at compile-time");
1495 "wrong! SliceOrigin need to known at compile-time");
1497 static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
1498 "wrong! Buffer need to be StaticBuffer");
1503 constexpr
auto src_slice_origin_idx =
to_multi_index(SrcSliceOriginIdx{});
1504 constexpr
auto dst_slice_origin_idx =
to_multi_index(DstSliceOriginIdx{});
1510 constexpr
auto dst_scalar_step_in_vector =
1518 "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
1527 constexpr
index_t src_offset = src_desc.CalculateOffset(
1528 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1530 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1531 dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1553 template <
typename SrcData,
1557 typename ElementwiseOperation,
1558 typename SliceLengths,
1559 typename DimAccessOrder,
1562 uint32_t LowEightRowlaneIdx,
1563 uint32_t HighEightRowLaneIdx,
1564 bool IntraRowSwizzlePerm,
1565 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1566 bool>::type =
false>
1575 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1576 "wrong! Desc need to known at compile-time");
1579 "wrong! Not divisible");
1583 template <
typename SrcSliceOriginIdx,
1584 typename DstSliceOriginIdx,
1587 __device__
void Run(
const SrcDesc&,
1588 const SrcSliceOriginIdx&,
1589 const SrcBuffer& src_buf,
1591 const DstSliceOriginIdx&,
1592 DstBuffer& dst_buf)
const
1594 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1595 "wrong! Desc need to known at compile-time");
1599 "wrong! SliceOrigin need to known at compile-time");
1601 static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
1602 "wrong! Buffer need to be StaticBuffer");
1607 constexpr
auto src_slice_origin_idx =
to_multi_index(SrcSliceOriginIdx{});
1608 constexpr
auto dst_slice_origin_idx =
to_multi_index(DstSliceOriginIdx{});
1614 constexpr
auto dst_scalar_step_in_vector =
1622 "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
1632 constexpr
index_t src_offset = src_desc.CalculateOffset(
1633 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1635 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1636 dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1638 SrcData v_this_row, v_theother_row;
1646 if constexpr(IntraRowSwizzlePerm)
1648 temp = __builtin_amdgcn_permlane16(
1649 temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
1650 v_this_row = type_convert_sp<SrcData>(temp);
1654 temp = __builtin_amdgcn_permlanex16(temp,
1655 type_convert_sp<int>(v_this_row),
1657 HighEightRowLaneIdx,
1660 v_theother_row = type_convert_sp<SrcData>(temp);
1667 type_convert_sp<DstData>(v_theother_row);
1673 type_convert_sp<DstData>(v_this_row);
1683 template <
typename SrcData,
1687 typename ElementwiseOperation,
1688 typename SliceLengths,
1689 typename DimAccessOrder,
1692 bool IntraRowSwizzlePerm,
1693 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1694 bool>::type =
false>
1703 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1704 "wrong! Desc need to known at compile-time");
1707 "wrong! Not divisible");
1711 template <
typename SrcSliceOriginIdx,
1712 typename DstSliceOriginIdx,
1715 __device__
void Run(
const SrcDesc&,
1716 const SrcSliceOriginIdx&,
1717 const SrcBuffer& src_buf,
1719 const DstSliceOriginIdx&,
1720 DstBuffer& dst_buf)
const
1722 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1723 "wrong! Desc need to known at compile-time");
1727 "wrong! SliceOrigin need to known at compile-time");
1729 static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
1730 "wrong! Buffer need to be StaticBuffer");
1735 constexpr
auto src_slice_origin_idx =
to_multi_index(SrcSliceOriginIdx{});
1736 constexpr
auto dst_slice_origin_idx =
to_multi_index(DstSliceOriginIdx{});
1742 constexpr
auto dst_scalar_step_in_vector =
1750 "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
1760 constexpr
index_t src_offset = src_desc.CalculateOffset(
1761 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1763 constexpr
index_t dst_offset = dst_desc.CalculateOffset(
1764 dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1774 if constexpr(IntraRowSwizzlePerm)
1776 temp = __builtin_amdgcn_permlane16(
1777 temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
1778 v_this_row = type_convert_sp<SrcData>(temp);
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition: tensor_descriptor.hpp:560
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:990
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
_Float16 half_t
Definition: data_type.hpp:25
__host__ constexpr __device__ auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition: tensor_descriptor.hpp:407
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto generate_sequence(F, Number< N >)
Definition: sequence_helper.hpp:18
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition: container_helper.hpp:380
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition: container_helper.hpp:43
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition: tensor_descriptor.hpp:508
__host__ constexpr __device__ auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition: tensor_descriptor.hpp:444
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
__host__ constexpr __device__ auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition: container_helper.hpp:54
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition: data_type.hpp:384
__host__ static constexpr __device__ T QuietNaN()
Definition: data_type.hpp:2835
Definition: tensor_space_filling_curve.hpp:20
static __device__ constexpr __host__ auto GetForwardStep(Number< AccessIdx1d >)
Definition: tensor_space_filling_curve.hpp:66
__host__ static constexpr __device__ index_t GetNumOfAccess()
Definition: tensor_space_filling_curve.hpp:41
static constexpr index_t ScalarPerVector
Definition: tensor_space_filling_curve.hpp:25
static __device__ constexpr __host__ Index GetIndex(Number< AccessIdx1d >)
Definition: tensor_space_filling_curve.hpp:81
static __device__ constexpr __host__ auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition: tensor_space_filling_curve.hpp:52
Definition: threadwise_tensor_slice_transfer.hpp:1568
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:1569
constexpr __device__ ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index &src_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1573
ElementwiseOperation element_op_
Definition: threadwise_tensor_slice_transfer.hpp:1679
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1587
Definition: threadwise_tensor_slice_transfer.hpp:1696
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:1697
constexpr __device__ ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index &src_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1701
ElementwiseOperation element_op_
Definition: threadwise_tensor_slice_transfer.hpp:1786
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1715
Threadwise data transfer.
Definition: threadwise_tensor_slice_transfer.hpp:1463
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:1483
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:1464
ElementwiseOperation element_op_
Definition: threadwise_tensor_slice_transfer.hpp:1544
constexpr __device__ ThreadwiseTensorSliceTransfer_StaticToStatic(const ElementwiseOperation &element_op)
Definition: threadwise_tensor_slice_transfer.hpp:1468
Definition: threadwise_tensor_slice_transfer.hpp:39
static constexpr __device__ auto GetDstCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:149
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:40
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:42
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition: threadwise_tensor_slice_transfer.hpp:44
constexpr __device__ ThreadwiseTensorSliceTransfer_v1r3(const DstDesc &dst_desc, const Index &dst_slice_origin_idx, const ElementwiseOperation &element_op)
Definition: threadwise_tensor_slice_transfer.hpp:48
decltype(make_tensor_coordinate_step(DstDesc{}, Index{})) DstCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:46
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:173
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:60
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Definition: threadwise_tensor_slice_transfer.hpp:214
constexpr __device__ ThreadwiseTensorSliceTransfer_v2(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:227
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:221
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:355
static constexpr __device__ auto GetSrcCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:331
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition: threadwise_tensor_slice_transfer.hpp:372
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:237
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:219
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:225
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition: threadwise_tensor_slice_transfer.hpp:223
Definition: threadwise_tensor_slice_transfer.hpp:418
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition: threadwise_tensor_slice_transfer.hpp:423
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:425
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:420
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition: threadwise_tensor_slice_transfer.hpp:934
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, const SrcStepHacks &src_step_hacks)
Definition: threadwise_tensor_slice_transfer.hpp:453
decltype(make_tensor_coordinate_step(DstDesc{}, Index{})) DstCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:426
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:950
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:446
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:783
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition: threadwise_tensor_slice_transfer.hpp:441
static constexpr __device__ auto GetSrcCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:796
static constexpr __device__ auto GetDstCoordinateResetStep()
Definition: threadwise_tensor_slice_transfer.hpp:856
constexpr __device__ ThreadwiseTensorSliceTransfer_v3(const SrcDesc &src_desc, const Index &src_slice_origin, const DstDesc &dst_desc, const Index &dst_slice_origin)
Definition: threadwise_tensor_slice_transfer.hpp:428
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition: threadwise_tensor_slice_transfer.hpp:422
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:419
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf)
Definition: threadwise_tensor_slice_transfer.hpp:769
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:917
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, const DstStepHacks &dst_step_hacks)
Definition: threadwise_tensor_slice_transfer.hpp:610
Definition: threadwise_tensor_slice_transfer.hpp:1001
static constexpr index_t nDim
Definition: threadwise_tensor_slice_transfer.hpp:1002
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1036
static constexpr index_t PackedSize
Definition: threadwise_tensor_slice_transfer.hpp:1010
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition: threadwise_tensor_slice_transfer.hpp:1006
constexpr __device__ ThreadwiseTensorSliceTransfer_v4(const Index &src_ref_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1017
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition: threadwise_tensor_slice_transfer.hpp:1008
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstData &scale, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1230
__device__ void SetSrcCoord(const Index &src_ref_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1436
MultiIndex< nDim > Index
Definition: threadwise_tensor_slice_transfer.hpp:1004
__device__ void MoveSrcSliceWindow(const SrcDesc &, const SrcSliceMoveStepIdx &src_slice_move_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:1426
Definition: threadwise_tensor_slice_transfer_util.hpp:20
Definition: threadwise_tensor_slice_transfer_util.hpp:29
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: data_type.hpp:320
Definition: functional2.hpp:31
Definition: functional3.hpp:97
Definition: unary_element_wise_operation.hpp:174
Definition: unary_element_wise_operation.hpp:210
Definition: unary_element_wise_operation.hpp:115
Definition: data_type.hpp:367
Definition: data_type.hpp:347