Skip to content

Commit

Permalink
Update libraries to use JAX's limited (and ill-advised) trace-state-q…
Browse files Browse the repository at this point in the history
…uerying APIs rather than depending on JAX's deeper internals, which are about to change.

PiperOrigin-RevId: 677843398
  • Loading branch information
dougalm authored and copybara-github committed Sep 24, 2024
1 parent 4773949 commit b27b893
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ class JaxTraceLevel(NamedTuple):

@classmethod
def current(cls):
# TODO(tomhennigan): Remove once a version of JAX is released incl PR#9423.
trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack
top_type = trace_stack[0].trace_type
level = trace_stack[-1].level
sublevel = jax_core.cur_sublevel()
return JaxTraceLevel(opaque=(top_type, level, sublevel))
if jax.__version_info__ <= (0, 4, 33):
trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack
top_type = trace_stack[0].trace_type
level = trace_stack[-1].level
sublevel = jax_core.cur_sublevel()
return JaxTraceLevel(opaque=(top_type, level, sublevel))

ts = jax_core.get_opaque_trace_state(convention="haiku")
return JaxTraceLevel(opaque=ts)

frame_ids = it.count()

Expand Down

0 comments on commit b27b893

Please sign in to comment.