mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Stable Baseline Running
This commit is contained in:
parent
134f06b3d7
commit
c1cb7a4ffc
@ -1,4 +1,3 @@
|
|||||||
import abc
|
|
||||||
from typing import List, Union, Iterable
|
from typing import List, Union, Iterable
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -43,10 +42,6 @@ class BaseFactory(gym.Env):
|
|||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
|
return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32)
|
||||||
|
|
||||||
@property
|
|
||||||
def monitor_as_df_list(self):
|
|
||||||
return [x.to_pd_dataframe() for x in self._monitor_list]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
return (int(self.allow_vertical_movement) + int(self.allow_horizontal_movement)) * 4
|
return (int(self.allow_vertical_movement) + int(self.allow_horizontal_movement)) * 4
|
||||||
@ -61,7 +56,6 @@ class BaseFactory(gym.Env):
|
|||||||
self.allow_vertical_movement = True
|
self.allow_vertical_movement = True
|
||||||
self.allow_horizontal_movement = True
|
self.allow_horizontal_movement = True
|
||||||
self.allow_no_OP = True
|
self.allow_no_OP = True
|
||||||
self._monitor_list = list()
|
|
||||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||||
self.level = h.one_hot_level(
|
self.level = h.one_hot_level(
|
||||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||||
@ -77,7 +71,6 @@ class BaseFactory(gym.Env):
|
|||||||
self.steps = 0
|
self.steps = 0
|
||||||
self.cumulative_reward = 0
|
self.cumulative_reward = 0
|
||||||
self.monitor = FactoryMonitor(self)
|
self.monitor = FactoryMonitor(self)
|
||||||
self._monitor_list.append(self.monitor)
|
|
||||||
self.agent_states = []
|
self.agent_states = []
|
||||||
# Agent placement ...
|
# Agent placement ...
|
||||||
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
||||||
@ -92,7 +85,6 @@ class BaseFactory(gym.Env):
|
|||||||
# state.shape = level, agent 1,..., agent n,
|
# state.shape = level, agent 1,..., agent n,
|
||||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||||
# Returns State
|
# Returns State
|
||||||
|
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
@ -122,11 +114,11 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
self.agent_states = states
|
self.agent_states = states
|
||||||
reward, info = self.calculate_reward(states)
|
reward, info = self.calculate_reward(states)
|
||||||
self.cumulative_reward += reward
|
|
||||||
|
|
||||||
if self.steps >= self.max_steps:
|
if self.steps >= self.max_steps:
|
||||||
self.done = True
|
self.done = True
|
||||||
return self.state, self.cumulative_reward, self.done, info
|
self.monitor.add('step_reward', reward)
|
||||||
|
return self.state, reward, self.done, info
|
||||||
|
|
||||||
def _is_moving_action(self, action):
|
def _is_moving_action(self, action):
|
||||||
return action < self.movement_actions
|
return action < self.movement_actions
|
||||||
|
@ -120,8 +120,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
|
dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
this_step_reward = -(dirty_tiles / current_dirt_amount)
|
this_step_reward = (dirty_tiles / current_dirt_amount)
|
||||||
except ZeroDivisionError:
|
except (ZeroDivisionError, RuntimeWarning):
|
||||||
this_step_reward = 0
|
this_step_reward = 0
|
||||||
|
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
@ -136,6 +136,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
||||||
self.monitor.set('dirt_amount', current_dirt_amount)
|
self.monitor.set('dirt_amount', current_dirt_amount)
|
||||||
self.monitor.set('dirty_tiles', dirty_tiles)
|
self.monitor.set('dirty_tiles', dirty_tiles)
|
||||||
|
print(f"reward is {this_step_reward}")
|
||||||
return this_step_reward, {}
|
return this_step_reward, {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,17 +60,26 @@ class MonitorCallback(BaseCallback):
|
|||||||
super(MonitorCallback, self).__init__()
|
super(MonitorCallback, self).__init__()
|
||||||
self._outpath = Path(outpath)
|
self._outpath = Path(outpath)
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
|
self._monitor_list = list()
|
||||||
self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick'
|
self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick'
|
||||||
self.env = env
|
self.env = env
|
||||||
self.started = False
|
self.started = False
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def monitor_as_df_list(self):
|
||||||
|
return [x.to_pd_dataframe() for x in self._monitor_list]
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self._on_training_start()
|
self._on_training_start()
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self._on_training_end()
|
self._on_training_end()
|
||||||
|
|
||||||
|
def _on_rollout_end(self) -> None:
|
||||||
|
self._monitor_list.append(self.env.monitor)
|
||||||
|
pass
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
if self.started:
|
if self.started:
|
||||||
pass
|
pass
|
||||||
@ -85,7 +94,7 @@ class MonitorCallback(BaseCallback):
|
|||||||
else:
|
else:
|
||||||
# self.out_file.unlink(missing_ok=True)
|
# self.out_file.unlink(missing_ok=True)
|
||||||
with self.out_file.open('wb') as f:
|
with self.out_file.open('wb') as f:
|
||||||
pickle.dump(self.env.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user