diff --git a/op_tests/test_topk_plain.py b/op_tests/test_topk_plain.py index 6249475b46..b04d573820 100644 --- a/op_tests/test_topk_plain.py +++ b/op_tests/test_topk_plain.py @@ -1,84 +1,51 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +import gc + +import pandas as pd import torch -from aiter.test_common import ( - checkAllclose, - benchmark, - run_perftest, -) + import aiter from aiter import dtypes from aiter.ops.topk_plain import topk_plain -import pandas as pd +from aiter.test_common import benchmark, checkAllclose, run_perftest torch.set_default_device("cuda") torch.set_printoptions(sci_mode=False) +# Correctness sweep: a few timed iters is plenty (was 1000). This is a +# correctness check, not a perf gate; the high iter count made the file the +# slowest in its CI shard for no benefit. +NUM_ITERS = 100 +NUM_WARMUP = 10 + +# checkAllclose returns the fraction of mismatching elements (0 == exact) and +# does NOT raise; assert on it so an incorrect topk_plain actually fails CI. +TOL_ERR_RATIO = 0.05 + @benchmark() -def test_topk( - batch_size, - hiddensize, - topk, - largest, - dtype, -): - output = torch.randn((batch_size, hiddensize), dtype=dtype) - device = output.device +def run_topk_case(batch_size, hiddensize, topk, largest, dtype): + device = "cuda" + # Each row is a permutation of [0, hiddensize) -> distinct values for topk. + # Vectorised; replaces the per-row Python randperm loop (batch_size iters). + x = torch.rand(batch_size, hiddensize, device=device).argsort(dim=1).to(dtype) topk_ids = torch.zeros((batch_size, topk), dtype=dtypes.i32, device=device) topk_value = torch.zeros((batch_size, topk), dtype=dtype, device=device) - x = torch.arange(hiddensize, dtype=dtype).repeat(batch_size, 1) - for b in range(batch_size): - x[b] = x[b, torch.randperm(hiddensize)] - (ref_value, ref_index), us_ref = run_perftest( torch.topk, x, topk, largest=largest, - num_iters=1000, - num_warmup=100, + num_iters=NUM_ITERS, + num_warmup=NUM_WARMUP, ) - id_ref, _ref = torch.sort(ref_index) - # Try Triton, but handle resource errors gracefully - # try: - # (res_triton_value, res_triton_index), us_triton = run_perftest( - # triton_topk, - # x, - # topk, - # largest=largest, - # num_iters=1000, - # num_warmup=100, - # ) - - # id_triton, _triton = torch.sort(res_triton_index) - # checkAllclose( - # ref_value.gather(1, _ref), - # res_triton_value.gather(1, _triton), - # msg="topk_values [golden vs triton]", - # ) - # checkAllclose( - # id_ref, - # id_triton, - # msg=( - # f"topk_ids Performance Comparison:\n" - # f" {'Method':<10} {'Time (us)':>12}\n" - # f" {'-'*10} {'-'*12}\n" - # f" {'golden':<10} {us_ref:>12.2f}\n" - # f" {'triton':<10} {us_triton:>12.2f}\n" - # ), - # ) - # except Exception as e: - # print(f"Triton failed: {e}") - # print("Setting triton time to 0 and continuing...") - # us_triton = 0.0 - - # TODO: uncomment this when the triton topk return in a resonalbe execution time + # TODO: re-enable triton topk comparison when it returns in a reasonable time. us_triton = 0.0 _, us_aiter = run_perftest( @@ -88,51 +55,30 @@ def test_topk( topk_value, topk, largest, - torch.tensor( - [], dtype=torch.int32, device=device - ), # rowStarts - empty int32 tensor - torch.tensor( - [], dtype=torch.int32, device=device - ), # rowEnds - empty int32 tensor - -1, # stride0 - 1, # stride1 + torch.tensor([], dtype=torch.int32, device=device), # rowStarts + torch.tensor([], dtype=torch.int32, device=device), # rowEnds + -1, + 1, # stride0, stride1 + num_iters=NUM_ITERS, + num_warmup=NUM_WARMUP, ) - id_aiter, _aiter = torch.sort(topk_ids.to(torch.long)) - # Skip for float16 as it would has duplicates in topk_ids - if dtype != torch.float16 and dtype != torch.bfloat16: - # TODO: uncomment this when the aiter topk supports value return - # err = checkAllclose( - # ref_value.gather(1, _ref), - # topk_value.gather(1, _aiter), - # msg="topk_values [golden vs aiter]", - # ) - err = checkAllclose( - id_ref, - id_aiter, - msg=( - f"topk_ids Performance Comparison:\n" - f" {'Method':<10} {'Time (us)':>12}\n" - f" {'-'*10} {'-'*12}\n" - f" {'golden':<10} {us_ref:>12.2f}\n" - f" {'triton':<10} {us_triton:>12.2f}\n" - f" {'aiter':<10} {us_aiter:>12.2f}\n" - ), - ) + if dtype not in (torch.float16, torch.bfloat16): + err = checkAllclose(id_ref, id_aiter, msg="topk_ids [golden vs aiter]") else: - err = checkAllclose( - ref_value, - topk_value, - msg=( - f"topk_values [golden vs aiter]:\n" - f" {'Method':<10} {'Time (us)':>12}\n" - f" {'-'*10} {'-'*12}\n" - f" {'golden':<10} {us_ref:>12.2f}\n" - f" {'triton':<10} {us_triton:>12.2f}\n" - f" {'aiter':<10} {us_aiter:>12.2f}\n" - ), - ) + # fp16/bf16 can tie within the top-k -> compare values, not indices. + err = checkAllclose(ref_value, topk_value, msg="topk_values [golden vs aiter]") + assert err <= TOL_ERR_RATIO, ( + f"topk_plain mismatch: err ratio {err:.4f} > {TOL_ERR_RATIO} " + f"(batch_size={batch_size}, hiddensize={hiddensize}, topk={topk}, dtype={dtype})" + ) + + # Release this case's buffers before the next (larger) case allocates, so a + # memory-pressured runner does not OOM accumulating the whole sweep. + del x, topk_ids, topk_value, ref_value, ref_index, id_ref, id_aiter, _ref, _aiter + gc.collect() + torch.cuda.empty_cache() return { "err": err, @@ -142,50 +88,47 @@ def test_topk( } -# BATCH_SIZES = [100, 1000, 10000, 32679] -# HIDDENSIZES = [10000, 100000] -# topk = 64 -# BATCH_SIZES = [3072, 3072, 3072] -# HIDDENSIZES = [3072, 4096, 8192] BATCH_SIZES = [3072] HIDDENSIZES = [3072, 4096, 8192, 16384, 32768, 65536, 131072] -# HIDDENSIZES = [32768] TOPKS = [2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1] largest = True -df = [] -for batch_size in BATCH_SIZES: - for hiddensize in HIDDENSIZES: - for topk in TOPKS: - print(f"\n{'='*60}") - print( - f"Testing: batch_size={batch_size}, hiddensize={hiddensize}, topk={topk}" - ) - print(f"{'='*60}") - ret = test_topk( - batch_size, - hiddensize, - topk, - largest, - dtypes.fp32, - ) - df.append( - { - "batch_size": batch_size, - "hiddensize": hiddensize, - "topk": topk, - "error": ret["err"], - "time_us (aiter)": ret["us_aiter"], - "time_us (torch)": ret["us_torch"], - "time_us (triton)": ret["us_triton"], - } - ) - -df = pd.DataFrame(df) - -# Add speedup columns -df["speedup (aiter vs torch)"] = df["time_us (torch)"] / df["time_us (aiter)"] -df["speedup (aiter vs triton)"] = df["time_us (triton)"] / df["time_us (aiter)"] - -df_md = df.to_markdown(index=False) -aiter.logger.info("topk_plain summary (markdown):\n%s", df_md) + +def main(): + rows = [] + for batch_size in BATCH_SIZES: + for hiddensize in HIDDENSIZES: + for topk in TOPKS: + if topk > hiddensize: + continue + print(f"\n{'='*60}") + print( + f"Testing: batch_size={batch_size}, hiddensize={hiddensize}, topk={topk}" + ) + print(f"{'='*60}") + ret = run_topk_case(batch_size, hiddensize, topk, largest, dtypes.fp32) + rows.append( + { + "batch_size": batch_size, + "hiddensize": hiddensize, + "topk": topk, + "error": ret["err"], + "time_us (aiter)": ret["us_aiter"], + "time_us (torch)": ret["us_torch"], + "time_us (triton)": ret["us_triton"], + } + ) + + df = pd.DataFrame(rows) + df["speedup (aiter vs torch)"] = df["time_us (torch)"] / df["time_us (aiter)"] + df["speedup (aiter vs triton)"] = df["time_us (triton)"] / df["time_us (aiter)"] + # Summary is informational only -- never let rendering fail the run. + try: + table = df.to_markdown(index=False) + except ImportError: + table = df.to_string(index=False) # `tabulate` not installed + aiter.logger.info("topk_plain summary:\n%s", table) + + +if __name__ == "__main__": + main()