diff --git a/acme/agents/jax/bc/agent_test.py b/acme/agents/jax/bc/agent_test.py index e266b49a02..67e3e4cb9f 100644 --- a/acme/agents/jax/bc/agent_test.py +++ b/acme/agents/jax/bc/agent_test.py @@ -21,7 +21,6 @@ from acme.jax import types as jax_types from acme.jax import utils from acme.testing import fakes -import chex import haiku as hk import jax import jax.numpy as jnp @@ -103,7 +102,7 @@ class BCTest(parameterized.TestCase): ('peerbc',) ) def test_continuous_actions(self, loss_name): - with chex.fake_pmap_and_jit(): + with jax.disable_jit(): num_sgd_steps_per_step = 1 num_steps = 5 @@ -145,7 +144,7 @@ def test_continuous_actions(self, loss_name): ('logp',), ('rcal',)) def test_discrete_actions(self, loss_name): - with chex.fake_pmap_and_jit(): + with jax.disable_jit(): num_sgd_steps_per_step = 1 num_steps = 5 diff --git a/acme/agents/jax/bc/learning.py b/acme/agents/jax/bc/learning.py index 46eb2607a9..edc9ac48e7 100644 --- a/acme/agents/jax/bc/learning.py +++ b/acme/agents/jax/bc/learning.py @@ -150,20 +150,32 @@ def sgd_step( # Split the input batch to `num_sgd_steps_per_step` minibatches in order # to achieve better performance on accelerators. sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) - self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) - - random_key, init_key = jax.random.split(random_key) - policy_params = networks.policy_network.init(init_key) - optimizer_state = optimizer.init(policy_params) - - # Create initial state. - state = TrainingState( - optimizer_state=optimizer_state, - policy_params=policy_params, - key=random_key, - steps=0, + self._sgd_step = jax.pmap( + sgd_step, + axis_name=_PMAP_AXIS_NAME, + in_axes=(None, 0), + out_axes=(None, 0), ) - self._state = utils.replicate_in_all_devices(state) + + def init_fn(random_key): + random_key, init_key = jax.random.split(random_key) + policy_params = networks.policy_network.init(init_key) + optimizer_state = optimizer.init(policy_params) + + # Create initial state. + state = TrainingState( + optimizer_state=optimizer_state, + policy_params=policy_params, + key=random_key, + steps=0, + ) + return state + + state = jax.pmap(init_fn, out_axes=None)( + utils.replicate_in_all_devices(random_key) + ) + self._state = state + self._state_sharding = jax.tree.map(lambda x: x.sharding, state) self._timestamp = None @@ -188,13 +200,13 @@ def step(self): def get_variables(self, names: List[str]) -> List[networks_lib.Params]: variables = { - 'policy': utils.get_from_first_device(self._state.policy_params), + 'policy': self._state.policy_params, } return [variables[name] for name in names] def save(self) -> TrainingState: # Serialize only the first replica of parameters and optimizer state. - return jax.tree.map(utils.get_from_first_device, self._state) + return self._state def restore(self, state: TrainingState): - self._state = utils.replicate_in_all_devices(state) + self._state = jax.device_put(state, self._state_sharding) diff --git a/acme/agents/jax/mbop/agent_test.py b/acme/agents/jax/mbop/agent_test.py index db0fcadc3e..54124e4419 100644 --- a/acme/agents/jax/mbop/agent_test.py +++ b/acme/agents/jax/mbop/agent_test.py @@ -23,7 +23,6 @@ from acme.agents.jax.mbop import networks as mbop_networks from acme.testing import fakes from acme.utils import loggers -import chex import jax import optax import rlds @@ -34,7 +33,7 @@ class MBOPTest(absltest.TestCase): def test_learner(self): - with chex.fake_pmap_and_jit(): + with jax.disable_jit(): num_sgd_steps_per_step = 1 num_steps = 5 num_networks = 7 diff --git a/acme/agents/jax/mpo/builder.py b/acme/agents/jax/mpo/builder.py index 1f6f30df82..2ce7816d32 100644 --- a/acme/agents/jax/mpo/builder.py +++ b/acme/agents/jax/mpo/builder.py @@ -38,7 +38,6 @@ from acme.jax import variable_utils from acme.utils import counting from acme.utils import loggers -import chex import jax import optax import reverb @@ -162,8 +161,7 @@ def make_learner(self, 'learner', steps_key=counter.get_steps_key() if counter else 'learner_steps') - with chex.fake_pmap_and_jit(not self.config.jit_learner, - not self.config.jit_learner): + with jax.disable_jit(not self.config.jit_learner): learner = learning.MPOLearner( iterator=dataset, networks=networks,