Logging Monitor Callback

This commit is contained in:
steffen-illium 2021-05-20 09:49:08 +02:00
parent c1cb7a4ffc
commit e7d31aa272
4 changed files with 83 additions and 30 deletions

View File

@ -56,6 +56,7 @@ class BaseFactory(gym.Env):
self.allow_vertical_movement = True
self.allow_horizontal_movement = True
self.allow_no_OP = True
self.done_at_collision = True
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
@ -96,7 +97,7 @@ class BaseFactory(gym.Env):
self.steps += 1
# Move this in a seperate function?
states = list()
agent_states = list()
for agent_i, action in enumerate(actions):
agent_i_state = AgentState(agent_i, action)
if self._is_moving_action(action):
@ -107,13 +108,15 @@ class BaseFactory(gym.Env):
pos, valid = self.additional_actions(agent_i, action)
# Update state accordingly
agent_i_state.update(pos=pos, action_valid=valid)
states.append(agent_i_state)
agent_states.append(agent_i_state)
for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])):
states[i].update(collision_vector=collision_vec)
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
agent_states[i].update(collision_vector=collision_vec)
if self.done_at_collision and collision_vec.any():
self.done = True
self.agent_states = states
reward, info = self.calculate_reward(states)
self.agent_states = agent_states
reward, info = self.calculate_reward(agent_states)
if self.steps >= self.max_steps:
self.done = True

View File

@ -17,7 +17,7 @@ DIRT_INDEX = -1
@dataclass
class DirtProperties:
clean_amount = 0.25
clean_amount = 10
max_spawn_ratio = 0.1
gain_amount = 0.1
spawn_frequency = 5
@ -31,8 +31,9 @@ class SimpleFactory(BaseFactory):
def _is_clean_up_action(self, action):
return self.action_space.n - 1 == action
def __init__(self, *args, dirt_properties: DirtProperties, **kwargs):
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
self._dirt_properties = dirt_properties
self.verbose = verbose
super(SimpleFactory, self).__init__(*args, **kwargs)
self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
self.renderer = None # expensive - dont use it when not required !
@ -81,6 +82,9 @@ class SimpleFactory(BaseFactory):
return pos, cleanup_was_sucessfull
def step(self, actions):
if self.state[h.LEVEL_IDX][self.agent_i_position(0)] == h.IS_OCCUPIED_CELL:
print(f'fAgent placed on wall!!!!, step is :{self.steps}')
raise Exception('Agent placed on wall!!!!')
_, _, _, info = super(SimpleFactory, self).step(actions)
if not self.next_dirt_spawn:
self.spawn_dirt()
@ -94,12 +98,6 @@ class SimpleFactory(BaseFactory):
if self._is_clean_up_action(action):
agent_i_pos = self.agent_i_position(agent_i)
_, valid = self.clean_up(agent_i_pos)
if valid:
print(f'Agent {agent_i} did just clean up some dirt at {agent_i_pos}.')
self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
else:
print(f'Agent {agent_i} just tried to clean up some dirt at {agent_i_pos}, but was unsucsessfull.')
self.monitor.add('failed_cleanup_attempt', 1)
return agent_i_pos, valid
else:
raise RuntimeError('This should not happen!!!')
@ -120,25 +118,44 @@ class SimpleFactory(BaseFactory):
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
try:
this_step_reward = (dirty_tiles / current_dirt_amount)
# penalty = current_dirt_amount
penalty = 0
except (ZeroDivisionError, RuntimeWarning):
this_step_reward = 0
penalty = 0
inforcements = 0
for agent_state in agent_states:
collisions = agent_state.collisions
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
f'{[self.slice_strings[entity] for entity in collisions if entity != self.string_slices["dirt"]]}')
if self._is_clean_up_action(agent_state.action) and agent_state.action_valid:
this_step_reward += 1
cols = agent_state.collisions
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}')
if self._is_clean_up_action(agent_state.action):
if agent_state.action_valid:
inforcements += 10
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
else:
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
f'at {agent_state.pos}, but was unsucsessfull.')
self.monitor.add('failed_cleanup_attempt', 1)
elif self._is_moving_action(agent_state.action):
if not agent_state.action_valid:
penalty += 10
else:
inforcements += 1
for entity in collisions:
for entity in cols:
if entity != self.string_slices["dirt"]:
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
this_step_reward = max(0, inforcements-penalty)
self.monitor.set('dirt_amount', current_dirt_amount)
self.monitor.set('dirty_tiles', dirty_tiles)
print(f"reward is {this_step_reward}")
self.print(f"reward is {this_step_reward}")
return this_step_reward, {}
def print(self, string):
if self.verbose:
print(string)
if __name__ == '__main__':
render = True

View File

@ -56,12 +56,10 @@ class FactoryMonitor:
class MonitorCallback(BaseCallback):
def __init__(self, env, outpath='debug_out', filename='monitor'):
def __init__(self, env, filepath=Path('debug_out/monitor.pick')):
super(MonitorCallback, self).__init__()
self._outpath = Path(outpath)
self._filename = filename
self.filepath = Path(filepath)
self._monitor_list = list()
self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick'
self.env = env
self.started = False
self.closed = False
@ -84,7 +82,7 @@ class MonitorCallback(BaseCallback):
if self.started:
pass
else:
self.out_file.parent.mkdir(exist_ok=True, parents=True)
self.filepath.parent.mkdir(exist_ok=True, parents=True)
self.started = True
pass
@ -93,7 +91,7 @@ class MonitorCallback(BaseCallback):
pass
else:
# self.out_file.unlink(missing_ok=True)
with self.out_file.open('wb') as f:
with self.filepath.open('wb') as f:
pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
self.closed = True

View File

@ -0,0 +1,35 @@
from pathlib import Path
import pandas as pd
from stable_baselines3.common.callbacks import BaseCallback
class TraningMonitor(BaseCallback):
def __init__(self, filepath, flush_interval=None):
super(TraningMonitor, self).__init__()
self.values = dict()
self.filepath = Path(filepath)
self.flush_interval = flush_interval
pass
def _on_training_start(self) -> None:
self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1)
def _flush(self):
df = pd.DataFrame.from_dict(self.values)
if not self.filepath.exists():
df.to_csv(self.filepath, mode='wb', header=True)
else:
df.to_csv(self.filepath, mode='a', header=False)
self.values = dict()
def _on_step(self) -> bool:
self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item())
if self.num_timesteps % self.flush_interval == 0:
self._flush()
return True
def on_training_end(self) -> None:
self._flush()