Stable Baseline Running
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import abc
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
import gym
|
||||
@@ -61,17 +62,14 @@ class BaseFactory(gym.Env):
|
||||
self.allow_horizontal_movement = True
|
||||
self.allow_no_OP = True
|
||||
self._monitor_list = list()
|
||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP)
|
||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||
self.level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
)
|
||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
||||
self.reset()
|
||||
|
||||
def __init_subclass__(cls):
|
||||
print(cls)
|
||||
|
||||
def register_additional_actions(self):
|
||||
def register_additional_actions(self) -> int:
|
||||
raise NotImplementedError('Please register additional actions ')
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
@@ -111,6 +109,8 @@ class BaseFactory(gym.Env):
|
||||
agent_i_state = AgentState(agent_i, action)
|
||||
if self._is_moving_action(action):
|
||||
pos, valid = self.move_or_colide(agent_i, action)
|
||||
elif self._is_no_op(action):
|
||||
pos, valid = self.agent_i_position(agent_i), True
|
||||
else:
|
||||
pos, valid = self.additional_actions(agent_i, action)
|
||||
# Update state accordingly
|
||||
@@ -129,10 +129,10 @@ class BaseFactory(gym.Env):
|
||||
return self.state, self.cumulative_reward, self.done, info
|
||||
|
||||
def _is_moving_action(self, action):
|
||||
if action < self.movement_actions:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return action < self.movement_actions
|
||||
|
||||
def _is_no_op(self, action):
|
||||
return self.allow_no_OP and (action - self.movement_actions) == 0
|
||||
|
||||
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
|
||||
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
|
||||
|
||||
Reference in New Issue
Block a user