Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An POC implementation for a wrapper on top of OpenSpiel. #2293

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- imageio==2.26.0
- wandb
- dm_control
- open_spiel
- mlflow
- av
- coverage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ dependencies:
- scipy
- hydra-core
- patchelf
- open_spiel
- mujoco==2.3.3
- dm_control==1.0.11
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def _main(argv):
],
"dm_control": ["dm_control"],
"gym_continuous": ["gymnasium", "mujoco"],
"open_spiel": ["open_spiel"],
"rendering": ["moviepy"],
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"],
"utils": [
Expand Down
10 changes: 9 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.libs.meltingpot import MeltingpotEnv, MeltingpotWrapper
from torchrl.envs.libs.open_spiel import _has_open_spiel, OpenSpielEnv
from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv
from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv
Expand Down Expand Up @@ -183,7 +184,6 @@ def get_gym_pixel_wrapper():
if _has_vmas:
import vmas


if _has_envpool:
import envpool

Expand Down Expand Up @@ -3860,6 +3860,14 @@ def test_render(self, rollout_steps):
assert not torch.equal(rollout_penultimate_image, image_from_env)


@pytest.mark.skipif(not _has_open_spiel, reason="OpenSpiel not found")
class TestOpenSpiel:
@pytest.mark.parametrize("game", ["tic_tac_toe"])
def test_spec(self, game: str):
env = OpenSpielEnv(game=game)
check_env_specs(env)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
151 changes: 151 additions & 0 deletions torchrl/envs/libs/open_spiel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import importlib
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import GymLikeEnv

_has_open_spiel = importlib.util.find_spec("open_spiel") is not None

if _has_open_spiel:
import open_spiel


class OpenSpielEnv(GymLikeEnv):
"""OpenSpiel environment wrapper.

The OpenSpiel can be found here: https://github.com/google-deepmind/open_spiel/.

Paper: https://arxiv.org/abs/1908.09453

"""

git_url = "https://github.com/google-deepmind/open_spiel/"
libname = "open_spiel"

def __init__(self, env=None, **kwargs):
if env is not None:
kwargs["env"] = env
super().__init__(**kwargs)
self.state: Optional[open_spiel.python.rl_environment.TimeStep] = None

@property
def lib(self):
import open_spiel

return open_spiel

def _output_transform(
self, step_output: "open_spiel.python.rl_environment.TimeStep"
) -> Tuple[
Any,
float | np.ndarray,
bool | np.ndarray | None,
bool | np.ndarray | None,
bool | np.ndarray | None,
dict,
]:
self.state = step_output

rewards = np.asarray(step_output.rewards)
obs = self._get_observation()
done = step_output.step_type == open_spiel.python.rl_environment.StepType.LAST

return (
obs,
rewards,
done,
done,
done,
None,
)

def _reset_output_transform(self, step_output: Tuple) -> Tuple:
self.state = step_output
return self._get_observation(), None

def _get_observation(self):
state = self.state.observations["info_state"]
return torch.Tensor(
state,
device=self.device,
)

def read_action(self, action):
action_np = super().read_action(action)
return action_np[
self.state.current_player() : self.state.current_player() + 1, ...
]

def _check_kwargs(self, kwargs: Dict):
pass

def _init_env(self) -> Optional[int]:
pass

def _build_env(self, **kwargs) -> "open_spiel.python.rl_environment.Environment":
return open_spiel.python.rl_environment.Environment(**kwargs)

def _make_specs(self, env: "open_spiel.python.rl_environment.Environment") -> None:
spec = env.observation_spec()
num_players = env.num_players

self.observation_spec = self._make_observation_spec(
spec["info_state"], num_players
)
self.reward_spec = self._make_reward_spec(num_players)
self.action_spec = self._make_action_spec(env.action_spec(), num_players)
self.done_spec = self.done_spec = DiscreteTensorSpec(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
)

def _make_observation_spec(
self, info_state: Tuple[int, ...], num_players: int
) -> TensorSpec:
return CompositeSpec(
{
"observation": UnboundedContinuousTensorSpec(
shape=(num_players,) + info_state,
device=self.device,
)
},
shape=(),
)

def _make_reward_spec(self, num_players: int):
return CompositeSpec(
{
"reward": UnboundedContinuousTensorSpec(
shape=(num_players,),
device=self.device,
)
},
shape=(),
)

def _make_action_spec(self, org_action_spec: Dict[str, int], num_players: int):
dtype = org_action_spec["dtype"]
if dtype is not int:
raise ValueError(f"{dtype} is not supported yet")
return CompositeSpec(
{
"action": DiscreteTensorSpec(
n=org_action_spec["num_actions"],
shape=(num_players,),
device=self.device,
)
},
shape=(),
)

def _set_seed(self, seed: Optional[int]):
if seed is not None:
self._env.seed(seed)