This commit is contained in:
steffen-illium
2021-06-01 12:39:33 +02:00
parent 403d38dc24
commit 55b409c72f
6 changed files with 82 additions and 83 deletions

View File

@ -1,12 +1,11 @@
from pathlib import Path
from typing import List, Union, Iterable from typing import List, Union, Iterable
import gym import gym
from gym import spaces
import numpy as np import numpy as np
from pathlib import Path from gym import spaces
from environments import helpers as h from environments import helpers as h
from environments.logging.monitor import FactoryMonitor
class AgentState: class AgentState:
@ -102,16 +101,29 @@ class BaseFactory(gym.Env):
@property @property
def observation_space(self): def observation_space(self):
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32) if self.pomdp_size:
return spaces.Box(low=0, high=1, shape=(self.state.shape[0], self.pomdp_size,
self.pomdp_size), dtype=np.float32)
else:
space = spaces.Box(low=0, high=1, shape=self.state.shape, dtype=np.float32)
# space = spaces.MultiBinary(np.prod(self.state.shape))
# space = spaces.Dict({
# 'level': spaces.MultiBinary(np.prod(self.state[0].shape)),
# 'agent_n': spaces.Discrete(np.prod(self.state[1].shape)),
# 'dirt': spaces.Box(low=0, high=1, shape=self.state[2].shape, dtype=np.float32)
# })
return space
@property @property
def movement_actions(self): def movement_actions(self):
return self._actions.movement_actions return self._actions.movement_actions
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_size: Union[None, int] = None, **kwargs):
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
self.n_agents = n_agents self.n_agents = n_agents
self.max_steps = max_steps self.max_steps = max_steps
assert pomdp_size is None or (pomdp_size is not None and pomdp_size % 2 == 1)
self.pomdp_size = pomdp_size
self.done_at_collision = False self.done_at_collision = False
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True), _actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True), allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
@ -138,7 +150,6 @@ class BaseFactory(gym.Env):
def reset(self) -> (np.ndarray, int, bool, dict): def reset(self) -> (np.ndarray, int, bool, dict):
self.steps = 0 self.steps = 0
self.monitor = FactoryMonitor(self)
self.agent_states = [] self.agent_states = []
# Agent placement ... # Agent placement ...
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8) agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
@ -153,7 +164,35 @@ class BaseFactory(gym.Env):
# state.shape = level, agent 1,..., agent n, # state.shape = level, agent 1,..., agent n,
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0) self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
# Returns State # Returns State
return self.state return self._return_state()
def _return_state(self):
if self.pomdp_size:
pos = self.agent_states[0].pos
# pos = [agent_state.pos for agent_state in self.agent_states]
# obs = [] ... list comprehension... pos per agent
offset = self.pomdp_size // 2
x0, x1 = max(0, pos[0] - offset), pos[0] + offset + 1
y0, y1 = max(0, pos[1] - offset), pos[1] + offset + 1
obs = self.state[:, x0:x1, y0:y1]
if obs.shape[1] != self.pomdp_size or obs.shape[2] != self.pomdp_size:
obs_padded = np.zeros((obs.shape[0], self.pomdp_size, self.pomdp_size))
try:
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
except IndexError:
print('Shiiiiiit')
try:
obs_padded[:, abs(a_pos[0]-offset):abs(a_pos[0]-offset)+obs.shape[1], abs(a_pos[1]-offset):abs(a_pos[1]-offset)+obs.shape[2]] = obs
except ValueError:
print('Shiiiiiit')
assert all(np.argwhere(obs_padded[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] == (3,3))
obs = obs_padded
else:
obs = self.state
return obs
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
raise NotImplementedError raise NotImplementedError
@ -188,12 +227,11 @@ class BaseFactory(gym.Env):
if self.steps >= self.max_steps: if self.steps >= self.max_steps:
done = True done = True
self.monitor.set('step_reward', reward)
self.monitor.set('step', self.steps)
if done: info.update(step_reward=reward, step=self.steps)
info.update(monitor=self.monitor)
return self.state, reward, done, info obs = self._return_state()
return obs, reward, done, info
def _is_moving_action(self, action): def _is_moving_action(self, action):
return action in self._actions.movement_actions return action in self._actions.movement_actions

View File

@ -99,7 +99,8 @@ class SimpleFactory(BaseFactory):
self.next_dirt_spawn = self._dirt_properties.spawn_frequency self.next_dirt_spawn = self._dirt_properties.spawn_frequency
else: else:
self.next_dirt_spawn -= 1 self.next_dirt_spawn -= 1
return self.state, r, done, info obs = self._return_state()
return obs, r, done, info
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
if action != self._is_moving_action(action): if action != self._is_moving_action(action):
@ -118,12 +119,14 @@ class SimpleFactory(BaseFactory):
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt() self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency self.next_dirt_spawn = self._dirt_properties.spawn_frequency
return self.state obs = self._return_state()
return obs
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
# TODO: What reward to use? # TODO: What reward to use?
current_dirt_amount = self.state[DIRT_INDEX].sum() current_dirt_amount = self.state[DIRT_INDEX].sum()
dirty_tiles = np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] dirty_tiles = np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0]
info_dict = dict()
try: try:
# penalty = current_dirt_amount # penalty = current_dirt_amount
@ -143,33 +146,35 @@ class SimpleFactory(BaseFactory):
if agent_state.action_valid: if agent_state.action_valid:
reward += 1 reward += 1
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
self.monitor.set('dirt_cleaned', 1) info_dict.update(dirt_cleaned=1)
else: else:
reward -= 0.5 reward -= 0.0
self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
f'at {agent_state.pos}, but was unsucsessfull.') f'at {agent_state.pos}, but was unsucsessfull.')
self.monitor.set('failed_cleanup_attempt', 1) info_dict.update(failed_cleanup_attempt=1)
elif self._is_moving_action(agent_state.action): elif self._is_moving_action(agent_state.action):
if agent_state.action_valid: if agent_state.action_valid:
info_dict.update(movement=1)
reward -= 0.00 reward -= 0.00
else: else:
reward -= 0.5 info_dict.update(collision=1)
reward -= 0.00
else: else:
self.monitor.set('no_op', 1) info_dict.update(collision=1)
reward -= 0.1 reward -= 0.00
for entity in cols: for entity in cols:
if entity != self.state_slices.by_name("dirt"): if entity != self.state_slices.by_name("dirt"):
self.monitor.set(f'agent_{agent_state.i}_vs_{self.state_slices[entity]}', 1) info_dict.update({f'agent_{agent_state.i}_vs_{self.state_slices[entity]}': 1})
self.monitor.set('dirt_amount', current_dirt_amount) info_dict.update(dirt_amount=current_dirt_amount)
self.monitor.set('dirty_tile_count', dirty_tiles) info_dict.update(dirty_tile_count=dirty_tiles)
self.print(f"reward is {reward}") self.print(f"reward is {reward}")
# Potential based rewards -> # Potential based rewards ->
# track the last reward , minus the current reward = potential # track the last reward , minus the current reward = potential
return reward, {} return reward, info_dict
def print(self, string): def print(self, string):
if self.verbose: if self.verbose:

