From f88e1c1a4d8d037326b8ca779475e842898b4973 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 26 Jun 2026 17:38:02 -0700 Subject: [PATCH] Add Eigh profile mode --- problems/linalg/eigh_py/eval.py | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/problems/linalg/eigh_py/eval.py b/problems/linalg/eigh_py/eval.py index c0dd353a..49004c70 100644 --- a/problems/linalg/eigh_py/eval.py +++ b/problems/linalg/eigh_py/eval.py @@ -9,6 +9,7 @@ from typing import Any, Optional import torch +from torch.cuda.nvtx import range as nvtx_range from reference import check_implementation, generate_input from utils import clear_l2_cache, set_seed @@ -251,6 +252,39 @@ def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: l return 0 if passed else 112 +def _run_single_profile(test: TestCase): + from submission import custom_kernel + + with nvtx_range("generate input"): + data = generate_input(**test.args) + torch.cuda.synchronize() + + cloned = _clone_data(data) + with nvtx_range("custom_kernel"): + output = custom_kernel(cloned) + torch.cuda.synchronize() + + return check_implementation(data, output) + + +def run_single_profile(pool: multiprocessing.Pool, test: TestCase): + return pool.apply(_run_single_profile, (test,)) + + +def run_profiling(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + test = tests[0] + logger.log("benchmark.0.spec", test.spec) + good, message = run_single_profile(pool, test) + if not good: + logger.log("benchmark.0.status", "fail") + logger.log("benchmark.0.error", message) + logger.log("check", "fail") + return 112 + logger.log("check", "pass") + return 0 + + def main(): fd = os.getenv("POPCORN_FD") if not fd: @@ -290,6 +324,8 @@ def main(): break logger.log("check", "pass" if passed else "fail") return 0 if passed else 112 + if mode == "profile": + return run_profiling(logger, pool, tests) return 2