mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-20 03:08:08 +02:00
alles was ich hab
This commit is contained in:
@ -113,15 +113,19 @@ class BaseFactory(gym.Env):
|
|||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
return self._actions.movement_actions
|
return self._actions.movement_actions
|
||||||
|
|
||||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, **kwargs):
|
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||||
|
allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True, **kwargs):
|
||||||
|
self.allow_no_op = allow_no_op
|
||||||
|
self.allow_diagonal_movement = allow_diagonal_movement
|
||||||
|
self.allow_square_movement = allow_square_movement
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.pomdp_radius = pomdp_radius
|
self.pomdp_radius = pomdp_radius
|
||||||
|
|
||||||
self.done_at_collision = False
|
self.done_at_collision = False
|
||||||
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
|
_actions = Actions(allow_square_movement=self.allow_square_movement,
|
||||||
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
|
allow_diagonal_movement=self.allow_diagonal_movement,
|
||||||
allow_no_op=kwargs.get('allow_no_op', True))
|
allow_no_op=allow_no_op)
|
||||||
self._actions = _actions + self.additional_actions
|
self._actions = _actions + self.additional_actions
|
||||||
|
|
||||||
self._level = h.one_hot_level(
|
self._level = h.one_hot_level(
|
||||||
@ -165,12 +169,16 @@ class BaseFactory(gym.Env):
|
|||||||
pos = self._agent_states[0].pos
|
pos = self._agent_states[0].pos
|
||||||
# pos = [agent_state.pos for agent_state in self.agent_states]
|
# pos = [agent_state.pos for agent_state in self.agent_states]
|
||||||
# obs = [] ... list comprehension... pos per agent
|
# obs = [] ... list comprehension... pos per agent
|
||||||
npad = [(0, 0)] + [(self.pomdp_radius, self.pomdp_radius)] * (self._state.ndim - 1)
|
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
|
||||||
x_roll = self.pomdp_radius-pos[0]
|
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
|
||||||
y_roll = self.pomdp_radius-pos[1]
|
obs = self._state[:, x0:x1, y0:y1]
|
||||||
padded_state = np.pad(self._state, pad_width=npad, mode='constant', constant_values=0)
|
if obs.shape[1] != self.pomdp_radius ** 2 + 1 or obs.shape[2] != self.pomdp_radius ** 2 + 1:
|
||||||
padded_state = np.roll(np.roll(padded_state, x_roll, axis=1), y_roll, axis=2)
|
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
|
||||||
obs = padded_state[:, :self.pomdp_radius * 2 + 1, :self.pomdp_radius * 2 + 1]
|
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||||
|
obs_padded[:,
|
||||||
|
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||||
|
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||||
|
obs = obs_padded
|
||||||
else:
|
else:
|
||||||
obs = self._state
|
obs = self._state
|
||||||
return obs
|
return obs
|
||||||
@ -211,8 +219,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
info.update(step_reward=reward, step=self.steps)
|
info.update(step_reward=reward, step=self.steps)
|
||||||
|
|
||||||
obs = self._return_state()
|
return None, reward, done, info
|
||||||
return obs, reward, done, info
|
|
||||||
|
|
||||||
def _is_moving_action(self, action):
|
def _is_moving_action(self, action):
|
||||||
return action in self._actions.movement_actions
|
return action in self._actions.movement_actions
|
||||||
|
@ -150,7 +150,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
|
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
|
||||||
info_dict.update(dirt_cleaned=1)
|
info_dict.update(dirt_cleaned=1)
|
||||||
else:
|
else:
|
||||||
reward -= 0.0
|
reward -= 0.01
|
||||||
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
||||||
f'at {agent_state.pos}, but was unsucsessfull.')
|
f'at {agent_state.pos}, but was unsucsessfull.')
|
||||||
info_dict.update(failed_cleanup_attempt=1)
|
info_dict.update(failed_cleanup_attempt=1)
|
||||||
@ -162,14 +162,14 @@ class SimpleFactory(BaseFactory):
|
|||||||
else:
|
else:
|
||||||
# info_dict.update(collision=1)
|
# info_dict.update(collision=1)
|
||||||
# self.print('collision')
|
# self.print('collision')
|
||||||
reward -= 0.00
|
reward -= 0.01
|
||||||
|
|
||||||
else:
|
else:
|
||||||
info_dict.update(no_op=1)
|
info_dict.update(no_op=1)
|
||||||
reward -= 0.00
|
reward -= 0.00
|
||||||
|
|
||||||
for entity in list_of_collisions:
|
for entity in list_of_collisions:
|
||||||
info_dict.update({f'agent_{agent_state.i}_vs_{self._state_slices.by_name(entity)}': 1})
|
info_dict.update({f'agent_{agent_state.i}_vs_{entity}': 1})
|
||||||
|
|
||||||
info_dict.update(dirt_amount=current_dirt_amount)
|
info_dict.update(dirt_amount=current_dirt_amount)
|
||||||
info_dict.update(dirty_tile_count=dirty_tiles)
|
info_dict.update(dirty_tile_count=dirty_tiles)
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
PALETTE = 10 * (
|
PALETTE = 10 * (
|
||||||
"#377eb8",
|
"#377eb8",
|
||||||
"#4daf4a",
|
"#4daf4a",
|
||||||
|
6
main.py
6
main.py
@ -12,7 +12,6 @@ from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
|||||||
from environments.helpers import IGNORED_DF_COLUMNS
|
from environments.helpers import IGNORED_DF_COLUMNS
|
||||||
from environments.logging.monitor import MonitorCallback
|
from environments.logging.monitor import MonitorCallback
|
||||||
from environments.logging.plotting import prepare_plot
|
from environments.logging.plotting import prepare_plot
|
||||||
from environments.logging.training import TraningMonitor
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -55,7 +54,7 @@ if __name__ == '__main__':
|
|||||||
for seed in range(5):
|
for seed in range(5):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
||||||
allow_diagonal_movement=False, allow_no_op=False, verbose=True)
|
allow_diagonal_movement=False, allow_no_op=False, verbose=False)
|
||||||
|
|
||||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
@ -65,8 +64,7 @@ if __name__ == '__main__':
|
|||||||
out_path /= identifier
|
out_path /= identifier
|
||||||
|
|
||||||
callbacks = CallbackList(
|
callbacks = CallbackList(
|
||||||
[TraningMonitor(out_path / f'train_logging_{identifier}.csv'),
|
[MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
||||||
MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model.learn(total_timesteps=int(2e5), callback=callbacks)
|
model.learn(total_timesteps=int(2e5), callback=callbacks)
|
||||||
|
@ -14,13 +14,13 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\A2C_1622557712')
|
out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\A2C_1622558379')
|
||||||
with (out_path / f'env_{out_path.name}.pick').open('rb') as f:
|
with (out_path / f'env_{out_path.name}.pick').open('rb') as f:
|
||||||
env_kwargs = pickle.load(f)
|
env_kwargs = pickle.load(f)
|
||||||
env = SimpleFactory(**env_kwargs)
|
env = SimpleFactory(allow_no_op=False, allow_diagonal_movement=False, allow_square_movement=True, **env_kwargs)
|
||||||
|
|
||||||
# Edit THIS:
|
# Edit THIS:
|
||||||
model_path = out_path
|
model_path = out_path / '1_A2C_1622558379'
|
||||||
|
|
||||||
model_files = list(natsorted(out_path.rglob('*.zip')))
|
model_files = list(natsorted(out_path.rglob('*.zip')))
|
||||||
this_model = model_files[0]
|
this_model = model_files[0]
|
||||||
|
Reference in New Issue
Block a user