View File

@ -1,6 +1,5 @@
import pickle import pickle
from pathlib import Path from pathlib import Path
from collections import defaultdict
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
@ -9,51 +8,6 @@ from environments.logging.plotting import prepare_plot
import pandas as pd import pandas as pd
class FactoryMonitor:
def __init__(self, env):
self._env = env
self._monitor = defaultdict(lambda: defaultdict(lambda: 0))
self._last_vals = defaultdict(lambda: 0)
def __iter__(self):
for key, value in self._monitor.items():
yield key, dict(value)
def add(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = self._last_vals[key] + value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def set(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def remove(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = self._last_vals[key] - value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def to_dict(self):
return dict(self)
def to_pd_dataframe(self):
import pandas as pd
df = pd.DataFrame.from_dict(self.to_dict())
df.fillna(0)
return df
def reset(self):
raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.")
class MonitorCallback(BaseCallback): class MonitorCallback(BaseCallback):
ext = 'png' ext = 'png'
@ -62,6 +16,7 @@ class MonitorCallback(BaseCallback):
super(MonitorCallback, self).__init__() super(MonitorCallback, self).__init__()
self.filepath = Path(filepath) self.filepath = Path(filepath)
self._monitor_df = pd.DataFrame() self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
self.env = env self.env = env
self.plotting = plotting self.plotting = plotting
self.started = False self.started = False
@ -113,12 +68,17 @@ class MonitorCallback(BaseCallback):
self.closed = True self.closed = True
def _on_step(self) -> bool: def _on_step(self) -> bool:
for _, info in enumerate(self.locals.get('infos', [])):
self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items()
if key not in ['terminal_observation', 'episode']}
for env_idx, done in enumerate(self.locals.get('dones', [])): for env_idx, done in enumerate(self.locals.get('dones', [])):
if done: if done:
env_monitor_df = self.locals['infos'][env_idx]['monitor'].to_pd_dataframe() env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
self._monitor_dict = dict()
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS] columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
env_monitor_df = env_monitor_df.aggregate( env_monitor_df = env_monitor_df.aggregate(
{col: 'mean' if 'amount' in col or 'count' in col else 'sum' for col in columns} {col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
) )
env_monitor_df['episode'] = len(self._monitor_df) env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df]) self._monitor_df = self._monitor_df.append([env_monitor_df])

