From a1930fc049b6c420b37bb7123ed01a48a5d98d9c Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 2 Oct 2024 14:01:18 -0700 Subject: [PATCH] Stackless yashful PiperOrigin-RevId: 681582933 --- kfac_jax/_src/utils/parallel.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/kfac_jax/_src/utils/parallel.py b/kfac_jax/_src/utils/parallel.py index 20d26d1..cef7beb 100644 --- a/kfac_jax/_src/utils/parallel.py +++ b/kfac_jax/_src/utils/parallel.py @@ -35,15 +35,7 @@ def in_pmap(axis_name: str | None) -> bool: if axis_name is None: return False - try: - # The only way to know if we are under `jax.pmap` is to check if the - # function call below raises a `NameError` or not. - core.axis_frame(axis_name) - - return True - - except NameError: - return False + return axis_name in core.unsafe_get_axis_names_DO_NOT_USE() def wrap_if_pmap(