diff --git a/c/parallel/src/scan.cu b/c/parallel/src/scan.cu index a2507bb523b..7eb1377b8d6 100644 --- a/c/parallel/src/scan.cu +++ b/c/parallel/src/scan.cu @@ -234,11 +234,13 @@ struct scan_kernel_source return arg; } - static auto lookahead_make_tile_state_kernel_arg(void* ts) + static auto lookahead_make_tile_state_kernel_arg(void* ts, ::cuda::std::uint32_t* atomic_counter = nullptr) { // we can ignore passing a wrong AccumT, since we only store a pointer, and the kernel will have the right type cub::detail::scan::tile_state_kernel_arg_t arg; - ::cuda::std::__construct_at(&arg.lookahead, static_cast*>(ts)); + ::cuda::std::__construct_at(&arg.lookahead, + cub::detail::scan::lookahead_tile_state_arg_t{ + static_cast*>(ts), atomic_counter}); return arg; } }; diff --git a/cub/cub/detail/warpspeed/look_ahead.cuh b/cub/cub/detail/warpspeed/look_ahead.cuh index d80d234a794..4a559f250bb 100644 --- a/cub/cub/detail/warpspeed/look_ahead.cuh +++ b/cub/cub/detail/warpspeed/look_ahead.cuh @@ -72,10 +72,10 @@ struct alignas(_Alignment) tile_state_t : tile_state_unaligned_t template _CCCL_DEVICE_API void -storeTileAggregate(tile_state_t* ptrTileStates, scan_state scanState, AccumT aggr, int index) +storeTileAggregate(tile_state_t* ptrTileStates, scan_state scanState, AccumT aggr, int index, int num_tiles) { _CCCL_ASSERT(::cuda::is_aligned(ptrTileStates, alignof(tile_state_t)), ""); - _CCCL_ASSERT(index >= 0 && index < gridDim.x, "Reading out of bounds tile state"); + _CCCL_ASSERT(index >= 0 && index < num_tiles, "Reading out of bounds tile state"); if constexpr (sizeof(tile_state_t) <= cub::detail::warpspeed::max_native_atomic_size() && ::cuda::is_trivially_copyable_v>) @@ -99,10 +99,10 @@ storeTileAggregate(tile_state_t* ptrTileStates, scan_state scanState, Ac } template -_CCCL_DEVICE_API tile_state_t loadTileAggregate(tile_state_t* ptrTileStates, int index) +_CCCL_DEVICE_API tile_state_t loadTileAggregate(tile_state_t* ptrTileStates, int index, int num_tiles) { _CCCL_ASSERT(::cuda::is_aligned(ptrTileStates, alignof(tile_state_t)), ""); - _CCCL_ASSERT(index >= 0 && index < gridDim.x, "Reading out of bounds tile state"); + _CCCL_ASSERT(index >= 0 && index < num_tiles, "Reading out of bounds tile state"); tile_state_t res; if constexpr (sizeof(tile_state_t) <= cub::detail::warpspeed::max_native_atomic_size() @@ -149,14 +149,15 @@ _CCCL_DEVICE_API void warpLoadLookahead( tile_state_t (&outTileStates)[numTileStatesPerThread], tile_state_t* ptrTileStates, int idxTileCur, - int idxTileNext) + int idxTileNext, + int num_tiles) { for (int i = 0; i < numTileStatesPerThread; ++i) { const int idxTileLookahead = idxTileCur + 32 * i + laneIdx; if (idxTileLookahead < idxTileNext) { - outTileStates[i] = loadTileAggregate(ptrTileStates, idxTileLookahead); + outTileStates[i] = loadTileAggregate(ptrTileStates, idxTileLookahead, num_tiles); } else { @@ -182,7 +183,8 @@ template const int idxTilePrev, const AccumT aggrExclusiveCtaPrev, const int idxTileNext, - ScanOpT& scan_op) + ScanOpT& scan_op, + const int num_tiles) { const int laneIdx = specialRegisters.laneIdx; const ::cuda::std::uint32_t lanemaskEq = ::cuda::ptx::get_sreg_lanemask_eq(); @@ -203,7 +205,7 @@ template while (idxTileCur < idxTileNext) { tile_state_t regTmpStates[numTileStatesPerThread]; - warpLoadLookahead(laneIdx, regTmpStates, ptrTileStates, idxTileCur, idxTileNext); + warpLoadLookahead(laneIdx, regTmpStates, ptrTileStates, idxTileCur, idxTileNext, num_tiles); for (int idx = 0; idx < numTileStatesPerThread; ++idx) { @@ -274,7 +276,8 @@ template int& idxTilePrev, AccumT& aggrExclusiveCtaPrev, const int idxTileNext, - ScanOpT& scan_op) + ScanOpT& scan_op, + const int num_tiles) { const int laneIdx = specialRegisters.laneIdx; const ::cuda::std::uint32_t lanemaskEq = ::cuda::ptx::get_sreg_lanemask_eq(); @@ -290,7 +293,7 @@ template while (idxTileCur < idxTileNext) { tile_state_t regTmpStates[numTileStatesPerThread]; - warpLoadLookahead(laneIdx, regTmpStates, ptrTileStates, idxTileCur, idxTileNext); + warpLoadLookahead(laneIdx, regTmpStates, ptrTileStates, idxTileCur, idxTileNext, num_tiles); for (int idx = 0; idx < numTileStatesPerThread; ++idx) { diff --git a/cub/cub/detail/warpspeed/squad/load_store.cuh b/cub/cub/detail/warpspeed/squad/load_store.cuh index 8567c35d2c4..0143bdcb1ef 100644 --- a/cub/cub/detail/warpspeed/squad/load_store.cuh +++ b/cub/cub/detail/warpspeed/squad/load_store.cuh @@ -25,6 +25,8 @@ #include #include +#include + CUB_NAMESPACE_BEGIN namespace detail::warpspeed @@ -280,16 +282,20 @@ squadStoreBulkSync(Squad squad, CpAsyncOobInfo cpAsyncOobInfo, const :: asm volatile("" : "+l"(srcSmem)); # endif // _CCCL_CUDA_COMPILER(NVCC, <, 13, 3) // Copy a subset of the first 16 bytes - if (::cuda::ptx::elect_sync(~0)) - { - ::cuda::ptx::cp_async_bulk_cp_mask( - ::cuda::ptx::space_global, - ::cuda::ptx::space_shared, - cpAsyncOobInfo.ptrGmemStartAlignDown, - srcSmem, - /*size*/ 16, - byteMaskStart); - } + NV_IF_ELSE_TARGET( + NV_PROVIDES_SM_100, + (if (::cuda::ptx::elect_sync(~0)) { + ::cuda::ptx::cp_async_bulk_cp_mask( + ::cuda::ptx::space_global, + ::cuda::ptx::space_shared, + cpAsyncOobInfo.ptrGmemStartAlignDown, + srcSmem, + /*size*/ 16, + byteMaskStart); + }), + (const int rank = squad.threadRank(); if (rank < 16 && ((byteMaskStart >> rank) & 1u)) { + cpAsyncOobInfo.ptrGmemStartAlignDown[rank] = srcSmem[rank]; + })); } if (doEndCopy) { @@ -299,32 +305,42 @@ squadStoreBulkSync(Squad squad, CpAsyncOobInfo cpAsyncOobInfo, const :: asm volatile("" : "+l"(cpAsyncOobInfo.ptrGmemEndAlignDown)); # endif // _CCCL_CUDA_COMPILER(NVHPC) - // Copy a subset of the last 16 bytes - if (::cuda::ptx::elect_sync(~0)) - { - ::cuda::ptx::cp_async_bulk_cp_mask( - ::cuda::ptx::space_global, - ::cuda::ptx::space_shared, - cpAsyncOobInfo.ptrGmemEndAlignDown, - ptrSmemMiddle + cpAsyncOobInfo.underCopySizeBytes, - /*size*/ 16, - byteMaskEnd); - } + // Copy a subset of the first 16 bytes + NV_IF_ELSE_TARGET( + NV_PROVIDES_SM_100, + (if (::cuda::ptx::elect_sync(~0)) { + ::cuda::ptx::cp_async_bulk_cp_mask( + ::cuda::ptx::space_global, + ::cuda::ptx::space_shared, + cpAsyncOobInfo.ptrGmemEndAlignDown, + ptrSmemMiddle + cpAsyncOobInfo.underCopySizeBytes, + /*size*/ 16, + byteMaskEnd); + }), + (const int rank = squad.threadRank(); + const ::cuda::std::byte* tail_smem_source = ptrSmemMiddle + cpAsyncOobInfo.underCopySizeBytes; + if (rank < 16 && ((byteMaskEnd >> rank) & 1u)) { + cpAsyncOobInfo.ptrGmemEndAlignDown[rank] = tail_smem_source[rank]; + })); } } else { // Copy a subset of the first 16 bytes - if (::cuda::ptx::elect_sync(~0)) - { - ::cuda::ptx::cp_async_bulk_cp_mask( - ::cuda::ptx::space_global, - ::cuda::ptx::space_shared, - cpAsyncOobInfo.ptrGmemStartAlignDown, - srcSmem, - /*size*/ 16, - byteMaskSmall); - } + NV_IF_ELSE_TARGET( + NV_PROVIDES_SM_100, + (if (::cuda::ptx::elect_sync(~0)) { + ::cuda::ptx::cp_async_bulk_cp_mask( + ::cuda::ptx::space_global, + ::cuda::ptx::space_shared, + cpAsyncOobInfo.ptrGmemStartAlignDown, + srcSmem, + /*size*/ 16, + byteMaskSmall); + }), + (const int rank = squad.threadRank(); if (rank < 16 && ((byteMaskSmall >> rank) & 1u)) { + cpAsyncOobInfo.ptrGmemStartAlignDown[rank] = srcSmem[rank]; + })); } // Commit and wait for store to have completed reading from shared memory ::cuda::ptx::cp_async_bulk_commit_group(); diff --git a/cub/cub/device/dispatch/dispatch_scan.cuh b/cub/cub/device/dispatch/dispatch_scan.cuh index 0d50fc3a263..21da1421bb8 100644 --- a/cub/cub/device/dispatch/dispatch_scan.cuh +++ b/cub/cub/device/dispatch/dispatch_scan.cuh @@ -140,10 +140,13 @@ struct DeviceScanKernelSource return arg; } - CUB_RUNTIME_FUNCTION static constexpr auto lookahead_make_tile_state_kernel_arg(void* ts) + CUB_RUNTIME_FUNCTION static constexpr auto + lookahead_make_tile_state_kernel_arg(void* ts, ::cuda::std::uint32_t* atomic_counter = nullptr) { tile_state_kernel_arg_t arg; - ::cuda::std::__construct_at(&arg.lookahead, static_cast*>(ts)); + ::cuda::std::__construct_at( + &arg.lookahead, + lookahead_tile_state_arg_t{static_cast*>(ts), atomic_counter}); return arg; } }; @@ -1083,6 +1086,7 @@ CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t invoke_lookahead( OffsetT num_items, cudaStream_t stream, bool dependent_launch, + bool atomic_scheduling, KernelSource kernel_source, KernelLauncherFactory launcher_factory) { @@ -1101,25 +1105,33 @@ CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t invoke_lookahead( CUB_DETAIL_STATIC_ISH_ASSERT(lookahead_policy.lookahead_items_per_thread >= 1, "Lookahead scan policy must look ahead at least 1 item per thread"); - const int grid_dim = + const int num_tiles = static_cast(::cuda::ceil_div(num_items, static_cast(lookahead_policy.tile_size()))); - if (d_temp_storage == nullptr) + size_t allocation_sizes[2] = { + static_cast(num_tiles) * kernel_source.lookahead_tile_state_size(), sizeof(::cuda::std::uint32_t)}; + void* allocations[2] = {}; + if (const auto error = + CubDebug(detail::alias_temporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) { - temp_storage_bytes = static_cast(grid_dim) * kernel_source.lookahead_tile_state_size(); - return cudaSuccess; + return error; } - if (num_items == 0) + if (d_temp_storage == nullptr) { return cudaSuccess; } + void* d_tile_state = allocations[0]; + ::cuda::std::uint32_t* d_atomic_counter = static_cast<::cuda::std::uint32_t*>(allocations[1]); + int sm_count = 0; if (const auto error = CubDebug(launcher_factory.MultiProcessorCount(sm_count))) { return error; } + + const int scan_grid_dim = atomic_scheduling ? ::cuda::std::min(sm_count, num_tiles) : num_tiles; // Maximum dynamic shared memory size that we can use for temporary storage. int max_dynamic_smem_size{}; if (const auto error = @@ -1129,7 +1141,7 @@ CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t invoke_lookahead( } // TODO(bgruber): we probably need to ensure alignment of d_temp_storage - _CCCL_ASSERT(::cuda::is_aligned(d_temp_storage, kernel_source.lookahead_tile_state_alignment()), ""); + _CCCL_ASSERT(::cuda::is_aligned(d_tile_state, kernel_source.lookahead_tile_state_alignment()), ""); auto scan_kernel = kernel_source.ScanKernel(); [[maybe_unused]] auto kernel_src = kernel_source; // need to pull a copy to not access `this` during const. eval. @@ -1188,7 +1200,7 @@ CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t invoke_lookahead( // Invoke init kernel { constexpr auto init_kernel_threads = 128; - const auto init_grid_size = ::cuda::ceil_div(grid_dim, init_kernel_threads); + const auto init_grid_size = ::cuda::ceil_div(num_tiles, init_kernel_threads); # ifdef CUB_DEBUG_LOG _CubLog("Invoking DeviceScanInitKernel<<<%d, %d, 0, %lld>>>()\n", @@ -1200,8 +1212,8 @@ CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t invoke_lookahead( if (const auto error = CubDebug( launcher_factory(init_grid_size, init_kernel_threads, 0, stream, dependent_launch) .doit(kernel_source.InitKernel(), - kernel_source.lookahead_make_tile_state_kernel_arg(d_temp_storage), - grid_dim))) + kernel_source.lookahead_make_tile_state_kernel_arg(d_tile_state, d_atomic_counter), + num_tiles))) { return error; } @@ -1223,15 +1235,16 @@ CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t invoke_lookahead( { const int block_dim = detail::scan::num_total_threads(lookahead_policy); # ifdef CUB_DEBUG_LOG - _CubLog("Invoking DeviceScanKernel<<<%d, %d, %d, %lld>>>()\n", grid_dim, block_dim, smem_size, (long long) stream); + _CubLog( + "Invoking DeviceScanKernel<<<%d, %d, %d, %lld>>>()\n", scan_grid_dim, block_dim, smem_size, (long long) stream); # endif // CUB_DEBUG_LOG if (const auto error = CubDebug( - launcher_factory(grid_dim, block_dim, smem_size, stream, dependent_launch) + launcher_factory(scan_grid_dim, block_dim, smem_size, stream, dependent_launch) .doit(scan_kernel, THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(d_in), THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(d_out), - kernel_source.lookahead_make_tile_state_kernel_arg(d_temp_storage), + kernel_source.lookahead_make_tile_state_kernel_arg(d_tile_state, d_atomic_counter), /* start_tile, unused */ 0, ::cuda::std::move(scan_op), init_value, @@ -1285,6 +1298,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t invoke( const bool dependent_launch = cc >= ::cuda::compute_capability{9, 0}; if CUB_DETAIL_CONSTEXPR_ISH (policy_getter().algorithm == ScanAlgorithm::lookahead) { + const bool atomic_scheduling = cc == ::cuda::compute_capability{9, 0}; return invoke_lookahead( policy_getter, d_temp_storage, @@ -1296,6 +1310,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t invoke( num_items, stream, dependent_launch, + atomic_scheduling, kernel_source, launcher_factory); } diff --git a/cub/cub/device/dispatch/kernels/kernel_scan.cuh b/cub/cub/device/dispatch/kernels/kernel_scan.cuh index b03b60c6f10..834cffab134 100644 --- a/cub/cub/device/dispatch/kernels/kernel_scan.cuh +++ b/cub/cub/device/dispatch/kernels/kernel_scan.cuh @@ -25,14 +25,23 @@ #include +#include + CUB_NAMESPACE_BEGIN namespace detail::scan { +template +struct lookahead_tile_state_arg_t +{ + warpspeed::tile_state_t* tile_states; + ::cuda::std::uint32_t* atomic_counter; +}; + template union tile_state_kernel_arg_t { - warpspeed::tile_state_t* lookahead; + lookahead_tile_state_arg_t lookahead; ScanTileState lookback; // ScanTileState is not trivially [default|copy]-constructible, so because of @@ -69,7 +78,11 @@ _CCCL_KERNEL_ATTRIBUTES __launch_bounds__(128) void DeviceScanInitKernel( constexpr ScanPolicy policy = current_policy(); if constexpr (policy.algorithm == ScanAlgorithm::lookahead) { - device_scan_init_lookahead_body(tile_state.lookahead, num_tiles); + device_scan_init_lookahead_body(tile_state.lookahead.tile_states, num_tiles); + if (tile_state.lookahead.atomic_counter != nullptr && blockIdx.x == 0 && threadIdx.x == 0) + { + *tile_state.lookahead.atomic_counter = 0; + } } else #endif // _CCCL_CUDACC_AT_LEAST(12, 8) @@ -205,12 +218,13 @@ __launch_bounds__(device_scan_launch_bounds, 1) _CCCL_KERNEL_ATT if constexpr (active_policy.algorithm == ScanAlgorithm::lookahead) { #if _CCCL_CUDACC_AT_LEAST(12, 8) - NV_IF_TARGET(NV_PROVIDES_SM_100, ({ - auto scan_params = scanKernelParams, it_value_t, AccumT>{ - d_in, d_out, tile_state.lookahead, num_items, num_stages}; - device_scan_lookahead_body( - scan_params, scan_op, init_value); - })); + NV_IF_TARGET( + NV_PROVIDES_SM_90, ({ + auto scan_params = scanKernelParams, it_value_t, AccumT>{ + d_in, d_out, tile_state.lookahead.tile_states, tile_state.lookahead.atomic_counter, num_items, num_stages}; + device_scan_lookahead_body( + scan_params, scan_op, init_value); + })); #else static_assert(sizeof(d_in) == 0, "Implementation bug: Tuning policy selected lookahead, but CUDA compiler does not support it"); diff --git a/cub/cub/device/dispatch/kernels/kernel_scan_lookahead.cuh b/cub/cub/device/dispatch/kernels/kernel_scan_lookahead.cuh index 39e745c9b8f..8fffb9b9ffc 100644 --- a/cub/cub/device/dispatch/kernels/kernel_scan_lookahead.cuh +++ b/cub/cub/device/dispatch/kernels/kernel_scan_lookahead.cuh @@ -37,6 +37,8 @@ #include #include +#include + CUB_NAMESPACE_BEGIN namespace detail::scan @@ -57,6 +59,7 @@ struct scanKernelParams const InputT* ptrIn; OutputT* ptrOut; warpspeed::tile_state_t* ptrTileStates; + ::cuda::std::uint32_t* atomicCounter; ::cuda::std::size_t numElem; int numStages; }; @@ -132,6 +135,15 @@ _CCCL_DEVICE_API inline void squadGetNextBlockIdx(const warpspeed::Squad& squad, refDestSmem.squadIncreaseTxCount(squad, refDestSmem.sizeBytes()); } +_CCCL_DEVICE_API inline void squadGetNextBlockIdxAtomic( + const warpspeed::Squad& squad, warpspeed::SmemRef& refDestSmem, ::cuda::std::uint32_t* atomicCounter) +{ + if (squad.isLeaderThread()) + { + refDestSmem.data().x = ::atomicAdd(atomicCounter, 1u); + } +} + template _CCCL_DEVICE_API Tp warpReduce(const Tp input, ScanOpT& scan_op) { @@ -289,7 +301,9 @@ struct lookahead_scan_closure load_next_tile_index(const warpspeed::Squad& squad, warpspeed::SmemPhase& phaseNextBlockIdxW) const { warpspeed::SmemRef refNextBlockIdxW = phaseNextBlockIdxW.acquireRef(); - squadGetNextBlockIdx(squad, refNextBlockIdxW); + NV_IF_ELSE_TARGET(NV_PROVIDES_SM_100, + (squadGetNextBlockIdx(squad, refNextBlockIdxW);), + (squadGetNextBlockIdxAtomic(squad, refNextBlockIdxW, params.atomicCounter);)); } _CCCL_DEVICE_API _CCCL_FORCEINLINE void load_current_tile( @@ -307,7 +321,8 @@ struct lookahead_scan_closure bool is_first_tile, int& idxTilePrev, AccumT& AggrExclusiveCtaPrev, - int idxTile) /*const*/ // FIXME(bgruber): this const causes a large SASS diff + int idxTile, + int num_tiles) /*const*/ // FIXME(bgruber): this const causes a large SASS diff { warpspeed::SmemRef refAggrExclusiveCtaW = phaseAggrExclusiveCtaW.acquireRef(); @@ -317,7 +332,7 @@ struct lookahead_scan_closure { // The stable-order version updates idxTilePrev/AggrExclusiveCtaPrev itself AccumT regAggrExclusiveCta = warpspeed::warpIncrementalLookaheadStable( - specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op); + specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op, num_tiles); if (squad.isLeaderThread()) { refAggrExclusiveCtaW.data() = regAggrExclusiveCta; @@ -326,7 +341,7 @@ struct lookahead_scan_closure else { AccumT regAggrExclusiveCta = warpspeed::warpIncrementalLookahead( - specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op); + specialRegisters, params.ptrTileStates, idxTilePrev, AggrExclusiveCtaPrev, idxTile, scan_op, num_tiles); if (squad.isLeaderThread()) { refAggrExclusiveCtaW.data() = regAggrExclusiveCta; @@ -345,7 +360,8 @@ struct lookahead_scan_closure bool is_first_tile, bool is_last_tile, // TODO(bgruber): should we dispatch on is_last_tile outside this function and compile it twice? const warpspeed::CpAsyncOobInfo& loadInfo, - int idxTile) const + int idxTile, + int num_tiles) const { const int valid_items_this_thread = cuda::std::clamp(valid_items - squad.threadRank() * elemPerThread, 0, +elemPerThread); @@ -430,7 +446,8 @@ struct lookahead_scan_closure // Store tile aggregate for lookahead if (squad.isLeaderThread()) { - warpspeed::storeTileAggregate(params.ptrTileStates, warpspeed::scan_state::tile_aggregate, regSquadAggr, idxTile); + warpspeed::storeTileAggregate( + params.ptrTileStates, warpspeed::scan_state::tile_aggregate, regSquadAggr, idxTile, num_tiles); } // Store thread aggregate @@ -707,17 +724,29 @@ struct lookahead_scan_closure // hot loops (even if that may seem the case from a first glance at the code). _CCCL_DEVICE_API _CCCL_FORCEINLINE void dispatch_squad(warpspeed::Squad squad) // const // TODO(bgruber): enable const { - // Start with the tile indicated by blockIdx.x - int idxTile = specialRegisters.blockIdxX; + const int numTiles = static_cast(::cuda::ceil_div(params.numElem, ::cuda::std::size_t(tile_size))); + // Lookahead-specific variables: int idxTilePrev = 0; AccumT AggrExclusiveCtaPrev; // only valid in squadLookahead lane_0 _CCCL_PDL_GRID_DEPENDENCY_SYNC(); - // Loop over tiles + // Start with the tile indicated by blockIdx.x for sm100, but for sm90+, we use an atomic counter to determine the + // first tile + int idxTile; + NV_IF_ELSE_TARGET(NV_PROVIDES_SM_100, (idxTile = specialRegisters.blockIdxX;), ({ + __shared__ int s_first_tile; + if (specialRegisters.threadIdxX == 0) + { + s_first_tile = static_cast(::atomicAdd(params.atomicCounter, 1u)); + } + __syncthreads(); + idxTile = s_first_tile; + })); + # pragma unroll 1 - while (true) + while (idxTile < numTiles) { // Get stages. When these objects go out of scope, the stage of the resource is automatically incremented. warpspeed::SmemStage stageNextBlockIdx = res.smemNextBlockIdx.nextStage(); @@ -761,17 +790,29 @@ struct lookahead_scan_closure regNextBlockIdx = refNextBlockIdxR.data(); refNextBlockIdxR.setFenceLdsToAsyncProxy(); } - bool nextIdxTileValid = ::cuda::ptx::clusterlaunchcontrol_query_cancel_is_canceled(regNextBlockIdx); + bool nextIdxTileValid = false; + NV_IF_ELSE_TARGET( + NV_PROVIDES_SM_100, + (nextIdxTileValid = ::cuda::ptx::clusterlaunchcontrol_query_cancel_is_canceled(regNextBlockIdx);), + (nextIdxTileValid = static_cast(regNextBlockIdx.x) < numTiles;)); if (squad == squadReduce) { reduce_tile( - squad, phaseInOutRW, phaseThreadAndWarpAggrW, valid_items, is_first_tile, is_last_tile, loadInfo, idxTile); + squad, + phaseInOutRW, + phaseThreadAndWarpAggrW, + valid_items, + is_first_tile, + is_last_tile, + loadInfo, + idxTile, + numTiles); } if (squad == squadLookahead) { - lookahead(squad, phaseAggrExclusiveCtaW, is_first_tile, idxTilePrev, AggrExclusiveCtaPrev, idxTile); + lookahead(squad, phaseAggrExclusiveCtaW, is_first_tile, idxTilePrev, AggrExclusiveCtaPrev, idxTile, numTiles); } if (squad == squadScanStore) @@ -808,7 +849,10 @@ struct lookahead_scan_closure { break; } - idxTile = ::cuda::ptx::clusterlaunchcontrol_query_cancel_get_first_ctaid_x(regNextBlockIdx); + NV_IF_ELSE_TARGET( + NV_PROVIDES_SM_100, + (idxTile = ::cuda::ptx::clusterlaunchcontrol_query_cancel_get_first_ctaid_x(regNextBlockIdx);), + (idxTile = static_cast(regNextBlockIdx.x);)); } // epilogue: after the load squad finished, we can start ramping up the next kernel diff --git a/cub/cub/device/dispatch/tuning/tuning_scan.cuh b/cub/cub/device/dispatch/tuning/tuning_scan.cuh index 615b74d86a2..978530cd5b1 100644 --- a/cub/cub/device/dispatch/tuning/tuning_scan.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_scan.cuh @@ -978,6 +978,10 @@ struct policy_selector return get_sm100_fallback_lookahead_policy(); } + if (cc >= ::cuda::compute_capability{9, 0} && require_stable_reduction_order) + { + return get_sm100_fallback_lookahead_policy(); + } return {}; } @@ -1027,9 +1031,9 @@ struct policy_selector [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator()(::cuda::compute_capability cc) const -> ScanPolicy { // we first try to get the valid lookahead implementation. if we can't run it, fall back to the old scan impl. - // For stable reduction order (fp + plus), lookahead can only be used on sm_100+, Older arches fall back to classic + // For stable reduction order (fp + plus), lookahead can only be used on sm_90+, Older arches fall back to classic // lookback stable reduction order implementation below. - if (!require_stable_reduction_order || cc >= ::cuda::compute_capability{10, 0}) + if (!require_stable_reduction_order || cc >= ::cuda::compute_capability{9, 0}) { auto lookahead_policy_opt = get_lookahead_policy(cc); if (lookahead_policy_opt && can_use_lookahead(cc, *lookahead_policy_opt))