mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Logging Monitor Callback
This commit is contained in:
parent
c1cb7a4ffc
commit
e7d31aa272
@ -56,6 +56,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.allow_vertical_movement = True
|
self.allow_vertical_movement = True
|
||||||
self.allow_horizontal_movement = True
|
self.allow_horizontal_movement = True
|
||||||
self.allow_no_OP = True
|
self.allow_no_OP = True
|
||||||
|
self.done_at_collision = True
|
||||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||||
self.level = h.one_hot_level(
|
self.level = h.one_hot_level(
|
||||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||||
@ -96,7 +97,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.steps += 1
|
self.steps += 1
|
||||||
|
|
||||||
# Move this in a seperate function?
|
# Move this in a seperate function?
|
||||||
states = list()
|
agent_states = list()
|
||||||
for agent_i, action in enumerate(actions):
|
for agent_i, action in enumerate(actions):
|
||||||
agent_i_state = AgentState(agent_i, action)
|
agent_i_state = AgentState(agent_i, action)
|
||||||
if self._is_moving_action(action):
|
if self._is_moving_action(action):
|
||||||
@ -107,13 +108,15 @@ class BaseFactory(gym.Env):
|
|||||||
pos, valid = self.additional_actions(agent_i, action)
|
pos, valid = self.additional_actions(agent_i, action)
|
||||||
# Update state accordingly
|
# Update state accordingly
|
||||||
agent_i_state.update(pos=pos, action_valid=valid)
|
agent_i_state.update(pos=pos, action_valid=valid)
|
||||||
states.append(agent_i_state)
|
agent_states.append(agent_i_state)
|
||||||
|
|
||||||
for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])):
|
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
|
||||||
states[i].update(collision_vector=collision_vec)
|
agent_states[i].update(collision_vector=collision_vec)
|
||||||
|
if self.done_at_collision and collision_vec.any():
|
||||||
|
self.done = True
|
||||||
|
|
||||||
self.agent_states = states
|
self.agent_states = agent_states
|
||||||
reward, info = self.calculate_reward(states)
|
reward, info = self.calculate_reward(agent_states)
|
||||||
|
|
||||||
if self.steps >= self.max_steps:
|
if self.steps >= self.max_steps:
|
||||||
self.done = True
|
self.done = True
|
||||||
|
@ -17,7 +17,7 @@ DIRT_INDEX = -1
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DirtProperties:
|
class DirtProperties:
|
||||||
clean_amount = 0.25
|
clean_amount = 10
|
||||||
max_spawn_ratio = 0.1
|
max_spawn_ratio = 0.1
|
||||||
gain_amount = 0.1
|
gain_amount = 0.1
|
||||||
spawn_frequency = 5
|
spawn_frequency = 5
|
||||||
@ -31,8 +31,9 @@ class SimpleFactory(BaseFactory):
|
|||||||
def _is_clean_up_action(self, action):
|
def _is_clean_up_action(self, action):
|
||||||
return self.action_space.n - 1 == action
|
return self.action_space.n - 1 == action
|
||||||
|
|
||||||
def __init__(self, *args, dirt_properties: DirtProperties, **kwargs):
|
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
|
||||||
self._dirt_properties = dirt_properties
|
self._dirt_properties = dirt_properties
|
||||||
|
self.verbose = verbose
|
||||||
super(SimpleFactory, self).__init__(*args, **kwargs)
|
super(SimpleFactory, self).__init__(*args, **kwargs)
|
||||||
self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
|
self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
|
||||||
self.renderer = None # expensive - dont use it when not required !
|
self.renderer = None # expensive - dont use it when not required !
|
||||||
@ -81,6 +82,9 @@ class SimpleFactory(BaseFactory):
|
|||||||
return pos, cleanup_was_sucessfull
|
return pos, cleanup_was_sucessfull
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
|
if self.state[h.LEVEL_IDX][self.agent_i_position(0)] == h.IS_OCCUPIED_CELL:
|
||||||
|
print(f'fAgent placed on wall!!!!, step is :{self.steps}')
|
||||||
|
raise Exception('Agent placed on wall!!!!')
|
||||||
_, _, _, info = super(SimpleFactory, self).step(actions)
|
_, _, _, info = super(SimpleFactory, self).step(actions)
|
||||||
if not self.next_dirt_spawn:
|
if not self.next_dirt_spawn:
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
@ -94,12 +98,6 @@ class SimpleFactory(BaseFactory):
|
|||||||
if self._is_clean_up_action(action):
|
if self._is_clean_up_action(action):
|
||||||
agent_i_pos = self.agent_i_position(agent_i)
|
agent_i_pos = self.agent_i_position(agent_i)
|
||||||
_, valid = self.clean_up(agent_i_pos)
|
_, valid = self.clean_up(agent_i_pos)
|
||||||
if valid:
|
|
||||||
print(f'Agent {agent_i} did just clean up some dirt at {agent_i_pos}.')
|
|
||||||
self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
|
|
||||||
else:
|
|
||||||
print(f'Agent {agent_i} just tried to clean up some dirt at {agent_i_pos}, but was unsucsessfull.')
|
|
||||||
self.monitor.add('failed_cleanup_attempt', 1)
|
|
||||||
return agent_i_pos, valid
|
return agent_i_pos, valid
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('This should not happen!!!')
|
raise RuntimeError('This should not happen!!!')
|
||||||
@ -120,25 +118,44 @@ class SimpleFactory(BaseFactory):
|
|||||||
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
|
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
this_step_reward = (dirty_tiles / current_dirt_amount)
|
# penalty = current_dirt_amount
|
||||||
|
penalty = 0
|
||||||
except (ZeroDivisionError, RuntimeWarning):
|
except (ZeroDivisionError, RuntimeWarning):
|
||||||
this_step_reward = 0
|
penalty = 0
|
||||||
|
inforcements = 0
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
collisions = agent_state.collisions
|
cols = agent_state.collisions
|
||||||
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||||
f'{[self.slice_strings[entity] for entity in collisions if entity != self.string_slices["dirt"]]}')
|
f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}')
|
||||||
if self._is_clean_up_action(agent_state.action) and agent_state.action_valid:
|
if self._is_clean_up_action(agent_state.action):
|
||||||
this_step_reward += 1
|
if agent_state.action_valid:
|
||||||
|
inforcements += 10
|
||||||
|
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
|
||||||
|
self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
|
||||||
|
else:
|
||||||
|
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
||||||
|
f'at {agent_state.pos}, but was unsucsessfull.')
|
||||||
|
self.monitor.add('failed_cleanup_attempt', 1)
|
||||||
|
elif self._is_moving_action(agent_state.action):
|
||||||
|
if not agent_state.action_valid:
|
||||||
|
penalty += 10
|
||||||
|
else:
|
||||||
|
inforcements += 1
|
||||||
|
|
||||||
for entity in collisions:
|
for entity in cols:
|
||||||
if entity != self.string_slices["dirt"]:
|
if entity != self.string_slices["dirt"]:
|
||||||
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
||||||
|
|
||||||
|
this_step_reward = max(0, inforcements-penalty)
|
||||||
self.monitor.set('dirt_amount', current_dirt_amount)
|
self.monitor.set('dirt_amount', current_dirt_amount)
|
||||||
self.monitor.set('dirty_tiles', dirty_tiles)
|
self.monitor.set('dirty_tiles', dirty_tiles)
|
||||||
print(f"reward is {this_step_reward}")
|
self.print(f"reward is {this_step_reward}")
|
||||||
return this_step_reward, {}
|
return this_step_reward, {}
|
||||||
|
|
||||||
|
def print(self, string):
|
||||||
|
if self.verbose:
|
||||||
|
print(string)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = True
|
||||||
|
@ -56,12 +56,10 @@ class FactoryMonitor:
|
|||||||
|
|
||||||
class MonitorCallback(BaseCallback):
|
class MonitorCallback(BaseCallback):
|
||||||
|
|
||||||
def __init__(self, env, outpath='debug_out', filename='monitor'):
|
def __init__(self, env, filepath=Path('debug_out/monitor.pick')):
|
||||||
super(MonitorCallback, self).__init__()
|
super(MonitorCallback, self).__init__()
|
||||||
self._outpath = Path(outpath)
|
self.filepath = Path(filepath)
|
||||||
self._filename = filename
|
|
||||||
self._monitor_list = list()
|
self._monitor_list = list()
|
||||||
self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick'
|
|
||||||
self.env = env
|
self.env = env
|
||||||
self.started = False
|
self.started = False
|
||||||
self.closed = False
|
self.closed = False
|
||||||
@ -84,7 +82,7 @@ class MonitorCallback(BaseCallback):
|
|||||||
if self.started:
|
if self.started:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.out_file.parent.mkdir(exist_ok=True, parents=True)
|
self.filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
self.started = True
|
self.started = True
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -93,7 +91,7 @@ class MonitorCallback(BaseCallback):
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# self.out_file.unlink(missing_ok=True)
|
# self.out_file.unlink(missing_ok=True)
|
||||||
with self.out_file.open('wb') as f:
|
with self.filepath.open('wb') as f:
|
||||||
pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
|
35
environments/logging/training.py
Normal file
35
environments/logging/training.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
|
||||||
|
|
||||||
|
class TraningMonitor(BaseCallback):
|
||||||
|
|
||||||
|
def __init__(self, filepath, flush_interval=None):
|
||||||
|
super(TraningMonitor, self).__init__()
|
||||||
|
self.values = dict()
|
||||||
|
self.filepath = Path(filepath)
|
||||||
|
self.flush_interval = flush_interval
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_training_start(self) -> None:
|
||||||
|
self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1)
|
||||||
|
|
||||||
|
def _flush(self):
|
||||||
|
df = pd.DataFrame.from_dict(self.values)
|
||||||
|
if not self.filepath.exists():
|
||||||
|
df.to_csv(self.filepath, mode='wb', header=True)
|
||||||
|
else:
|
||||||
|
df.to_csv(self.filepath, mode='a', header=False)
|
||||||
|
self.values = dict()
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item())
|
||||||
|
if self.num_timesteps % self.flush_interval == 0:
|
||||||
|
self._flush()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_training_end(self) -> None:
|
||||||
|
self._flush()
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user