Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 82 additions & 139 deletions op_tests/test_topk_plain.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,51 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.

import gc

import pandas as pd
import torch
from aiter.test_common import (
checkAllclose,
benchmark,
run_perftest,
)

import aiter
from aiter import dtypes
from aiter.ops.topk_plain import topk_plain
import pandas as pd
from aiter.test_common import benchmark, checkAllclose, run_perftest

torch.set_default_device("cuda")
torch.set_printoptions(sci_mode=False)

# Correctness sweep: a few timed iters is plenty (was 1000). This is a
# correctness check, not a perf gate; the high iter count made the file the
# slowest in its CI shard for no benefit.
NUM_ITERS = 100
NUM_WARMUP = 10

# checkAllclose returns the fraction of mismatching elements (0 == exact) and
# does NOT raise; assert on it so an incorrect topk_plain actually fails CI.
TOL_ERR_RATIO = 0.05


@benchmark()
def test_topk(
batch_size,
hiddensize,
topk,
largest,
dtype,
):
output = torch.randn((batch_size, hiddensize), dtype=dtype)
device = output.device
def run_topk_case(batch_size, hiddensize, topk, largest, dtype):
device = "cuda"
# Each row is a permutation of [0, hiddensize) -> distinct values for topk.
# Vectorised; replaces the per-row Python randperm loop (batch_size iters).
x = torch.rand(batch_size, hiddensize, device=device).argsort(dim=1).to(dtype)

topk_ids = torch.zeros((batch_size, topk), dtype=dtypes.i32, device=device)
topk_value = torch.zeros((batch_size, topk), dtype=dtype, device=device)

x = torch.arange(hiddensize, dtype=dtype).repeat(batch_size, 1)
for b in range(batch_size):
x[b] = x[b, torch.randperm(hiddensize)]

(ref_value, ref_index), us_ref = run_perftest(
torch.topk,
x,
topk,
largest=largest,
num_iters=1000,
num_warmup=100,
num_iters=NUM_ITERS,
num_warmup=NUM_WARMUP,
)

id_ref, _ref = torch.sort(ref_index)

# Try Triton, but handle resource errors gracefully
# try:
# (res_triton_value, res_triton_index), us_triton = run_perftest(
# triton_topk,
# x,
# topk,
# largest=largest,
# num_iters=1000,
# num_warmup=100,
# )

# id_triton, _triton = torch.sort(res_triton_index)
# checkAllclose(
# ref_value.gather(1, _ref),
# res_triton_value.gather(1, _triton),
# msg="topk_values [golden vs triton]",
# )
# checkAllclose(
# id_ref,
# id_triton,
# msg=(
# f"topk_ids Performance Comparison:\n"
# f" {'Method':<10} {'Time (us)':>12}\n"
# f" {'-'*10} {'-'*12}\n"
# f" {'golden':<10} {us_ref:>12.2f}\n"
# f" {'triton':<10} {us_triton:>12.2f}\n"
# ),
# )
# except Exception as e:
# print(f"Triton failed: {e}")
# print("Setting triton time to 0 and continuing...")
# us_triton = 0.0

# TODO: uncomment this when the triton topk return in a resonalbe execution time
# TODO: re-enable triton topk comparison when it returns in a reasonable time.
us_triton = 0.0

