This commit is contained in:
steffen-illium
2021-06-01 12:39:33 +02:00
parent 403d38dc24
commit 55b409c72f
6 changed files with 82 additions and 83 deletions

View File

@ -1,12 +1,11 @@
from pathlib import Path
from typing import List, Union, Iterable
import gym
from gym import spaces
import numpy as np
from pathlib import Path
from gym import spaces
from environments import helpers as h
from environments.logging.monitor import FactoryMonitor
class AgentState:
@ -102,16 +101,29 @@ class BaseFactory(gym.Env):
@property
def observation_space(self):
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
if self.pomdp_size:
return spaces.Box(low=0, high=1, shape=(self.state.shape[0], self.pomdp_size,
self.pomdp_size), dtype=np.float32)
else:
space = spaces.Box(low=0, high=1, shape=self.state.shape, dtype=np.float32)
# space = spaces.MultiBinary(np.prod(self.state.shape))
# space = spaces.Dict({
# 'level': spaces.MultiBinary(np.prod(self.state[0].shape)),
# 'agent_n': spaces.Discrete(np.prod(self.state[1].shape)),
# 'dirt': spaces.Box(low=0, high=1, shape=self.state[2].shape, dtype=np.float32)
# })
return space
@property
def movement_actions(self):
return self._actions.movement_actions
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_size: Union[None, int] = None, **kwargs):
self.n_agents = n_agents
self.max_steps = max_steps
assert pomdp_size is None or (pomdp_size is not None and pomdp_size % 2 == 1)
self.pomdp_size = pomdp_size
self.done_at_collision = False
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
@ -138,7 +150,6 @@ class BaseFactory(gym.Env):
def reset(self) -> (np.ndarray, int, bool, dict):
self.steps = 0
self.monitor = FactoryMonitor(self)
self.agent_states = []
# Agent placement ...
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
@ -153,7 +164,35 @@ class BaseFactory(gym.Env):
# state.shape = level, agent 1,..., agent n,
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
# Returns State
return self.state
return self._return_state()
def _return_state(self):
if self.pomdp_size:
pos = self.agent_states[0].pos
# pos = [agent_state.pos for agent_state in self.agent_states]
# obs = [] ... list comprehension... pos per agent
offset = self.pomdp_size // 2
x0, x1 = max(0, pos[0] - offset), pos[0] + offset + 1
y0, y1 = max(0, pos[1] - offset), pos[1] + offset + 1
obs = self.state[:, x0:x1, y0:y1]
if obs.shape[1] != self.pomdp_size or obs.shape[2] != self.pomdp_size:
obs_padded = np.zeros((obs.shape[0], self.pomdp_size, self.pomdp_size))
try:
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
except IndexError:
print('Shiiiiiit')
try:
obs_padded[:, abs(a_pos[0]-offset):abs(a_pos[0]-offset)+obs.shape[1], abs(a_pos[1]-offset):abs(a_pos[1]-offset)+obs.shape[2]] = obs
except ValueError:
print('Shiiiiiit')
assert all(np.argwhere(obs_padded[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] == (3,3))
obs = obs_padded
else:
obs = self.state
return obs
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
raise NotImplementedError
@ -188,12 +227,11 @@ class BaseFactory(gym.Env):
if self.steps >= self.max_steps:
done = True
self.monitor.set('step_reward', reward)
self.monitor.set('step', self.steps)
if done:
info.update(monitor=self.monitor)
return self.state, reward, done, info
info.update(step_reward=reward, step=self.steps)
obs = self._return_state()
return obs, reward, done, info
def _is_moving_action(self, action):
return action in self._actions.movement_actions

View File

@ -99,7 +99,8 @@ class SimpleFactory(BaseFactory):
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
else:
self.next_dirt_spawn -= 1
return self.state, r, done, info
obs = self._return_state()
return obs, r, done, info
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
if action != self._is_moving_action(action):
@ -118,12 +119,14 @@ class SimpleFactory(BaseFactory):
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt()
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
return self.state
obs = self._return_state()
return obs
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
# TODO: What reward to use?
current_dirt_amount = self.state[DIRT_INDEX].sum()
dirty_tiles = np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0]
info_dict = dict()
try:
# penalty = current_dirt_amount
@ -143,33 +146,35 @@ class SimpleFactory(BaseFactory):
if agent_state.action_valid:
reward += 1
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
self.monitor.set('dirt_cleaned', 1)
info_dict.update(dirt_cleaned=1)
else:
reward -= 0.5
reward -= 0.0
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
f'at {agent_state.pos}, but was unsucsessfull.')
self.monitor.set('failed_cleanup_attempt', 1)
info_dict.update(failed_cleanup_attempt=1)
elif self._is_moving_action(agent_state.action):
if agent_state.action_valid:
info_dict.update(movement=1)
reward -= 0.00
else:
reward -= 0.5
info_dict.update(collision=1)
reward -= 0.00
else:
self.monitor.set('no_op', 1)
reward -= 0.1
info_dict.update(collision=1)
reward -= 0.00
for entity in cols:
if entity != self.state_slices.by_name("dirt"):
self.monitor.set(f'agent_{agent_state.i}_vs_{self.state_slices[entity]}', 1)
info_dict.update({f'agent_{agent_state.i}_vs_{self.state_slices[entity]}': 1})
self.monitor.set('dirt_amount', current_dirt_amount)
self.monitor.set('dirty_tile_count', dirty_tiles)
info_dict.update(dirt_amount=current_dirt_amount)
info_dict.update(dirty_tile_count=dirty_tiles)
self.print(f"reward is {reward}")
# Potential based rewards ->
# track the last reward , minus the current reward = potential
return reward, {}
return reward, info_dict
def print(self, string):
if self.verbose:

