Is this a duplicate?
Area
CUB
Is your feature request related to a problem? Please describe.
In the worst case, our device-level TopK implementations can read each input element from global memory ~sizeof(T) times. For device-level TopK, it's caused by the fact that we use separate kernel to ensure global histogram is updated by each CTA. For segmented TopK, similar behavior will take place once we start supporting large segment sizes.
From the first principles, clusters let us synchronize CTAs without launching extra kernels while preserving candidate items in shared memory to avoiding global memory traffic on SM90+. Naive implementation that does N cluster-level histogram passes shows ~30% speedup compared to device-level TopK. There are other attempts that show ~1.4x (up to 4x) speedup on larger segment sizes.
Describe the solution you'd like
We should attempt optimizing segmented TopK with SM90+ clusters.
Describe alternatives you've considered
No response
Additional context
No response
Is this a duplicate?
Area
CUB
Is your feature request related to a problem? Please describe.
In the worst case, our device-level TopK implementations can read each input element from global memory
~sizeof(T)times. For device-level TopK, it's caused by the fact that we use separate kernel to ensure global histogram is updated by each CTA. For segmented TopK, similar behavior will take place once we start supporting large segment sizes.From the first principles, clusters let us synchronize CTAs without launching extra kernels while preserving candidate items in shared memory to avoiding global memory traffic on SM90+. Naive implementation that does N cluster-level histogram passes shows ~30% speedup compared to device-level TopK. There are other attempts that show ~1.4x (up to 4x) speedup on larger segment sizes.
Describe the solution you'd like
We should attempt optimizing segmented TopK with SM90+ clusters.
Describe alternatives you've considered
No response
Additional context
No response