[JAX] Remove xla deterministic arg for MNIST test to not timeout L2_jax_unittest CI#2952
[JAX] Remove xla deterministic arg for MNIST test to not timeout L2_jax_unittest CI#2952tdophung wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te_ci L2 jax |
Greptile SummaryThis PR surgically removes Confidence Score: 5/5Safe to merge — a targeted, well-commented removal of a single environment variable export that was causing CI timeouts. The change is minimal (one flag removed, a comment added, the flag re-scoped to encoder tests only). No logic changes, no new code paths. The only finding is a pre-existing P2 cosmetic issue with a duplicate flag append that doesn't affect correctness. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Start test.sh] --> B[Run jax not-distributed tests]
B --> C[pip install mnist requirements]
C --> D["Run MNIST pytest\n(no --xla_gpu_deterministic_ops)"]
D --> E[pip install encoder requirements]
E --> F["export XLA_FLAGS +=\n--xla_gpu_deterministic_ops"]
F --> G[Run encoder pytest\n- first run]
G --> H["export XLA_FLAGS +=\n--xla_gpu_deterministic_ops\n(flag now duplicated in env)"]
H --> I["Run encoder pytest\n- NVTE_JAX_CUSTOM_CALLS=false"]
I --> J{Any failures?}
J -- Yes --> K[Print failed cases & exit 1]
J -- No --> L[Print All tests passed & exit 0]
Reviews (1): Last reviewed commit: "remove xla deterministic arg to not time..." | Re-trigger Greptile |
| export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" | ||
| NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" | ||
| # Test without custom calls | ||
| export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" |
There was a problem hiding this comment.
Duplicate
--xla_gpu_deterministic_ops accumulation in XLA_FLAGS
XLA_FLAGS is appended with --xla_gpu_deterministic_ops on line 42 and then again on line 45, so the second encoder run (NVTE_JAX_CUSTOM_CALLS=false) sees the flag twice in the environment variable. This was a pre-existing pattern carried forward by this PR. XLA likely tolerates duplicate flags, but it's worth using a guard to avoid the accumulation:
| export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" | |
| export XLA_FLAGS="${XLA_FLAGS:+$XLA_FLAGS }--xla_gpu_deterministic_ops" |
Or simply remove the redundant second export since the flag is already set from line 42.
KshitijLakhani
left a comment
There was a problem hiding this comment.
Changes look good to me (assuming the CI passes for the L2 tests and this PR is identified to indeed avoid L2 time outs) as it is basically reverting the change in 2933.
Approved to merge (once CI fully passes)
Thanks @tdophung for quick TAT
|
/te-ci jax L2 |
Description
PR #2933 introduced the xla deterministic flag to help stabilize the L2 jax unittest result for MNIST test, as there were unpredictable noise near loss convergence. However this makes the test too slow and time out our CI. This PR will remove it. The mechanism to evaluate the loss closeness to desired has changed to be min of a window of the last 10% of steps (or min of 2) so this should already been more robust in L2 jax CI, without needing the xla deterministic flag
Fixes # (issue)
Type of change
Changes
remove xla deterministic flag
Checklist: