pomdp
This commit is contained in:
@ -1,12 +1,11 @@
|
|||||||
|
from pathlib import Path
|
||||||
from typing import List, Union, Iterable
|
from typing import List, Union, Iterable
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from gym import spaces
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.logging.monitor import FactoryMonitor
|
|
||||||
|
|
||||||
|
|
||||||
class AgentState:
|
class AgentState:
|
||||||
@ -102,16 +101,29 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
|
if self.pomdp_size:
|
||||||
|
return spaces.Box(low=0, high=1, shape=(self.state.shape[0], self.pomdp_size,
|
||||||
|
self.pomdp_size), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
space = spaces.Box(low=0, high=1, shape=self.state.shape, dtype=np.float32)
|
||||||
|
# space = spaces.MultiBinary(np.prod(self.state.shape))
|
||||||
|
# space = spaces.Dict({
|
||||||
|
# 'level': spaces.MultiBinary(np.prod(self.state[0].shape)),
|
||||||
|
# 'agent_n': spaces.Discrete(np.prod(self.state[1].shape)),
|
||||||
|
# 'dirt': spaces.Box(low=0, high=1, shape=self.state[2].shape, dtype=np.float32)
|
||||||
|
# })
|
||||||
|
return space
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
return self._actions.movement_actions
|
return self._actions.movement_actions
|
||||||
|
|
||||||
|
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_size: Union[None, int] = None, **kwargs):
|
||||||
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
|
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
|
assert pomdp_size is None or (pomdp_size is not None and pomdp_size % 2 == 1)
|
||||||
|
self.pomdp_size = pomdp_size
|
||||||
|
|
||||||
self.done_at_collision = False
|
self.done_at_collision = False
|
||||||
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
|
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
|
||||||
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
|
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
|
||||||
@ -138,7 +150,6 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
self.steps = 0
|
self.steps = 0
|
||||||
self.monitor = FactoryMonitor(self)
|
|
||||||
self.agent_states = []
|
self.agent_states = []
|
||||||
# Agent placement ...
|
# Agent placement ...
|
||||||
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
||||||
@ -153,7 +164,35 @@ class BaseFactory(gym.Env):
|
|||||||
# state.shape = level, agent 1,..., agent n,
|
# state.shape = level, agent 1,..., agent n,
|
||||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||||
# Returns State
|
# Returns State
|
||||||
return self.state
|
return self._return_state()
|
||||||
|
|
||||||
|
def _return_state(self):
|
||||||
|
if self.pomdp_size:
|
||||||
|
pos = self.agent_states[0].pos
|
||||||
|
# pos = [agent_state.pos for agent_state in self.agent_states]
|
||||||
|
# obs = [] ... list comprehension... pos per agent
|
||||||
|
offset = self.pomdp_size // 2
|
||||||
|
x0, x1 = max(0, pos[0] - offset), pos[0] + offset + 1
|
||||||
|
y0, y1 = max(0, pos[1] - offset), pos[1] + offset + 1
|
||||||
|
obs = self.state[:, x0:x1, y0:y1]
|
||||||
|
if obs.shape[1] != self.pomdp_size or obs.shape[2] != self.pomdp_size:
|
||||||
|
obs_padded = np.zeros((obs.shape[0], self.pomdp_size, self.pomdp_size))
|
||||||
|
try:
|
||||||
|
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||||
|
except IndexError:
|
||||||
|
print('Shiiiiiit')
|
||||||
|
try:
|
||||||
|
obs_padded[:, abs(a_pos[0]-offset):abs(a_pos[0]-offset)+obs.shape[1], abs(a_pos[1]-offset):abs(a_pos[1]-offset)+obs.shape[2]] = obs
|
||||||
|
except ValueError:
|
||||||
|
print('Shiiiiiit')
|
||||||
|
assert all(np.argwhere(obs_padded[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] == (3,3))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
obs = obs_padded
|
||||||
|
else:
|
||||||
|
obs = self.state
|
||||||
|
return obs
|
||||||
|
|
||||||
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -188,12 +227,11 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
if self.steps >= self.max_steps:
|
if self.steps >= self.max_steps:
|
||||||
done = True
|
done = True
|
||||||
self.monitor.set('step_reward', reward)
|
|
||||||
self.monitor.set('step', self.steps)
|
|
||||||
|
|
||||||
if done:
|
info.update(step_reward=reward, step=self.steps)
|
||||||
info.update(monitor=self.monitor)
|
|
||||||
return self.state, reward, done, info
|
obs = self._return_state()
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
def _is_moving_action(self, action):
|
def _is_moving_action(self, action):
|
||||||
return action in self._actions.movement_actions
|
return action in self._actions.movement_actions
|
||||||
|
@ -99,7 +99,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||||
else:
|
else:
|
||||||
self.next_dirt_spawn -= 1
|
self.next_dirt_spawn -= 1
|
||||||
return self.state, r, done, info
|
obs = self._return_state()
|
||||||
|
return obs, r, done, info
|
||||||
|
|
||||||
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
if action != self._is_moving_action(action):
|
if action != self._is_moving_action(action):
|
||||||
@ -118,12 +119,14 @@ class SimpleFactory(BaseFactory):
|
|||||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||||
return self.state
|
obs = self._return_state()
|
||||||
|
return obs
|
||||||
|
|
||||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||||
# TODO: What reward to use?
|
# TODO: What reward to use?
|
||||||
current_dirt_amount = self.state[DIRT_INDEX].sum()
|
current_dirt_amount = self.state[DIRT_INDEX].sum()
|
||||||
dirty_tiles = np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0]
|
dirty_tiles = np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0]
|
||||||
|
info_dict = dict()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# penalty = current_dirt_amount
|
# penalty = current_dirt_amount
|
||||||
@ -143,33 +146,35 @@ class SimpleFactory(BaseFactory):
|
|||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
reward += 1
|
reward += 1
|
||||||
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
|
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
|
||||||
self.monitor.set('dirt_cleaned', 1)
|
info_dict.update(dirt_cleaned=1)
|
||||||
else:
|
else:
|
||||||
reward -= 0.5
|
reward -= 0.0
|
||||||
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
||||||
f'at {agent_state.pos}, but was unsucsessfull.')
|
f'at {agent_state.pos}, but was unsucsessfull.')
|
||||||
self.monitor.set('failed_cleanup_attempt', 1)
|
info_dict.update(failed_cleanup_attempt=1)
|
||||||
|
|
||||||
elif self._is_moving_action(agent_state.action):
|
elif self._is_moving_action(agent_state.action):
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
|
info_dict.update(movement=1)
|
||||||
reward -= 0.00
|
reward -= 0.00
|
||||||
else:
|
else:
|
||||||
reward -= 0.5
|
info_dict.update(collision=1)
|
||||||
|
reward -= 0.00
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.monitor.set('no_op', 1)
|
info_dict.update(collision=1)
|
||||||
reward -= 0.1
|
reward -= 0.00
|
||||||
|
|
||||||
for entity in cols:
|
for entity in cols:
|
||||||
if entity != self.state_slices.by_name("dirt"):
|
if entity != self.state_slices.by_name("dirt"):
|
||||||
self.monitor.set(f'agent_{agent_state.i}_vs_{self.state_slices[entity]}', 1)
|
info_dict.update({f'agent_{agent_state.i}_vs_{self.state_slices[entity]}': 1})
|
||||||
|
|
||||||
self.monitor.set('dirt_amount', current_dirt_amount)
|
info_dict.update(dirt_amount=current_dirt_amount)
|
||||||
self.monitor.set('dirty_tile_count', dirty_tiles)
|
info_dict.update(dirty_tile_count=dirty_tiles)
|
||||||
self.print(f"reward is {reward}")
|
self.print(f"reward is {reward}")
|
||||||
# Potential based rewards ->
|
# Potential based rewards ->
|
||||||
# track the last reward , minus the current reward = potential
|
# track the last reward , minus the current reward = potential
|
||||||
return reward, {}
|
return reward, info_dict
|
||||||
|
|
||||||
def print(self, string):
|
def print(self, string):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
|
||||||
@ -9,51 +8,6 @@ from environments.logging.plotting import prepare_plot
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
class FactoryMonitor:
|
|
||||||
|
|
||||||
def __init__(self, env):
|
|
||||||
self._env = env
|
|
||||||
self._monitor = defaultdict(lambda: defaultdict(lambda: 0))
|
|
||||||
self._last_vals = defaultdict(lambda: 0)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for key, value in self._monitor.items():
|
|
||||||
yield key, dict(value)
|
|
||||||
|
|
||||||
def add(self, key, value, step=None):
|
|
||||||
assert step is None or step >= 1 # Is this good practice?
|
|
||||||
step = step or self._env.steps
|
|
||||||
self._last_vals[key] = self._last_vals[key] + value
|
|
||||||
self._monitor[key][step] = self._last_vals[key]
|
|
||||||
return self._last_vals[key]
|
|
||||||
|
|
||||||
def set(self, key, value, step=None):
|
|
||||||
assert step is None or step >= 1 # Is this good practice?
|
|
||||||
step = step or self._env.steps
|
|
||||||
self._last_vals[key] = value
|
|
||||||
self._monitor[key][step] = self._last_vals[key]
|
|
||||||
return self._last_vals[key]
|
|
||||||
|
|
||||||
def remove(self, key, value, step=None):
|
|
||||||
assert step is None or step >= 1 # Is this good practice?
|
|
||||||
step = step or self._env.steps
|
|
||||||
self._last_vals[key] = self._last_vals[key] - value
|
|
||||||
self._monitor[key][step] = self._last_vals[key]
|
|
||||||
return self._last_vals[key]
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return dict(self)
|
|
||||||
|
|
||||||
def to_pd_dataframe(self):
|
|
||||||
import pandas as pd
|
|
||||||
df = pd.DataFrame.from_dict(self.to_dict())
|
|
||||||
df.fillna(0)
|
|
||||||
return df
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
raise RuntimeError("DO NOT DO THIS! Always initalize a new Monitor per Env-Run.")
|
|
||||||
|
|
||||||
|
|
||||||
class MonitorCallback(BaseCallback):
|
class MonitorCallback(BaseCallback):
|
||||||
|
|
||||||
ext = 'png'
|
ext = 'png'
|
||||||
@ -62,6 +16,7 @@ class MonitorCallback(BaseCallback):
|
|||||||
super(MonitorCallback, self).__init__()
|
super(MonitorCallback, self).__init__()
|
||||||
self.filepath = Path(filepath)
|
self.filepath = Path(filepath)
|
||||||
self._monitor_df = pd.DataFrame()
|
self._monitor_df = pd.DataFrame()
|
||||||
|
self._monitor_dict = dict()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.plotting = plotting
|
self.plotting = plotting
|
||||||
self.started = False
|
self.started = False
|
||||||
@ -113,12 +68,17 @@ class MonitorCallback(BaseCallback):
|
|||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
for _, info in enumerate(self.locals.get('infos', [])):
|
||||||
|
self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items()
|
||||||
|
if key not in ['terminal_observation', 'episode']}
|
||||||
|
|
||||||
for env_idx, done in enumerate(self.locals.get('dones', [])):
|
for env_idx, done in enumerate(self.locals.get('dones', [])):
|
||||||
if done:
|
if done:
|
||||||
env_monitor_df = self.locals['infos'][env_idx]['monitor'].to_pd_dataframe()
|
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
|
||||||
|
self._monitor_dict = dict()
|
||||||
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
||||||
env_monitor_df = env_monitor_df.aggregate(
|
env_monitor_df = env_monitor_df.aggregate(
|
||||||
{col: 'mean' if 'amount' in col or 'count' in col else 'sum' for col in columns}
|
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
||||||
)
|
)
|
||||||
env_monitor_df['episode'] = len(self._monitor_df)
|
env_monitor_df['episode'] = len(self._monitor_df)
|
||||||
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
||||||
|
@ -25,8 +25,9 @@ def plot(filepath, ext='png', **kwargs):
|
|||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
figure = plt.gcf()
|
figure = plt.gcf()
|
||||||
plt.show()
|
|
||||||
figure.savefig(str(filepath), format=ext)
|
figure.savefig(str(filepath), format=ext)
|
||||||
|
plt.show()
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
|
||||||
def prepare_plot(filepath, results_df, ext='png'):
|
def prepare_plot(filepath, results_df, ext='png'):
|
||||||
|
6
main.py
6
main.py
@ -56,16 +56,16 @@ if __name__ == '__main__':
|
|||||||
# combine_runs('debug_out/PPO_1622399010')
|
# combine_runs('debug_out/PPO_1622399010')
|
||||||
# exit()
|
# exit()
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
|
|
||||||
for modeL_type in [PPO]:
|
for modeL_type in [A2C, PPO]:
|
||||||
for seed in range(5):
|
for seed in range(5):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_size=7,
|
||||||
allow_diagonal_movement=False, allow_no_op=False)
|
allow_diagonal_movement=False, allow_no_op=False)
|
||||||
|
|
||||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
||||||
|
@ -1,16 +1,11 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
|
||||||
|
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
|
||||||
from stable_baselines3.common.callbacks import CallbackList
|
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
|
||||||
from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
||||||
from environments.logging.monitor import MonitorCallback
|
|
||||||
from environments.logging.training import TraningMonitor
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -20,7 +15,7 @@ if __name__ == '__main__':
|
|||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props)
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props)
|
||||||
|
|
||||||
out_path = Path('debug_out')
|
out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\PPO_1622485791\1_PPO_1622485791')
|
||||||
model_files = list(natsorted(out_path.rglob('*.zip')))
|
model_files = list(natsorted(out_path.rglob('*.zip')))
|
||||||
this_model = model_files[0]
|
this_model = model_files[0]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user