diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 07ed3ea..36972ab 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -50,13 +50,13 @@ class BaseFactory(gym.Env): def string_slices(self): return {value: key for key, value in self.slice_strings.items()} - def __init__(self, level='simple', n_agents=1, max_steps=1e3): + def __init__(self, level='simple', n_agents=1, max_steps=int(5e2)): self.n_agents = n_agents self.max_steps = max_steps self.allow_vertical_movement = True self.allow_horizontal_movement = True self.allow_no_OP = True - self.done_at_collision = True + self.done_at_collision = False self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions() self.level = h.one_hot_level( h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index bb0e207..d6b8612 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -120,38 +120,38 @@ class SimpleFactory(BaseFactory): try: # penalty = current_dirt_amount - penalty = 0 + reward = 0 except (ZeroDivisionError, RuntimeWarning): - penalty = 0 - inforcements = 0 + reward = 0 + for agent_state in agent_states: cols = agent_state.collisions self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}') if self._is_clean_up_action(agent_state.action): if agent_state.action_valid: - inforcements += 10 + reward += 2 self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount) else: self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' f'at {agent_state.pos}, but was unsucsessfull.') self.monitor.add('failed_cleanup_attempt', 1) + reward -= 0.05 elif self._is_moving_action(agent_state.action): if not agent_state.action_valid: - penalty += 10 + reward -= 0.1 else: - inforcements += 1 + reward += 0 for entity in cols: if entity != self.string_slices["dirt"]: self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) - this_step_reward = max(0, inforcements-penalty) self.monitor.set('dirt_amount', current_dirt_amount) self.monitor.set('dirty_tiles', dirty_tiles) - self.print(f"reward is {this_step_reward}") - return this_step_reward, {} + self.print(f"reward is {reward}") + return reward, {} def print(self, string): if self.verbose: @@ -166,7 +166,7 @@ if __name__ == '__main__': with MonitorCallback(factory): for epoch in range(100): random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)] - env_state, reward, done_bool, _ = factory.reset() + env_state, this_reward, done_bool, _ = factory.reset() for agent_i_action in random_actions: env_state, reward, done_bool, info_obj = factory.step(agent_i_action) if render: diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index 5cfa59d..6c9eeb5 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -48,7 +48,7 @@ class FactoryMonitor: except IndexError: return None df = df.fillna(method='ffill') - return df + return def reset(self): raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.") @@ -74,10 +74,6 @@ class MonitorCallback(BaseCallback): def __exit__(self, exc_type, exc_val, exc_tb): self._on_training_end() - def _on_rollout_end(self) -> None: - self._monitor_list.append(self.env.monitor) - pass - def _on_training_start(self) -> None: if self.started: pass @@ -96,4 +92,7 @@ class MonitorCallback(BaseCallback): self.closed = True def _on_step(self) -> bool: - pass + if self.locals['dones'].item(): + self._monitor_list.append(self.env.monitor) + else: + pass