plotting
This commit is contained in:
@ -50,12 +50,12 @@ class BaseFactory(gym.Env):
|
|||||||
def string_slices(self):
|
def string_slices(self):
|
||||||
return {value: key for key, value in self.slice_strings.items()}
|
return {value: key for key, value in self.slice_strings.items()}
|
||||||
|
|
||||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2)):
|
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)):
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.allow_square_movement = True
|
self.allow_square_movement = True
|
||||||
self.allow_diagonal_movement = False
|
self.allow_diagonal_movement = True
|
||||||
self.allow_no_OP = False
|
self.allow_no_OP = True
|
||||||
self.done_at_collision = False
|
self.done_at_collision = False
|
||||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||||
self.level = h.one_hot_level(
|
self.level = h.one_hot_level(
|
||||||
@ -119,7 +119,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
if self.steps >= self.max_steps:
|
if self.steps >= self.max_steps:
|
||||||
done = True
|
done = True
|
||||||
self.monitor.add('step_reward', reward)
|
self.monitor.set('step_reward', reward)
|
||||||
return self.state, reward, done, info
|
return self.state, reward, done, info
|
||||||
|
|
||||||
def _is_moving_action(self, action):
|
def _is_moving_action(self, action):
|
||||||
|
@ -64,7 +64,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents))
|
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents))
|
||||||
|
|
||||||
def spawn_dirt(self) -> None:
|
def spawn_dirt(self) -> None:
|
||||||
if not self.state[DIRT_INDEX].sum() > self.max_dirt:
|
if not self.state[DIRT_INDEX].sum() > self.max_dirt or not np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > 10:
|
||||||
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
||||||
|
|
||||||
# randomly distribute dirt across the grid
|
# randomly distribute dirt across the grid
|
||||||
@ -150,6 +150,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
self.monitor.set('dirt_amount', current_dirt_amount)
|
self.monitor.set('dirt_amount', current_dirt_amount)
|
||||||
self.monitor.set('dirty_tiles', dirty_tiles)
|
self.monitor.set('dirty_tiles', dirty_tiles)
|
||||||
|
self.monitor.set('step', self.steps)
|
||||||
self.print(f"reward is {reward}")
|
self.print(f"reward is {reward}")
|
||||||
# Potential based rewards ->
|
# Potential based rewards ->
|
||||||
# track the last reward , minus the current reward = potential
|
# track the last reward , minus the current reward = potential
|
||||||
|
@ -45,11 +45,7 @@ class FactoryMonitor:
|
|||||||
def to_pd_dataframe(self):
|
def to_pd_dataframe(self):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
df = pd.DataFrame.from_dict(self.to_dict())
|
df = pd.DataFrame.from_dict(self.to_dict())
|
||||||
try:
|
df.fillna(0)
|
||||||
df.loc[0] = df.iloc[0].fillna(0)
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
df = df.fillna(method='ffill')
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -15,7 +15,7 @@ def plot(filepath, ext='png', tag='monitor', **kwargs):
|
|||||||
def prepare_plot(filepath, results_df, ext='png', tag=''):
|
def prepare_plot(filepath, results_df, ext='png', tag=''):
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
_ = sns.lineplot(data=results_df)
|
_ = sns.lineplot(data=results_df, ci='sd', x='step')
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
sns.set_theme(palette='husl', style='whitegrid')
|
sns.set_theme(palette='husl', style='whitegrid')
|
||||||
|
Reference in New Issue
Block a user