Accounting Zones for basefactory, mark them in level.txt by int, x for danger zones
This commit is contained in:
@ -7,20 +7,15 @@ import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
import yaml
|
||||
from gym.wrappers import FrameStack
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.utility_classes import Actions, StateSlice, AgentState, MovementProperties
|
||||
from environments.utility_classes import Actions, StateSlice, AgentState, MovementProperties, Zones
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
super(BaseFactory, self).__setattr__(key, Namespace(**value))
|
||||
else:
|
||||
super(BaseFactory, self).__setattr__(key, value)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return spaces.Discrete(self._actions.n)
|
||||
@ -40,11 +35,20 @@ class BaseFactory(gym.Env):
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
def __enter__(self):
|
||||
return self if self.frames_to_stack == 0 else FrameStack(self, self.frames_to_stack)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
||||
movement_properties: MovementProperties = MovementProperties(),
|
||||
combin_agent_slices_in_obs: bool = False,
|
||||
combin_agent_slices_in_obs: bool = False, frames_to_stack=0,
|
||||
omit_agent_slice_in_obs=False, **kwargs):
|
||||
assert combin_agent_slices_in_obs != omit_agent_slice_in_obs, 'Both options are exclusive'
|
||||
assert (combin_agent_slices_in_obs != omit_agent_slice_in_obs) or \
|
||||
(not combin_agent_slices_in_obs and not omit_agent_slice_in_obs), \
|
||||
'Both options are exclusive'
|
||||
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||
|
||||
self.movement_properties = movement_properties
|
||||
self.level_name = level_name
|
||||
@ -54,17 +58,19 @@ class BaseFactory(gym.Env):
|
||||
self.pomdp_radius = pomdp_radius
|
||||
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||
self.frames_to_stack = frames_to_stack
|
||||
|
||||
self.done_at_collision = False
|
||||
_actions = Actions(self.movement_properties)
|
||||
self._actions = _actions + self.additional_actions
|
||||
|
||||
self._level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt')
|
||||
)
|
||||
level_filepath = Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||
parsed_level = h.parse_level(level_filepath)
|
||||
self._level = h.one_hot_level(parsed_level)
|
||||
self._state_slices = StateSlice(n_agents)
|
||||
if 'additional_slices' in kwargs:
|
||||
self._state_slices.register_additional_items(kwargs.get('additional_slices'))
|
||||
self._zones = Zones(parsed_level)
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
@ -259,7 +265,7 @@ class BaseFactory(gym.Env):
|
||||
# d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
super(BaseFactory, self).save_params()
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(d, f)
|
||||
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
Reference in New Issue
Block a user