diff --git a/kfac_jax/_src/utils/parallel.py b/kfac_jax/_src/utils/parallel.py index 20d26d1..cef7beb 100644 --- a/kfac_jax/_src/utils/parallel.py +++ b/kfac_jax/_src/utils/parallel.py @@ -35,15 +35,7 @@ def in_pmap(axis_name: str | None) -> bool: if axis_name is None: return False - try: - # The only way to know if we are under `jax.pmap` is to check if the - # function call below raises a `NameError` or not. - core.axis_frame(axis_name) - - return True - - except NameError: - return False + return axis_name in core.unsafe_get_axis_names_DO_NOT_USE() def wrap_if_pmap(