Stable Baseline Running
This commit is contained in:
parent
134f06b3d7
commit
c1cb7a4ffc
@ -1,4 +1,3 @@
|
||||
import abc
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
import gym
|
||||
@ -43,10 +42,6 @@ class BaseFactory(gym.Env):
|
||||
def observation_space(self):
|
||||
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
|
||||
|
||||
@property
|
||||
def monitor_as_df_list(self):
|
||||
return [x.to_pd_dataframe() for x in self._monitor_list]
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return (int(self.allow_vertical_movement) + int(self.allow_horizontal_movement)) * 4
|
||||
@ -61,7 +56,6 @@ class BaseFactory(gym.Env):
|
||||
self.allow_vertical_movement = True
|
||||
self.allow_horizontal_movement = True
|
||||
self.allow_no_OP = True
|
||||
self._monitor_list = list()
|
||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||
self.level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
@ -77,7 +71,6 @@ class BaseFactory(gym.Env):
|
||||
self.steps = 0
|
||||
self.cumulative_reward = 0
|
||||
self.monitor = FactoryMonitor(self)
|
||||
self._monitor_list.append(self.monitor)
|
||||
self.agent_states = []
|
||||
# Agent placement ...
|
||||
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
||||
@ -92,7 +85,6 @@ class BaseFactory(gym.Env):
|
||||
# state.shape = level, agent 1,..., agent n,
|
||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||
# Returns State
|
||||
|
||||
return self.state
|
||||
|
||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
@ -122,11 +114,11 @@ class BaseFactory(gym.Env):
|
||||
|
||||
self.agent_states = states
|
||||
reward, info = self.calculate_reward(states)
|
||||
self.cumulative_reward += reward
|
||||
|
||||
if self.steps >= self.max_steps:
|
||||
self.done = True
|
||||
return self.state, self.cumulative_reward, self.done, info
|
||||
self.monitor.add('step_reward', reward)
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
def _is_moving_action(self, action):
|
||||
return action < self.movement_actions
|
||||
|
@ -120,8 +120,8 @@ class SimpleFactory(BaseFactory):
|
||||
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
|
||||
|
||||
try:
|
||||
this_step_reward = -(dirty_tiles / current_dirt_amount)
|
||||
except ZeroDivisionError:
|
||||
this_step_reward = (dirty_tiles / current_dirt_amount)
|
||||
except (ZeroDivisionError, RuntimeWarning):
|
||||
this_step_reward = 0
|
||||
|
||||
for agent_state in agent_states:
|
||||
@ -136,6 +136,7 @@ class SimpleFactory(BaseFactory):
|
||||
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
||||
self.monitor.set('dirt_amount', current_dirt_amount)
|
||||
self.monitor.set('dirty_tiles', dirty_tiles)
|
||||
print(f"reward is {this_step_reward}")
|
||||
return this_step_reward, {}
|
||||
|
||||
|
||||
|
@ -60,17 +60,26 @@ class MonitorCallback(BaseCallback):
|
||||
super(MonitorCallback, self).__init__()
|
||||
self._outpath = Path(outpath)
|
||||
self._filename = filename
|
||||
self._monitor_list = list()
|
||||
self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick'
|
||||
self.env = env
|
||||
self.started = False
|
||||
self.closed = False
|
||||
|
||||
@property
|
||||
def monitor_as_df_list(self):
|
||||
return [x.to_pd_dataframe() for x in self._monitor_list]
|
||||
|
||||
def __enter__(self):
|
||||
self._on_training_start()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._on_training_end()
|
||||
|
||||
def _on_rollout_end(self) -> None:
|
||||
self._monitor_list.append(self.env.monitor)
|
||||
pass
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
if self.started:
|
||||
pass
|
||||
@ -85,7 +94,7 @@ class MonitorCallback(BaseCallback):
|
||||
else:
|
||||
# self.out_file.unlink(missing_ok=True)
|
||||
with self.out_file.open('wb') as f:
|
||||
pickle.dump(self.env.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.closed = True
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
|
Loading…
x
Reference in New Issue
Block a user