Skip to content

Commit

Permalink
Make AtariTorso and RecurrentActor work without batch dimensions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 322760491
Change-Id: I13ef2a252cc323a7da898954f774b549f4dbe1fd
  • Loading branch information
katebaumli authored and copybara-github committed Jul 23, 2020
1 parent de0c14f commit 3162b59
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
7 changes: 4 additions & 3 deletions acme/agents/jax/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(

def select_action(self, observation: types.NestedArray) -> types.NestedArray:
key = next(self._rng)
# TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
observation = utils.add_batch_dim(observation)
action = self._policy(self._client.params, key, observation)
return utils.to_numpy_squeeze(action)
Expand Down Expand Up @@ -101,11 +102,11 @@ def select_action(self, observation: types.NestedArray) -> types.NestedArray:
action, new_state = self._recurrent_policy(
self._client.params,
key=next(self._rng),
observation=utils.add_batch_dim(observation),
observation=observation,
core_state=self._state)
self._prev_state = self._state # Keep previous state to save in replay.
self._state = new_state # Keep new state for next policy call.
return utils.to_numpy_squeeze(action)
return utils.to_numpy(action)

def observe_first(self, timestep: dm_env.TimeStep):
if self._adder:
Expand All @@ -115,7 +116,7 @@ def observe_first(self, timestep: dm_env.TimeStep):

def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep):
if self._adder:
numpy_state = utils.to_numpy_squeeze((self._prev_state))
numpy_state = utils.to_numpy(self._prev_state)
self._adder.add(action, next_timestep, extras=(numpy_state,))

def update(self):
Expand Down
12 changes: 7 additions & 5 deletions acme/agents/jax/actors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

"""Tests for actors."""
from typing import Tuple
from typing import Optional, Tuple

from absl.testing import absltest
from acme import environment_loop
Expand Down Expand Up @@ -85,14 +85,16 @@ def test_recurrent(self):

@_transform_without_rng
def network(inputs: jnp.ndarray, state: hk.LSTMState):
return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)
return hk.DeepRNN([lambda x: jnp.reshape(x, [-1]),
hk.LSTM(output_size)])(inputs, state)

@_transform_without_rng
def initial_state(batch_size: int):
network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
def initial_state(batch_size: Optional[int] = None):
network = hk.DeepRNN([lambda x: jnp.reshape(x, [-1]),
hk.LSTM(output_size)])
return network.initial_state(batch_size)

initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1)
initial_state = initial_state.apply(initial_state.init(next(rng)))
params = network.init(next(rng), obs, initial_state)

def policy(
Expand Down
14 changes: 11 additions & 3 deletions acme/jax/networks/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,20 @@ def __init__(self):
hk.Conv2D(64, [4, 4], 2),
jax.nn.relu,
hk.Conv2D(64, [3, 3], 1),
jax.nn.relu,
hk.Flatten(),
jax.nn.relu
])

def __call__(self, inputs: Images) -> jnp.ndarray:
return self._network(inputs)
inputs_rank = jnp.ndim(inputs)
batched_inputs = inputs_rank == 4
if inputs_rank < 3 or inputs_rank > 4:
raise ValueError('Expected input BHWC or HWC. Got rank %d' % inputs_rank)

outputs = self._network(inputs)

if batched_inputs:
return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D]
return jnp.reshape(outputs, [-1]) # [D]


def dqn_atari_network(num_actions: int) -> base.QNetwork:
Expand Down
4 changes: 4 additions & 0 deletions acme/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def to_numpy_squeeze(values: types.Nest) -> types.NestedArray:
return tree_util.tree_map(lambda x: np.array(x).squeeze(axis=0), values)


def to_numpy(values: types.Nest) -> types.NestedArray:
return tree_util.tree_map(np.array, values)


def fetch_devicearray(values: types.Nest) -> types.Nest:
"""Fetches and converts any DeviceArrays in `values`."""

Expand Down

0 comments on commit 3162b59

Please sign in to comment.