Getting Dirty

Viz
This commit is contained in:
steffen-illium 2021-05-17 16:50:54 +02:00
parent 2ba095767d
commit 27f5abad64
3 changed files with 78 additions and 57 deletions

View File

@ -0,0 +1,47 @@
from collections import defaultdict
class FactoryMonitor:
def __init__(self, env):
self._env = env
self._monitor = defaultdict(lambda: defaultdict(lambda: 0))
self._last_vals = defaultdict(lambda: 0)
def __iter__(self):
for key, value in self._monitor.items():
yield key, dict(value)
def add(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = self._last_vals[key] + value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def set(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def remove(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = self._last_vals[key] - value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def to_dict(self):
return dict(self)
def to_pd_dataframe(self):
import pandas as pd
df = pd.DataFrame.from_dict(self.to_dict())
df.loc[0] = df.iloc[0].fillna(0)
df = df.fillna(method='ffill')
return df
def reset(self):
raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.")

View File

@ -1,10 +1,10 @@
from collections import defaultdict from typing import List, Union
from typing import List
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from environments import helpers as h from environments import helpers as h
from environments.factory._factory_monitor import FactoryMonitor
class AgentState: class AgentState:
@ -29,51 +29,6 @@ class AgentState:
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
class FactoryMonitor:
def __init__(self, env):
self._env = env
self._monitor = defaultdict(lambda: defaultdict(lambda: 0))
self._last_vals = defaultdict(lambda: 0)
def __iter__(self):
for key, value in self._monitor.items():
yield key, dict(value)
def add(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = self._last_vals[key] + value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def set(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def remove(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = self._last_vals[key] - value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def to_dict(self):
return dict(self)
def to_pd_dataframe(self):
import pandas as pd
return pd.DataFrame.from_dict(self.to_dict())
def reset(self):
raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.")
class BaseFactory: class BaseFactory:
@property @property
@ -192,9 +147,19 @@ class BaseFactory:
pos_x, pos_y = positions[0] # a.flatten() pos_x, pos_y = positions[0] # a.flatten()
return pos_x, pos_y return pos_x, pos_y
@property def free_cells(self, excluded_slices: Union[None, List, int] = None) -> np.ndarray:
def free_cells(self) -> np.ndarray: excluded_slices = excluded_slices or []
free_cells = self.state.sum(0) assert isinstance(excluded_slices, (int, list))
excluded_slices = excluded_slices if isinstance(excluded_slices, list) else [excluded_slices]
state = self.state
if excluded_slices:
# Todo: Is there a cleaner way?
inds = list(range(self.state.shape[0]))
excluded_slices = [inds[x] if x < 0 else x for x in excluded_slices]
state = self.state[[x for x in inds if x not in excluded_slices]]
free_cells = state.sum(0)
free_cells = np.argwhere(free_cells == h.IS_FREE_CELL) free_cells = np.argwhere(free_cells == h.IS_FREE_CELL)
np.random.shuffle(free_cells) np.random.shuffle(free_cells)
return free_cells return free_cells

View File

@ -26,7 +26,7 @@ class GettingDirty(BaseFactory):
self.slice_strings.update({self.state.shape[0]-1: 'dirt'}) self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
def spawn_dirt(self) -> None: def spawn_dirt(self) -> None:
free_for_dirt = self.free_cells free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
# randomly distribute dirt across the grid # randomly distribute dirt across the grid
n_dirt_tiles = int(random.uniform(0, self._dirt_properties.max_spawn_ratio) * len(free_for_dirt)) n_dirt_tiles = int(random.uniform(0, self._dirt_properties.max_spawn_ratio) * len(free_for_dirt))
for x, y in free_for_dirt[:n_dirt_tiles]: for x, y in free_for_dirt[:n_dirt_tiles]:
@ -43,6 +43,11 @@ class GettingDirty(BaseFactory):
self.state[DIRT_INDEX][pos] = max(new_dirt_amount, h.IS_FREE_CELL) self.state[DIRT_INDEX][pos] = max(new_dirt_amount, h.IS_FREE_CELL)
return pos, cleanup_was_sucessfull return pos, cleanup_was_sucessfull
def step(self, actions):
_, _, _, info = super(GettingDirty, self).step(actions)
self.spawn_dirt()
return self.state, self.cumulative_reward, self.done, info
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
if action != self._is_moving_action(action): if action != self._is_moving_action(action):
if self._is_clean_up_action(action): if self._is_clean_up_action(action):
@ -53,7 +58,7 @@ class GettingDirty(BaseFactory):
self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount) self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
else: else:
print(f'Agent {agent_i} just tried to clean up some dirt at {agent_i_pos}, but was unsucsessfull.') print(f'Agent {agent_i} just tried to clean up some dirt at {agent_i_pos}, but was unsucsessfull.')
self.monitor.add('failed_attempts', 1) self.monitor.add('failed_cleanup_attempt', 1)
return agent_i_pos, valid return agent_i_pos, valid
else: else:
raise RuntimeError('This should not happen!!!') raise RuntimeError('This should not happen!!!')
@ -76,6 +81,9 @@ class GettingDirty(BaseFactory):
if self._is_clean_up_action(agent_state.action) and agent_state.action_valid: if self._is_clean_up_action(agent_state.action) and agent_state.action_valid:
this_step_reward += 1 this_step_reward += 1
for entity in collisions:
if entity != self.string_slices["dirt"]:
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
self.monitor.set('dirt_amount', self.state[DIRT_INDEX].sum()) self.monitor.set('dirt_amount', self.state[DIRT_INDEX].sum())
self.monitor.set('dirty_tiles', len(np.nonzero(self.state[DIRT_INDEX]))) self.monitor.set('dirty_tiles', len(np.nonzero(self.state[DIRT_INDEX])))
return this_step_reward, {} return this_step_reward, {}
@ -87,15 +95,16 @@ if __name__ == '__main__':
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props) factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
monitor_list = list() monitor_list = list()
for epoch in range(100): for epoch in range(100):
random_actions = [random.randint(0, 7) for _ in range(200)] random_actions = [random.randint(0, 8) for _ in range(200)]
state, r, done, _ = factory.reset() state, r, done, _ = factory.reset()
for action in random_actions: for action in random_actions:
state, r, done, info = factory.step(action) state, r, done, info = factory.step(action)
monitor_list.append(factory.monitor) monitor_list.append(factory.monitor.to_pd_dataframe())
print(f'Factory run done, reward is:\n {r}') print(f'Factory run {epoch} done, reward is:\n {r}')
from pathlib import Path from pathlib import Path
import pickle import pickle
out_path = Path('debug_out') out_path = Path('debug_out')
out_path.mkdir(exist_ok=True, parents=True) out_path.mkdir(exist_ok=True, parents=True)
with (out_path / 'monitor.pick').open('rb') as f: with (out_path / 'monitor.pick').open('wb') as f:
pickle.dump(monitor_list, f) pickle.dump(monitor_list, f, protocol=pickle.HIGHEST_PROTOCOL)