View File

@ -1,6 +1,5 @@
import pickle
from pathlib import Path
from collections import defaultdict
from stable_baselines3.common.callbacks import BaseCallback
@ -9,51 +8,6 @@ from environments.logging.plotting import prepare_plot
import pandas as pd
class FactoryMonitor:
def __init__(self, env):
self._env = env
self._monitor = defaultdict(lambda: defaultdict(lambda: 0))
self._last_vals = defaultdict(lambda: 0)
def __iter__(self):
for key, value in self._monitor.items():
yield key, dict(value)
def add(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = self._last_vals[key] + value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def set(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def remove(self, key, value, step=None):
assert step is None or step >= 1 # Is this good practice?
step = step or self._env.steps
self._last_vals[key] = self._last_vals[key] - value
self._monitor[key][step] = self._last_vals[key]
return self._last_vals[key]
def to_dict(self):
return dict(self)
def to_pd_dataframe(self):
import pandas as pd
df = pd.DataFrame.from_dict(self.to_dict())
df.fillna(0)
return df
def reset(self):
raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.")
class MonitorCallback(BaseCallback):
ext = 'png'
@ -62,6 +16,7 @@ class MonitorCallback(BaseCallback):
super(MonitorCallback, self).__init__()
self.filepath = Path(filepath)
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
self.env = env
self.plotting = plotting
self.started = False
@ -113,12 +68,17 @@ class MonitorCallback(BaseCallback):
self.closed = True
def _on_step(self) -> bool:
for _, info in enumerate(self.locals.get('infos', [])):
self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items()
if key not in ['terminal_observation', 'episode']}
for env_idx, done in enumerate(self.locals.get('dones', [])):
if done:
env_monitor_df = self.locals['infos'][env_idx]['monitor'].to_pd_dataframe()
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
self._monitor_dict = dict()
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
env_monitor_df = env_monitor_df.aggregate(
{col: 'mean' if 'amount' in col or 'count' in col else 'sum' for col in columns}
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
)
env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df])

View File

@ -25,8 +25,9 @@ def plot(filepath, ext='png', **kwargs):
plt.tight_layout()
figure = plt.gcf()
plt.show()
figure.savefig(str(filepath), format=ext)
plt.show()
plt.clf()
def prepare_plot(filepath, results_df, ext='png'):

View File

@ -56,16 +56,16 @@ if __name__ == '__main__':
# combine_runs('debug_out/PPO_1622399010')
# exit()
from stable_baselines3 import PPO, DQN
from stable_baselines3 import PPO, DQN, A2C
dirt_props = DirtProperties()
time_stamp = int(time.time())
out_path = None
for modeL_type in [PPO]:
for modeL_type in [A2C, PPO]:
for seed in range(5):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props,
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_size=7,
allow_diagonal_movement=False, allow_no_op=False)
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')

View File

@ -1,16 +1,11 @@
import warnings
from pathlib import Path
import time
from natsort import natsorted
from stable_baselines3 import PPO
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.evaluation import evaluate_policy
from environments.factory.simple_factory import DirtProperties, SimpleFactory
from environments.logging.monitor import MonitorCallback
from environments.logging.training import TraningMonitor
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -20,7 +15,7 @@ if __name__ == '__main__':
dirt_props = DirtProperties()
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props)
out_path = Path('debug_out')
out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\PPO_1622485791\1_PPO_1622485791')
model_files = list(natsorted(out_path.rglob('*.zip')))
this_model = model_files[0]