From b27b893c78c5064ccf803fac7993c32016ec71d1 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 23 Sep 2024 10:14:31 -0700 Subject: [PATCH] Update libraries to use JAX's limited (and ill-advised) trace-state-querying APIs rather than depending on JAX's deeper internals, which are about to change. PiperOrigin-RevId: 677843398 --- haiku/_src/base.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/haiku/_src/base.py b/haiku/_src/base.py index f44c9d5b5..c23579540 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -70,12 +70,15 @@ class JaxTraceLevel(NamedTuple): @classmethod def current(cls): - # TODO(tomhennigan): Remove once a version of JAX is released incl PR#9423. - trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack - top_type = trace_stack[0].trace_type - level = trace_stack[-1].level - sublevel = jax_core.cur_sublevel() - return JaxTraceLevel(opaque=(top_type, level, sublevel)) + if jax.__version_info__ <= (0, 4, 33): + trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack + top_type = trace_stack[0].trace_type + level = trace_stack[-1].level + sublevel = jax_core.cur_sublevel() + return JaxTraceLevel(opaque=(top_type, level, sublevel)) + + ts = jax_core.get_opaque_trace_state(convention="haiku") + return JaxTraceLevel(opaque=ts) frame_ids = it.count()