Stable Baseline Running

This commit is contained in:
steffen-illium
2021-05-19 16:50:42 +02:00
parent 575eec9ee6
commit b979a47b6f
4 changed files with 147 additions and 197 deletions

View File

@@ -1,3 +1,4 @@
import abc
from typing import List, Union, Iterable
import gym
@@ -61,17 +62,14 @@ class BaseFactory(gym.Env):
self.allow_horizontal_movement = True
self.allow_no_OP = True
self._monitor_list = list()
self._registered_actions = self.movement_actions + int(self.allow_no_OP)
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
)
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
self.reset()
def __init_subclass__(cls):
print(cls)
def register_additional_actions(self):
def register_additional_actions(self) -> int:
raise NotImplementedError('Please register additional actions ')
def reset(self) -> (np.ndarray, int, bool, dict):
@@ -111,6 +109,8 @@ class BaseFactory(gym.Env):
agent_i_state = AgentState(agent_i, action)
if self._is_moving_action(action):
pos, valid = self.move_or_colide(agent_i, action)
elif self._is_no_op(action):
pos, valid = self.agent_i_position(agent_i), True
else:
pos, valid = self.additional_actions(agent_i, action)
# Update state accordingly
@@ -129,10 +129,10 @@ class BaseFactory(gym.Env):
return self.state, self.cumulative_reward, self.done, info
def _is_moving_action(self, action):
if action < self.movement_actions:
return True
else:
return False
return action < self.movement_actions
def _is_no_op(self, action):
return self.allow_no_OP and (action - self.movement_actions) == 0
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices