mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
rewards and monitors
This commit is contained in:
parent
3114cdffc3
commit
962f914ff9
@ -50,13 +50,13 @@ class BaseFactory(gym.Env):
|
|||||||
def string_slices(self):
|
def string_slices(self):
|
||||||
return {value: key for key, value in self.slice_strings.items()}
|
return {value: key for key, value in self.slice_strings.items()}
|
||||||
|
|
||||||
def __init__(self, level='simple', n_agents=1, max_steps=1e3):
|
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2)):
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.allow_vertical_movement = True
|
self.allow_vertical_movement = True
|
||||||
self.allow_horizontal_movement = True
|
self.allow_horizontal_movement = True
|
||||||
self.allow_no_OP = True
|
self.allow_no_OP = True
|
||||||
self.done_at_collision = True
|
self.done_at_collision = False
|
||||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||||
self.level = h.one_hot_level(
|
self.level = h.one_hot_level(
|
||||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||||
|
@ -120,38 +120,38 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# penalty = current_dirt_amount
|
# penalty = current_dirt_amount
|
||||||
penalty = 0
|
reward = 0
|
||||||
except (ZeroDivisionError, RuntimeWarning):
|
except (ZeroDivisionError, RuntimeWarning):
|
||||||
penalty = 0
|
reward = 0
|
||||||
inforcements = 0
|
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
cols = agent_state.collisions
|
cols = agent_state.collisions
|
||||||
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||||
f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}')
|
f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}')
|
||||||
if self._is_clean_up_action(agent_state.action):
|
if self._is_clean_up_action(agent_state.action):
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
inforcements += 10
|
reward += 2
|
||||||
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
|
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
|
||||||
self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
|
self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
|
||||||
else:
|
else:
|
||||||
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
||||||
f'at {agent_state.pos}, but was unsucsessfull.')
|
f'at {agent_state.pos}, but was unsucsessfull.')
|
||||||
self.monitor.add('failed_cleanup_attempt', 1)
|
self.monitor.add('failed_cleanup_attempt', 1)
|
||||||
|
reward -= 0.05
|
||||||
elif self._is_moving_action(agent_state.action):
|
elif self._is_moving_action(agent_state.action):
|
||||||
if not agent_state.action_valid:
|
if not agent_state.action_valid:
|
||||||
penalty += 10
|
reward -= 0.1
|
||||||
else:
|
else:
|
||||||
inforcements += 1
|
reward += 0
|
||||||
|
|
||||||
for entity in cols:
|
for entity in cols:
|
||||||
if entity != self.string_slices["dirt"]:
|
if entity != self.string_slices["dirt"]:
|
||||||
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
||||||
|
|
||||||
this_step_reward = max(0, inforcements-penalty)
|
|
||||||
self.monitor.set('dirt_amount', current_dirt_amount)
|
self.monitor.set('dirt_amount', current_dirt_amount)
|
||||||
self.monitor.set('dirty_tiles', dirty_tiles)
|
self.monitor.set('dirty_tiles', dirty_tiles)
|
||||||
self.print(f"reward is {this_step_reward}")
|
self.print(f"reward is {reward}")
|
||||||
return this_step_reward, {}
|
return reward, {}
|
||||||
|
|
||||||
def print(self, string):
|
def print(self, string):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
@ -166,7 +166,7 @@ if __name__ == '__main__':
|
|||||||
with MonitorCallback(factory):
|
with MonitorCallback(factory):
|
||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)]
|
random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)]
|
||||||
env_state, reward, done_bool, _ = factory.reset()
|
env_state, this_reward, done_bool, _ = factory.reset()
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
if render:
|
if render:
|
||||||
|
@ -48,7 +48,7 @@ class FactoryMonitor:
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
return None
|
return None
|
||||||
df = df.fillna(method='ffill')
|
df = df.fillna(method='ffill')
|
||||||
return df
|
return
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.")
|
raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.")
|
||||||
@ -74,10 +74,6 @@ class MonitorCallback(BaseCallback):
|
|||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self._on_training_end()
|
self._on_training_end()
|
||||||
|
|
||||||
def _on_rollout_end(self) -> None:
|
|
||||||
self._monitor_list.append(self.env.monitor)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
if self.started:
|
if self.started:
|
||||||
pass
|
pass
|
||||||
@ -96,4 +92,7 @@ class MonitorCallback(BaseCallback):
|
|||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
pass
|
if self.locals['dones'].item():
|
||||||
|
self._monitor_list.append(self.env.monitor)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user