Stable Baseline Running

This commit is contained in:
steffen-illium 2021-05-19 18:27:22 +02:00
parent 134f06b3d7
commit c1cb7a4ffc
3 changed files with 15 additions and 13 deletions

View File

@ -1,4 +1,3 @@
import abc
from typing import List, Union, Iterable
import gym
@ -43,10 +42,6 @@ class BaseFactory(gym.Env):
def observation_space(self):
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
@property
def monitor_as_df_list(self):
return [x.to_pd_dataframe() for x in self._monitor_list]
@property
def movement_actions(self):
return (int(self.allow_vertical_movement) + int(self.allow_horizontal_movement)) * 4
@ -61,7 +56,6 @@ class BaseFactory(gym.Env):
self.allow_vertical_movement = True
self.allow_horizontal_movement = True
self.allow_no_OP = True
self._monitor_list = list()
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
@ -77,7 +71,6 @@ class BaseFactory(gym.Env):
self.steps = 0
self.cumulative_reward = 0
self.monitor = FactoryMonitor(self)
self._monitor_list.append(self.monitor)
self.agent_states = []
# Agent placement ...
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
@ -92,7 +85,6 @@ class BaseFactory(gym.Env):
# state.shape = level, agent 1,..., agent n,
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
# Returns State
return self.state
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
@ -122,11 +114,11 @@ class BaseFactory(gym.Env):
self.agent_states = states
reward, info = self.calculate_reward(states)
self.cumulative_reward += reward
if self.steps >= self.max_steps:
self.done = True
return self.state, self.cumulative_reward, self.done, info
self.monitor.add('step_reward', reward)
return self.state, reward, self.done, info
def _is_moving_action(self, action):
return action < self.movement_actions

View File

@ -120,8 +120,8 @@ class SimpleFactory(BaseFactory):
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
try:
this_step_reward = -(dirty_tiles / current_dirt_amount)
except ZeroDivisionError:
this_step_reward = (dirty_tiles / current_dirt_amount)
except (ZeroDivisionError, RuntimeWarning):
this_step_reward = 0
for agent_state in agent_states:
@ -136,6 +136,7 @@ class SimpleFactory(BaseFactory):
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
self.monitor.set('dirt_amount', current_dirt_amount)
self.monitor.set('dirty_tiles', dirty_tiles)
print(f"reward is {this_step_reward}")
return this_step_reward, {}

View File

@ -60,17 +60,26 @@ class MonitorCallback(BaseCallback):
super(MonitorCallback, self).__init__()
self._outpath = Path(outpath)
self._filename = filename
self._monitor_list = list()
self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick'
self.env = env
self.started = False
self.closed = False
@property
def monitor_as_df_list(self):
return [x.to_pd_dataframe() for x in self._monitor_list]
def __enter__(self):
self._on_training_start()
def __exit__(self, exc_type, exc_val, exc_tb):
self._on_training_end()
def _on_rollout_end(self) -> None:
self._monitor_list.append(self.env.monitor)
pass
def _on_training_start(self) -> None:
if self.started:
pass
@ -85,7 +94,7 @@ class MonitorCallback(BaseCallback):
else:
# self.out_file.unlink(missing_ok=True)
with self.out_file.open('wb') as f:
pickle.dump(self.env.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
self.closed = True
def _on_step(self) -> bool: