rewards and monitors

This commit is contained in:
steffen-illium 2021-05-20 14:27:46 +02:00
parent 3114cdffc3
commit 962f914ff9
3 changed files with 17 additions and 18 deletions

View File

@ -50,13 +50,13 @@ class BaseFactory(gym.Env):
def string_slices(self): def string_slices(self):
return {value: key for key, value in self.slice_strings.items()} return {value: key for key, value in self.slice_strings.items()}
def __init__(self, level='simple', n_agents=1, max_steps=1e3): def __init__(self, level='simple', n_agents=1, max_steps=int(5e2)):
self.n_agents = n_agents self.n_agents = n_agents
self.max_steps = max_steps self.max_steps = max_steps
self.allow_vertical_movement = True self.allow_vertical_movement = True
self.allow_horizontal_movement = True self.allow_horizontal_movement = True
self.allow_no_OP = True self.allow_no_OP = True
self.done_at_collision = True self.done_at_collision = False
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions() self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
self.level = h.one_hot_level( self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')

View File

@ -120,38 +120,38 @@ class SimpleFactory(BaseFactory):
try: try:
# penalty = current_dirt_amount # penalty = current_dirt_amount
penalty = 0 reward = 0
except (ZeroDivisionError, RuntimeWarning): except (ZeroDivisionError, RuntimeWarning):
penalty = 0 reward = 0
inforcements = 0
for agent_state in agent_states: for agent_state in agent_states:
cols = agent_state.collisions cols = agent_state.collisions
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}') f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}')
if self._is_clean_up_action(agent_state.action): if self._is_clean_up_action(agent_state.action):
if agent_state.action_valid: if agent_state.action_valid:
inforcements += 10 reward += 2
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount) self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
else: else:
self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
f'at {agent_state.pos}, but was unsucsessfull.') f'at {agent_state.pos}, but was unsucsessfull.')
self.monitor.add('failed_cleanup_attempt', 1) self.monitor.add('failed_cleanup_attempt', 1)
reward -= 0.05
elif self._is_moving_action(agent_state.action): elif self._is_moving_action(agent_state.action):
if not agent_state.action_valid: if not agent_state.action_valid:
penalty += 10 reward -= 0.1
else: else:
inforcements += 1 reward += 0
for entity in cols: for entity in cols:
if entity != self.string_slices["dirt"]: if entity != self.string_slices["dirt"]:
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
this_step_reward = max(0, inforcements-penalty)
self.monitor.set('dirt_amount', current_dirt_amount) self.monitor.set('dirt_amount', current_dirt_amount)
self.monitor.set('dirty_tiles', dirty_tiles) self.monitor.set('dirty_tiles', dirty_tiles)
self.print(f"reward is {this_step_reward}") self.print(f"reward is {reward}")
return this_step_reward, {} return reward, {}
def print(self, string): def print(self, string):
if self.verbose: if self.verbose:
@ -166,7 +166,7 @@ if __name__ == '__main__':
with MonitorCallback(factory): with MonitorCallback(factory):
for epoch in range(100): for epoch in range(100):
random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)] random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)]
env_state, reward, done_bool, _ = factory.reset() env_state, this_reward, done_bool, _ = factory.reset()
for agent_i_action in random_actions: for agent_i_action in random_actions:
env_state, reward, done_bool, info_obj = factory.step(agent_i_action) env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
if render: if render:

View File

@ -48,7 +48,7 @@ class FactoryMonitor:
except IndexError: except IndexError:
return None return None
df = df.fillna(method='ffill') df = df.fillna(method='ffill')
return df return
def reset(self): def reset(self):
raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.") raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.")
@ -74,10 +74,6 @@ class MonitorCallback(BaseCallback):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self._on_training_end() self._on_training_end()
def _on_rollout_end(self) -> None:
self._monitor_list.append(self.env.monitor)
pass
def _on_training_start(self) -> None: def _on_training_start(self) -> None:
if self.started: if self.started:
pass pass
@ -96,4 +92,7 @@ class MonitorCallback(BaseCallback):
self.closed = True self.closed = True
def _on_step(self) -> bool: def _on_step(self) -> bool:
pass if self.locals['dones'].item():
self._monitor_list.append(self.env.monitor)
else:
pass