diff --git a/haiku/_src/dot.py b/haiku/_src/dot.py index ab961ada5..e655d9b35 100644 --- a/haiku/_src/dot.py +++ b/haiku/_src/dot.py @@ -134,7 +134,7 @@ def to_graph(fun): @functools.wraps(fun) def wrapped_fun(*args): """See `fun`.""" - f = jax.linear_util.wrap_init(fun) + f = jax.extend.linear_util.wrap_init(fun) args_flat, in_tree = jax.tree_util.tree_flatten((args, {})) flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree) graph = Graph.create(title=name_or_str(fun)) @@ -202,7 +202,7 @@ def process_primitive(self, primitive, tracers, params): if primitive is pjit.pjit_p: f = jax.core.jaxpr_as_fun(params['jaxpr']) f.__name__ = params['name'] - fun = jax.linear_util.wrap_init(f) + fun = jax.extend.linear_util.wrap_init(f) return self.process_call(primitive, fun, tracers, params) inputs = [t.val for t in tracers]