[test] test_topk_plain: parametrize sweep to fix collection-time OOM#3934
Open
JohnQinAMD wants to merge 1 commit into
Open
[test] test_topk_plain: parametrize sweep to fix collection-time OOM#3934JohnQinAMD wants to merge 1 commit into
JohnQinAMD wants to merge 1 commit into
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
This PR refactors op_tests/test_topk_plain.py to avoid collection/import-time GPU OOMs caused by a module-level sweep that allocates tensors for all cases up front. It converts the sweep into isolated, lazy-per-case execution using pytest parameterization, adds teardown to free GPU memory between cases, and preserves the end-of-run performance summary.
Changes:
- Replace the module-level sweep loop with
@pytest.mark.parametrizetest cases to prevent collection-time allocations and speed up collection. - Add an autouse fixture to GC +
torch.cuda.empty_cache()between cases and a session-scoped fixture to emit a single markdown perf summary (with a safe fallback whentabulateis missing). - Vectorize the input permutation generation and reduce iteration counts (1000 → 20) to keep this as a correctness-focused test.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
b427fd5 to
7c7c7f6
Compare
CI runs each op test as a script (`python3 <file>` in aiter_test.sh), not via pytest. test_topk_plain.py sweeps 84 cases up to 3072x131072 fp32; with no cleanup between cases, a memory-pressured gfx950 runner OOMs mid-sweep and the whole file exits non-zero -> intermittent shard failures unrelated to the code under test. Fix (keeps the `python3 <file>` contract): - free each case's tensors (del + gc + torch.cuda.empty_cache()) so peak memory is one case, not the whole sweep - guard the sweep under `if __name__ == "__main__":` (clean import) - num_iters 1000 -> 100 (correctness check, not a perf gate; was the slowest file in its shard) - vectorize the per-row permutation (drops a batch_size-long Python loop) - assert on checkAllclose's returned error ratio (it does NOT raise) so an incorrect topk_plain actually fails CI instead of silently passing - summary table falls back to df.to_string when `tabulate` is absent Validated: `python3 op_tests/test_topk_plain.py` runs all 84 cases, all pass (err 0), summary prints, exit 0 on MI355X (gfx950). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
7c7c7f6 to
9a2655e
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
CI runs each op test as a script (
python3 <file>in.github/scripts/aiter_test.sh), not via pytest.test_topk_plain.pysweeps 84 cases up to3072 × 131072fp32 and never frees tensors between cases, so on a memory-pressured gfx950 runner it OOMs mid-sweep and the whole file exits non-zero — producing intermittent "Standard Tests (MI35X, 8, 5)" failures unrelated to the code under test. On a clean GPU it passes; the failure is purely residual-memory dependent.Fix (keeps the
python3 <file>contract)del+gc.collect()+torch.cuda.empty_cache()) so peak memory is one case, not the whole sweep — the actual OOM fix.if __name__ == "__main__":so import is side-effect-free.num_iters1000 → 100 — this is a correctness check, not a perf gate; the high iteration count made it the slowest file in its shard for no benefit.batch_size-long Python loop).checkAllclose's returned error ratio — it does not raise, so the test previously would have passed even iftopk_plainwere incorrect.df.to_stringwhentabulateis absent, so the informational summary can never fail the run.Validation
python3 op_tests/test_topk_plain.pyon MI355X (gfx950): all 84 cases run, allpassed~(err 0), summary prints, exit 0.Addresses both Copilot review comments (script-style no-op under
python3; assert on the error ratio).Motivation
Technical Details
Test Plan
Test Result
Submission Checklist