diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index d052af8..5ac6674 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -50,12 +50,12 @@ class BaseFactory(gym.Env): def string_slices(self): return {value: key for key, value in self.slice_strings.items()} - def __init__(self, level='simple', n_agents=1, max_steps=int(5e2)): + def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)): self.n_agents = n_agents self.max_steps = max_steps self.allow_square_movement = True - self.allow_diagonal_movement = False - self.allow_no_OP = False + self.allow_diagonal_movement = True + self.allow_no_OP = True self.done_at_collision = False self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions() self.level = h.one_hot_level( @@ -119,7 +119,7 @@ class BaseFactory(gym.Env): if self.steps >= self.max_steps: done = True - self.monitor.add('step_reward', reward) + self.monitor.set('step_reward', reward) return self.state, reward, done, info def _is_moving_action(self, action): diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index d205e04..14a2043 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -64,7 +64,7 @@ class SimpleFactory(BaseFactory): self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents)) def spawn_dirt(self) -> None: - if not self.state[DIRT_INDEX].sum() > self.max_dirt: + if not self.state[DIRT_INDEX].sum() > self.max_dirt or not np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > 10: free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX) # randomly distribute dirt across the grid @@ -150,6 +150,7 @@ class SimpleFactory(BaseFactory): self.monitor.set('dirt_amount', current_dirt_amount) self.monitor.set('dirty_tiles', dirty_tiles) + self.monitor.set('step', self.steps) self.print(f"reward is {reward}") # Potential based rewards -> # track the last reward , minus the current reward = potential diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index bdcb094..6b4acae 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -45,11 +45,7 @@ class FactoryMonitor: def to_pd_dataframe(self): import pandas as pd df = pd.DataFrame.from_dict(self.to_dict()) - try: - df.loc[0] = df.iloc[0].fillna(0) - except IndexError: - return None - df = df.fillna(method='ffill') + df.fillna(0) return df def reset(self): diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index fb4690b..57fb922 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -15,7 +15,7 @@ def plot(filepath, ext='png', tag='monitor', **kwargs): def prepare_plot(filepath, results_df, ext='png', tag=''): # %% - _ = sns.lineplot(data=results_df) + _ = sns.lineplot(data=results_df, ci='sd', x='step') # %% sns.set_theme(palette='husl', style='whitegrid')