_, us_aiter = run_perftest(
Expand All @@ -88,51 +55,30 @@ def test_topk(
topk_value,
topk,
largest,
torch.tensor(
[], dtype=torch.int32, device=device
), # rowStarts - empty int32 tensor
torch.tensor(
[], dtype=torch.int32, device=device
), # rowEnds - empty int32 tensor
-1, # stride0
1, # stride1
torch.tensor([], dtype=torch.int32, device=device), # rowStarts
torch.tensor([], dtype=torch.int32, device=device), # rowEnds
-1,
1, # stride0, stride1
num_iters=NUM_ITERS,
num_warmup=NUM_WARMUP,
)

id_aiter, _aiter = torch.sort(topk_ids.to(torch.long))

# Skip for float16 as it would has duplicates in topk_ids
if dtype != torch.float16 and dtype != torch.bfloat16:
# TODO: uncomment this when the aiter topk supports value return
# err = checkAllclose(
# ref_value.gather(1, _ref),
# topk_value.gather(1, _aiter),
# msg="topk_values [golden vs aiter]",
# )
err = checkAllclose(
id_ref,
id_aiter,
msg=(
f"topk_ids Performance Comparison:\n"
f" {'Method':<10} {'Time (us)':>12}\n"
f" {'-'*10} {'-'*12}\n"
f" {'golden':<10} {us_ref:>12.2f}\n"
f" {'triton':<10} {us_triton:>12.2f}\n"
f" {'aiter':<10} {us_aiter:>12.2f}\n"
),
)
if dtype not in (torch.float16, torch.bfloat16):
err = checkAllclose(id_ref, id_aiter, msg="topk_ids [golden vs aiter]")
else:
err = checkAllclose(
ref_value,
topk_value,
msg=(
f"topk_values [golden vs aiter]:\n"
f" {'Method':<10} {'Time (us)':>12}\n"
f" {'-'*10} {'-'*12}\n"
f" {'golden':<10} {us_ref:>12.2f}\n"
f" {'triton':<10} {us_triton:>12.2f}\n"
f" {'aiter':<10} {us_aiter:>12.2f}\n"
),
)
# fp16/bf16 can tie within the top-k -> compare values, not indices.
err = checkAllclose(ref_value, topk_value, msg="topk_values [golden vs aiter]")
assert err <= TOL_ERR_RATIO, (
f"topk_plain mismatch: err ratio {err:.4f} > {TOL_ERR_RATIO} "
f"(batch_size={batch_size}, hiddensize={hiddensize}, topk={topk}, dtype={dtype})"
)

# Release this case's buffers before the next (larger) case allocates, so a
# memory-pressured runner does not OOM accumulating the whole sweep.
del x, topk_ids, topk_value, ref_value, ref_index, id_ref, id_aiter, _ref, _aiter
gc.collect()
torch.cuda.empty_cache()

return {
"err": err,
Expand All @@ -142,50 +88,47 @@ def test_topk(
}


# BATCH_SIZES = [100, 1000, 10000, 32679]
# HIDDENSIZES = [10000, 100000]
# topk = 64
# BATCH_SIZES = [3072, 3072, 3072]
# HIDDENSIZES = [3072, 4096, 8192]
BATCH_SIZES = [3072]
HIDDENSIZES = [3072, 4096, 8192, 16384, 32768, 65536, 131072]
# HIDDENSIZES = [32768]
TOPKS = [2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1]
largest = True

df = []
for batch_size in BATCH_SIZES:
for hiddensize in HIDDENSIZES:
for topk in TOPKS:
print(f"\n{'='*60}")
print(
f"Testing: batch_size={batch_size}, hiddensize={hiddensize}, topk={topk}"
)
print(f"{'='*60}")
ret = test_topk(
batch_size,
hiddensize,
topk,
largest,
dtypes.fp32,
)
df.append(
{
"batch_size": batch_size,
"hiddensize": hiddensize,
"topk": topk,
"error": ret["err"],
"time_us (aiter)": ret["us_aiter"],
"time_us (torch)": ret["us_torch"],
"time_us (triton)": ret["us_triton"],
}
)

df = pd.DataFrame(df)

# Add speedup columns
df["speedup (aiter vs torch)"] = df["time_us (torch)"] / df["time_us (aiter)"]
df["speedup (aiter vs triton)"] = df["time_us (triton)"] / df["time_us (aiter)"]

df_md = df.to_markdown(index=False)
aiter.logger.info("topk_plain summary (markdown):\n%s", df_md)

def main():
rows = []
for batch_size in BATCH_SIZES:
for hiddensize in HIDDENSIZES:
for topk in TOPKS:
if topk > hiddensize:
continue
print(f"\n{'='*60}")
print(
f"Testing: batch_size={batch_size}, hiddensize={hiddensize}, topk={topk}"
)
print(f"{'='*60}")
ret = run_topk_case(batch_size, hiddensize, topk, largest, dtypes.fp32)
rows.append(
{
"batch_size": batch_size,
"hiddensize": hiddensize,
"topk": topk,
"error": ret["err"],
"time_us (aiter)": ret["us_aiter"],
"time_us (torch)": ret["us_torch"],
"time_us (triton)": ret["us_triton"],
}
)

df = pd.DataFrame(rows)
df["speedup (aiter vs torch)"] = df["time_us (torch)"] / df["time_us (aiter)"]
df["speedup (aiter vs triton)"] = df["time_us (triton)"] / df["time_us (aiter)"]
# Summary is informational only -- never let rendering fail the run.
try:
table = df.to_markdown(index=False)
except ImportError:
table = df.to_string(index=False) # `tabulate` not installed
aiter.logger.info("topk_plain summary:\n%s", table)


if __name__ == "__main__":
main()
Loading