View File

@ -25,8 +25,9 @@ def plot(filepath, ext='png', **kwargs):
plt.tight_layout() plt.tight_layout()
figure = plt.gcf() figure = plt.gcf()
plt.show()
figure.savefig(str(filepath), format=ext) figure.savefig(str(filepath), format=ext)
plt.show()
plt.clf()
def prepare_plot(filepath, results_df, ext='png'): def prepare_plot(filepath, results_df, ext='png'):

View File

@ -56,16 +56,16 @@ if __name__ == '__main__':
# combine_runs('debug_out/PPO_1622399010') # combine_runs('debug_out/PPO_1622399010')
# exit() # exit()
from stable_baselines3 import PPO, DQN from stable_baselines3 import PPO, DQN, A2C
dirt_props = DirtProperties() dirt_props = DirtProperties()
time_stamp = int(time.time()) time_stamp = int(time.time())
out_path = None out_path = None
for modeL_type in [PPO]: for modeL_type in [A2C, PPO]:
for seed in range(5): for seed in range(5):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_size=7,
allow_diagonal_movement=False, allow_no_op=False) allow_diagonal_movement=False, allow_no_op=False)
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu') model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')

View File

@ -1,16 +1,11 @@
import warnings import warnings
from pathlib import Path from pathlib import Path
import time
from natsort import natsorted from natsort import natsorted
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.evaluation import evaluate_policy
from environments.factory.simple_factory import DirtProperties, SimpleFactory from environments.factory.simple_factory import DirtProperties, SimpleFactory
from environments.logging.monitor import MonitorCallback
from environments.logging.training import TraningMonitor
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -20,7 +15,7 @@ if __name__ == '__main__':
dirt_props = DirtProperties() dirt_props = DirtProperties()
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props) env = SimpleFactory(n_agents=1, dirt_properties=dirt_props)
out_path = Path('debug_out') out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\PPO_1622485791\1_PPO_1622485791')
model_files = list(natsorted(out_path.rglob('*.zip'))) model_files = list(natsorted(out_path.rglob('*.zip')))
this_model = model_files[0] this_model = model_files[0]