Skip to content

[JAX] Remove xla deterministic arg for MNIST test to not timeout L2_jax_unittest CI#2952

Open
tdophung wants to merge 1 commit intoNVIDIA:mainfrom
tdophung:tdophung/remove_xla_deterministic
Open

[JAX] Remove xla deterministic arg for MNIST test to not timeout L2_jax_unittest CI#2952
tdophung wants to merge 1 commit intoNVIDIA:mainfrom
tdophung:tdophung/remove_xla_deterministic

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented May 1, 2026

Description

PR #2933 introduced the xla deterministic flag to help stabilize the L2 jax unittest result for MNIST test, as there were unpredictable noise near loss convergence. However this makes the test too slow and time out our CI. This PR will remove it. The mechanism to evaluate the loss closeness to desired has changed to be min of a window of the last 10% of steps (or min of 2) so this should already been more robust in L2 jax CI, without needing the xla deterministic flag

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

remove xla deterministic flag

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

tdophung commented May 1, 2026

/te_ci L2 jax

@tdophung tdophung marked this pull request as ready for review May 1, 2026 19:06
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 1, 2026

Greptile Summary

This PR surgically removes --xla_gpu_deterministic_ops from the MNIST test step (where it was causing CI timeouts on small conv/GEMM kernels) while preserving it for the encoder tests, and adds a clear comment explaining the intentional asymmetry. The change is straightforward and well-justified given that the MNIST verify() already uses a tail-window min/max strategy for noise tolerance.

Confidence Score: 5/5

Safe to merge — a targeted, well-commented removal of a single environment variable export that was causing CI timeouts.

The change is minimal (one flag removed, a comment added, the flag re-scoped to encoder tests only). No logic changes, no new code paths. The only finding is a pre-existing P2 cosmetic issue with a duplicate flag append that doesn't affect correctness.

No files require special attention.

Important Files Changed

Filename Overview
qa/L2_jax_unittest/test.sh Removes --xla_gpu_deterministic_ops for the MNIST test to prevent CI timeouts, moves the flag to encoder-only scope with an explanatory comment; the encoder section still exports the flag twice (lines 42 and 45), duplicating it in XLA_FLAGS for the second encoder run, which is a pre-existing cosmetic issue carried forward.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Start test.sh] --> B[Run jax not-distributed tests]
    B --> C[pip install mnist requirements]
    C --> D["Run MNIST pytest\n(no --xla_gpu_deterministic_ops)"]
    D --> E[pip install encoder requirements]
    E --> F["export XLA_FLAGS +=\n--xla_gpu_deterministic_ops"]
    F --> G[Run encoder pytest\n- first run]
    G --> H["export XLA_FLAGS +=\n--xla_gpu_deterministic_ops\n(flag now duplicated in env)"]
    H --> I["Run encoder pytest\n- NVTE_JAX_CUSTOM_CALLS=false"]
    I --> J{Any failures?}
    J -- Yes --> K[Print failed cases & exit 1]
    J -- No --> L[Print All tests passed & exit 0]
Loading

Reviews (1): Last reviewed commit: "remove xla deterministic arg to not time..." | Re-trigger Greptile

export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicate --xla_gpu_deterministic_ops accumulation in XLA_FLAGS

XLA_FLAGS is appended with --xla_gpu_deterministic_ops on line 42 and then again on line 45, so the second encoder run (NVTE_JAX_CUSTOM_CALLS=false) sees the flag twice in the environment variable. This was a pre-existing pattern carried forward by this PR. XLA likely tolerates duplicate flags, but it's worth using a guard to avoid the accumulation:

Suggested change
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
export XLA_FLAGS="${XLA_FLAGS:+$XLA_FLAGS }--xla_gpu_deterministic_ops"

Or simply remove the redundant second export since the flag is already set from line 42.

Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me (assuming the CI passes for the L2 tests and this PR is identified to indeed avoid L2 time outs) as it is basically reverting the change in 2933.

Approved to merge (once CI fully passes)
Thanks @tdophung for quick TAT

@KshitijLakhani KshitijLakhani changed the title Remove xla deterministic arg for MNIST test to not timeout L2_jax_unittest CI [JAX] Remove xla deterministic arg for MNIST test to not timeout L2_jax_unittest CI May 1, 2026
@tdophung
Copy link
Copy Markdown
Collaborator Author

tdophung commented May 1, 2026

/te-ci jax L2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants