-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
122 lines (103 loc) · 4.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gymnasium
import wandb
import hydra
from omegaconf import DictConfig, OmegaConf, open_dict
import os
import numpy as np
from pprint import pprint # noqa: F401
import git
from src.utils import dict_to_id
from src.experiment import Experiment
from src.wrappers import monitor_wrappers
import src.actor
import src.critic
@hydra.main(version_base=None, config_path="configs", config_name="default")
def run(cfg: DictConfig) -> None:
config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
# pprint(config)
group = dict_to_id(cfg.environment) + "/" + dict_to_id(cfg.monitor)
sha = git.Repo(search_parent_directories=True).head.object.hexsha
base_folder = os.path.join(sha, group)
run_id = "_".join(
[
str(cfg.algorithm.id),
str(cfg.agent.critic.q0_min),
str(cfg.agent.critic.q0_max),
str(cfg.agent.critic.q0_visit_min),
str(cfg.agent.critic.q0_visit_max),
str(cfg.agent.actor.eps.init_value),
str(cfg.agent.actor.eps.min_value),
str(cfg.agent.actor.beta_bar),
str(cfg.experiment.rng_seed),
]
)
if cfg.experiment.datadir is not None:
filepath = os.path.join(cfg.experiment.datadir, base_folder)
os.makedirs(filepath, exist_ok=True)
filepath = os.path.join(filepath, run_id)
if os.path.isfile(filepath + ".npz"):
print(" [RUN ALREADY DONE]")
return
wandb.init(
group=group,
config=config,
settings=wandb.Settings(
start_method="thread",
_disable_stats=True,
_disable_meta=True,
),
**cfg.wandb,
)
if cfg.environment.id in ["Gym-Grid/Gridworld-TwoRoom-Distract-Middle-2x11-v0"]:
if cfg.monitor.id in ["ButtonMonitor"]:
with open_dict(cfg):
cfg.monitor.button_cell_id = 16
# Fix max Q for infinite horizon MDPs
if cfg.environment.id in ["RiverSwim"]:
if cfg.agent.critic.q0_max == 1.0: # optimistic
cfg.agent.critic.q0_max = 50.0
if cfg.agent.critic.q0_min == 1.0:
cfg.agent.critic.q0_min = 50.0
# More training steps for Mon-MDPs
if cfg.monitor.id in ["NMonitor"]:
cfg.experiment.training_steps *= 10
elif cfg.monitor.id in ["ButtonMonitor"]:
cfg.experiment.training_steps *= 2
elif cfg.monitor.id in ["StatelessBinaryMonitor"]:
cfg.experiment.training_steps *= 3
elif cfg.monitor.id in ["LevelMonitor"]:
cfg.experiment.training_steps *= 20
elif cfg.monitor.id in ["BatteryMonitor"]:
cfg.experiment.training_steps *= 10
# Decay learning rate in stochastic monitors
if cfg.monitor.id in ["NMonitor"]:
cfg.agent.critic.lr.min_value = min(0.1, cfg.agent.critic.lr.min_value)
cfg.agent.critic.lr_visit.min_value = min(0.1, cfg.agent.critic.lr_visit.min_value)
env = gymnasium.make(**cfg.environment)
if "reward_noise_std" in cfg.environment.keys():
cfg.environment.reward_noise_std = 0.0 # test without noise so we need 1 episode only
env_test = gymnasium.make(**cfg.environment)
env = getattr(monitor_wrappers, cfg.monitor.id)(env, **cfg.monitor)
env_test = getattr(monitor_wrappers, cfg.monitor.id)(env_test, **cfg.monitor)
sizes = (
env.observation_space["env"].n,
env.observation_space["mon"].n,
env.action_space["env"].n,
env.action_space["mon"].n,
)
critic = getattr(src.critic, cfg.agent.critic.id)(*sizes, **cfg.agent.critic)
actor = getattr(src.actor, cfg.agent.actor.id)(critic, **cfg.agent.actor)
experiment = Experiment(env, env_test, actor, critic, **cfg.experiment)
data = experiment.train()
if cfg.experiment.datadir is not None:
np.savez(filepath, **data)
if cfg.experiment.debugdir is not None:
from plot_gridworld_agent import plot_agent
filepath = os.path.join(cfg.experiment.debugdir, base_folder)
os.makedirs(filepath, exist_ok=True)
filepath = os.path.join(filepath, run_id)
os.makedirs(filepath, exist_ok=True)
plot_agent(actor, critic, filepath)
wandb.finish()
if __name__ == "__main__":
run()