From 262c3cf6b09297ea3a8dd4e8cec7828ceb204cb3 Mon Sep 17 00:00:00 2001 From: John Aslanides Date: Thu, 25 Jun 2020 05:01:23 -0700 Subject: [PATCH] Fix ZeroDiscountonLifeLoss wrapper and add test covereage. PiperOrigin-RevId: 318249207 Change-Id: I01c2f0b76d1b15fc42bf83538fbb1bc8c0757776 --- acme/wrappers/atari_wrapper.py | 2 +- acme/wrappers/atari_wrapper_test.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/acme/wrappers/atari_wrapper.py b/acme/wrappers/atari_wrapper.py index cb27fa6cb2..0a3043fe92 100644 --- a/acme/wrappers/atari_wrapper.py +++ b/acme/wrappers/atari_wrapper.py @@ -353,7 +353,7 @@ def __init__(self, environment: dm_env.Environment): self._last_num_lives = None def reset(self) -> dm_env.TimeStep: - timestep = self._env.reset() + timestep = self._environment.reset() self._reset_next_step = False self._last_num_lives = timestep.observation[LIVES_INDEX] return timestep diff --git a/acme/wrappers/atari_wrapper_test.py b/acme/wrappers/atari_wrapper_test.py index bd1815f7b3..0775b96f4d 100644 --- a/acme/wrappers/atari_wrapper_test.py +++ b/acme/wrappers/atari_wrapper_test.py @@ -17,6 +17,7 @@ import unittest from absl.testing import absltest +from absl.testing import parameterized from acme.wrappers import atari_wrapper from dm_env import specs import numpy as np @@ -33,12 +34,14 @@ @unittest.skipIf(SKIP_GYM_TESTS, SKIP_GYM_MESSAGE) -class AtariWrapperTest(absltest.TestCase): +class AtariWrapperTest(parameterized.TestCase): - def test_pong(self): + @parameterized.parameters(True, False) + def test_pong(self, zero_discount_on_life_loss: bool): env = gym.make('PongNoFrameskip-v4', full_action_space=True) env = gym_wrapper.GymAtariAdapter(env) - env = atari_wrapper.AtariWrapper(env) + env = atari_wrapper.AtariWrapper( + env, zero_discount_on_life_loss=zero_discount_on_life_loss) # Test converted observation spec. observation_spec = env.observation_spec()