Skip to content

[FEA]: Optimize segmented TopK with clusters for sm100f #9075

@gevtushenko

Description

@gevtushenko

Is this a duplicate?

Area

CUB

Is your feature request related to a problem? Please describe.

In the worst case, our device-level TopK implementations can read each input element from global memory ~sizeof(T) times. For device-level TopK, it's caused by the fact that we use separate kernel to ensure global histogram is updated by each CTA. For segmented TopK, similar behavior will take place once we start supporting large segment sizes.

From the first principles, clusters let us synchronize CTAs without launching extra kernels while preserving candidate items in shared memory to avoiding global memory traffic on SM90+. Naive implementation that does N cluster-level histogram passes shows ~30% speedup compared to device-level TopK. There are other attempts that show ~1.4x (up to 4x) speedup on larger segment sizes.

Describe the solution you'd like

We should attempt optimizing segmented TopK with SM90+ clusters.

Describe alternatives you've considered

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions