Restructuring and Testing Done
This commit is contained in:
0
environments/factory/base/__init__.py
Normal file
0
environments/factory/base/__init__.py
Normal file
370
environments/factory/base/base_factory.py
Normal file
370
environments/factory/base/base_factory.py
Normal file
@ -0,0 +1,370 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
import yaml
|
||||
from gym.wrappers import FrameStack
|
||||
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Slice, Agent, Tile, Action, MoveableEntity
|
||||
from environments.factory.base.registers import StateSlices, Actions, Entities, Agents, Doors, FloorTiles
|
||||
from environments.utility_classes import MovementProperties
|
||||
|
||||
REC_TAC = 'rec'
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return spaces.Discrete(self._actions.n)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
|
||||
agent_slice = (self.n_agents - 1) if self.combin_agent_slices_in_obs else agent_slice
|
||||
if self.pomdp_radius:
|
||||
shape = (self._obs_cube.shape[0] - agent_slice, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1)
|
||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
return space
|
||||
else:
|
||||
shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._obs_cube.shape)]
|
||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
return space
|
||||
|
||||
@property
|
||||
def pomdp_diameter(self):
|
||||
return self.pomdp_radius * 2 + 1
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
@property
|
||||
def additional_actions(self) -> Union[str, List[str]]:
|
||||
"""
|
||||
When heriting from this Base Class, you musst implement this methode!!!
|
||||
|
||||
:return: A list of Actions-object holding all additional actions.
|
||||
:rtype: List[Action]
|
||||
"""
|
||||
raise NotImplementedError('Please register additional actions ')
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Union[Entities, List[Entities]]:
|
||||
"""
|
||||
When heriting from this Base Class, you musst implement this methode!!!
|
||||
|
||||
:return: A single Entites collection or a list of such.
|
||||
:rtype: Union[Entities, List[Entities]]
|
||||
"""
|
||||
raise NotImplementedError('Please register additional entities.')
|
||||
|
||||
@property
|
||||
def additional_slices(self) -> Union[Slice, List[Slice]]:
|
||||
"""
|
||||
When heriting from this Base Class, you musst implement this methode!!!
|
||||
|
||||
:return: A list of Slice-objects.
|
||||
:rtype: List[Slice]
|
||||
"""
|
||||
raise NotImplementedError('Please register additional slices.')
|
||||
|
||||
def __enter__(self):
|
||||
return self if self.frames_to_stack == 0 else FrameStack(self, self.frames_to_stack)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
||||
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
||||
combin_agent_slices_in_obs: bool = False, frames_to_stack=0, record_episodes=False,
|
||||
omit_agent_slice_in_obs=False, done_at_collision=False, **kwargs):
|
||||
assert (combin_agent_slices_in_obs != omit_agent_slice_in_obs) or \
|
||||
(not combin_agent_slices_in_obs and not omit_agent_slice_in_obs), \
|
||||
'Both options are exclusive'
|
||||
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||
|
||||
# Attribute Assignment
|
||||
self.movement_properties = movement_properties
|
||||
self.level_name = level_name
|
||||
self._level_shape = None
|
||||
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
self.pomdp_radius = pomdp_radius
|
||||
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||
self.frames_to_stack = frames_to_stack
|
||||
|
||||
self.done_at_collision = done_at_collision
|
||||
self.record_episodes = record_episodes
|
||||
self.parse_doors = parse_doors
|
||||
|
||||
# Actions
|
||||
self._actions = Actions(self.movement_properties, can_use_doors=self.parse_doors)
|
||||
if additional_actions := self.additional_actions:
|
||||
self._actions.register_additional_items(additional_actions)
|
||||
|
||||
self.reset()
|
||||
|
||||
def _init_state_slices(self) -> StateSlices:
|
||||
state_slices = StateSlices()
|
||||
|
||||
# Objects
|
||||
# Level
|
||||
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||
parsed_level = h.parse_level(level_filepath)
|
||||
level = [Slice(c.LEVEL.name, h.one_hot_level(parsed_level))]
|
||||
self._level_shape = level[0].shape
|
||||
|
||||
# Doors
|
||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
||||
doors = [Slice(c.DOORS.value, parsed_doors)] if parsed_doors.any() and self.parse_doors else []
|
||||
|
||||
# Agents
|
||||
agents = []
|
||||
for i in range(self.n_agents):
|
||||
agents.append(Slice(f'{c.AGENT.name}#{i}', np.zeros_like(level[0].slice)))
|
||||
state_slices.register_additional_items(level+doors+agents)
|
||||
|
||||
# Additional Slices from SubDomains
|
||||
if additional_slices := self.additional_slices:
|
||||
state_slices.register_additional_items(additional_slices)
|
||||
return state_slices
|
||||
|
||||
def _init_obs_cube(self) -> np.ndarray:
|
||||
x, y = self._slices.by_enum(c.LEVEL).shape
|
||||
state = np.zeros((len(self._slices), x, y))
|
||||
state[0] = self._slices.by_enum(c.LEVEL).slice
|
||||
if r := self.pomdp_radius:
|
||||
self._padded_obs_cube = np.full((len(self._slices), x + r*2, y + r*2), c.FREE_CELL.value)
|
||||
self._padded_obs_cube[0] = c.OCCUPIED_CELL.value
|
||||
self._padded_obs_cube[:, r:r+x, r:r+y] = state
|
||||
return state
|
||||
|
||||
def _init_entities(self):
|
||||
# Tile Init
|
||||
self._tiles = FloorTiles.from_argwhere_coordinates(self._slices.by_enum(c.LEVEL).free_tiles)
|
||||
|
||||
# Door Init
|
||||
if self.parse_doors:
|
||||
tiles = [self._tiles.by_pos(x) for x in self._slices.by_enum(c.DOORS).occupied_tiles]
|
||||
self._doors = Doors.from_tiles(tiles, context=self._tiles)
|
||||
|
||||
# Agent Init on random positions
|
||||
self._agents = Agents.from_tiles(np.random.choice(self._tiles, self.n_agents))
|
||||
entities = Entities()
|
||||
entities.register_additional_items([self._agents])
|
||||
|
||||
if self.parse_doors:
|
||||
entities.register_additional_items([self._doors])
|
||||
|
||||
if additional_entities := self.additional_entities:
|
||||
entities.register_additional_items([additional_entities])
|
||||
|
||||
return entities
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
self._slices = self._init_state_slices()
|
||||
self._obs_cube = self._init_obs_cube()
|
||||
self._entitites = self._init_entities()
|
||||
self._flush_state()
|
||||
self._steps = 0
|
||||
|
||||
info = self._summarize_state() if self.record_episodes else {}
|
||||
return None, None, None, info
|
||||
|
||||
def pre_step(self) -> None:
|
||||
pass
|
||||
|
||||
def post_step(self) -> dict:
|
||||
pass
|
||||
|
||||
def step(self, actions):
|
||||
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||
self._steps += 1
|
||||
done = False
|
||||
|
||||
# Pre step Hook for later use
|
||||
self.pre_step()
|
||||
|
||||
# Move this in a seperate function?
|
||||
for action, agent in zip(actions, self._agents):
|
||||
agent.clear_temp_sate()
|
||||
action_name = self._actions[action]
|
||||
if self._actions.is_moving_action(action):
|
||||
valid = self._move_or_colide(agent, action_name)
|
||||
elif self._actions.is_no_op(action):
|
||||
valid = c.VALID.value
|
||||
elif self._actions.is_door_usage(action):
|
||||
# Check if agent raly stands on a door:
|
||||
if door := self._doors.by_pos(agent.pos):
|
||||
door.use()
|
||||
valid = c.VALID.value
|
||||
# When he doesn't...
|
||||
else:
|
||||
valid = c.NOT_VALID.value
|
||||
else:
|
||||
valid = self.do_additional_actions(agent, action)
|
||||
agent.temp_action = action
|
||||
agent.temp_valid = valid
|
||||
|
||||
self._flush_state()
|
||||
|
||||
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
||||
for tile in tiles_with_collisions:
|
||||
guests = tile.guests_that_can_collide
|
||||
for i, guest in enumerate(guests):
|
||||
this_collisions = guests[:]
|
||||
del this_collisions[i]
|
||||
guest.temp_collisions = this_collisions
|
||||
|
||||
if self.done_at_collision and tiles_with_collisions:
|
||||
done = True
|
||||
|
||||
# Step the door close intervall
|
||||
if self.parse_doors:
|
||||
self._doors.tick_doors()
|
||||
|
||||
# Finalize
|
||||
reward, info = self.calculate_reward()
|
||||
if self._steps >= self.max_steps:
|
||||
done = True
|
||||
info.update(step_reward=reward, step=self._steps)
|
||||
if self.record_episodes:
|
||||
info.update(self._summarize_state())
|
||||
|
||||
# Post step Hook for later use
|
||||
info.update(self.post_step())
|
||||
|
||||
obs = self._get_observations()
|
||||
|
||||
return obs, reward, done, info
|
||||
|
||||
def _flush_state(self):
|
||||
self._obs_cube[np.arange(len(self._slices)) != self._slices.get_idx(c.LEVEL)] = c.FREE_CELL.value
|
||||
if self.parse_doors:
|
||||
for door in self._doors:
|
||||
if door.is_open:
|
||||
self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.IS_OPEN_DOOR.value
|
||||
else:
|
||||
self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.IS_CLOSED_DOOR.value
|
||||
for agent in self._agents:
|
||||
self._obs_cube[self._slices.get_idx_by_name(agent.name)][agent.pos] = c.OCCUPIED_CELL.value
|
||||
if agent.last_pos != h.NO_POS:
|
||||
self._obs_cube[self._slices.get_idx_by_name(agent.name)][agent.last_pos] = c.FREE_CELL.value
|
||||
|
||||
def _get_observations(self) -> np.ndarray:
|
||||
if self.n_agents == 1:
|
||||
obs = self._build_per_agent_obs(self._agents[0])
|
||||
elif self.n_agents >= 2:
|
||||
obs = np.stack([self._build_per_agent_obs(agent) for agent in self._agents])
|
||||
else:
|
||||
raise ValueError('n_agents cannot be smaller than 1!!')
|
||||
return obs
|
||||
|
||||
def _build_per_agent_obs(self, agent: Agent) -> np.ndarray:
|
||||
first_agent_slice = self._slices.AGENTSTARTIDX
|
||||
if r := self.pomdp_radius:
|
||||
x, y = self._level_shape
|
||||
self._padded_obs_cube[:, r:r + x, r:r + y] = self._obs_cube
|
||||
global_x, global_y = agent.pos
|
||||
global_x += r
|
||||
global_y += r
|
||||
x0, x1 = max(0, global_x - self.pomdp_radius), global_x + self.pomdp_radius + 1
|
||||
y0, y1 = max(0, global_y - self.pomdp_radius), global_y + self.pomdp_radius + 1
|
||||
obs = self._padded_obs_cube[:, x0:x1, y0:y1]
|
||||
else:
|
||||
obs = self._obs_cube
|
||||
if self.omit_agent_slice_in_obs:
|
||||
obs_new = obs[[key for key, val in self._slices.items() if c.AGENT.value not in val]]
|
||||
return obs_new
|
||||
else:
|
||||
if self.combin_agent_slices_in_obs:
|
||||
agent_obs = np.sum(obs[[key for key, slice in self._slices.items() if c.AGENT.name in slice.name]],
|
||||
axis=0, keepdims=True)
|
||||
obs = np.concatenate((obs[:first_agent_slice], agent_obs, obs[first_agent_slice+self.n_agents:]))
|
||||
return obs
|
||||
else:
|
||||
return obs
|
||||
|
||||
def do_additional_actions(self, agent_i: int, action: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
||||
tiles_with_collisions = list()
|
||||
for tile in self._tiles:
|
||||
if tile.is_occupied():
|
||||
guests = [guest for guest in tile.guests if guest.can_collide]
|
||||
if len(guests) >= 2:
|
||||
tiles_with_collisions.append(tile)
|
||||
return tiles_with_collisions
|
||||
|
||||
def _move_or_colide(self, agent: Agent, action: Action) -> Constants:
|
||||
new_tile, valid = self._check_agent_move(agent, action)
|
||||
if valid:
|
||||
# Does not collide width level boundaries
|
||||
return agent.move(new_tile)
|
||||
else:
|
||||
# Agent seems to be trying to collide in this step
|
||||
return c.NOT_VALID
|
||||
|
||||
def _check_agent_move(self, agent, action: Action) -> (Tile, bool):
|
||||
# Actions
|
||||
x_diff, y_diff = h.ACTIONMAP[action.name]
|
||||
x_new = agent.x + x_diff
|
||||
y_new = agent.y + y_diff
|
||||
|
||||
new_tile = self._tiles.by_pos((x_new, y_new))
|
||||
if new_tile:
|
||||
valid = c.VALID
|
||||
else:
|
||||
tile = agent.tile
|
||||
valid = c.VALID
|
||||
return tile, valid
|
||||
|
||||
if self.parse_doors and agent.last_pos != h.NO_POS:
|
||||
if door := self._doors.by_pos(agent.pos):
|
||||
if door.is_open:
|
||||
pass
|
||||
else: # door.is_closed:
|
||||
if door.is_linked(agent.last_pos, new_tile.pos):
|
||||
pass
|
||||
else:
|
||||
return agent.tile, c.NOT_VALID
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return new_tile, valid
|
||||
|
||||
def calculate_reward(self) -> (int, dict):
|
||||
# Returns: Reward, Info
|
||||
raise NotImplementedError
|
||||
|
||||
def render(self, mode='human'):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
# noinspection PyProtectedMember
|
||||
# d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(d, f)
|
||||
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def _summarize_state(self):
|
||||
summary = {f'{REC_TAC}_step': self._steps}
|
||||
for entity in self._entitites:
|
||||
if hasattr(entity, 'summarize_state'):
|
||||
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
||||
return summary
|
266
environments/factory/base/objects.py
Normal file
266
environments/factory/base/objects.py
Normal file
@ -0,0 +1,266 @@
|
||||
import itertools
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
import itertools
|
||||
|
||||
|
||||
def sub(p, q):
|
||||
return p - q
|
||||
|
||||
|
||||
class Object:
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def i(self):
|
||||
return self._identifier
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._identifier
|
||||
|
||||
def __init__(self, identifier, **kwargs):
|
||||
self._identifier = identifier
|
||||
if kwargs:
|
||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self._identifier})'
|
||||
|
||||
|
||||
class Action(Object):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.i
|
||||
|
||||
def __init__(self, *args):
|
||||
super(Action, self).__init__(*args)
|
||||
|
||||
|
||||
class Slice(Object):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.slice.shape
|
||||
|
||||
@property
|
||||
def occupied_tiles(self):
|
||||
return np.argwhere(self.slice == c.OCCUPIED_CELL.value)
|
||||
|
||||
@property
|
||||
def free_tiles(self):
|
||||
return np.argwhere(self.slice == c.FREE_CELL.value)
|
||||
|
||||
def __init__(self, identifier, arrayslice):
|
||||
super(Slice, self).__init__(identifier)
|
||||
self.slice = arrayslice
|
||||
|
||||
|
||||
class Wall(Object):
|
||||
pass
|
||||
|
||||
|
||||
class Tile(Object):
|
||||
|
||||
@property
|
||||
def guests_that_can_collide(self):
|
||||
return [x for x in self.guests if x.can_collide]
|
||||
|
||||
@property
|
||||
def guests(self):
|
||||
return self._guests.values()
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return self._pos
|
||||
|
||||
def __init__(self, i, pos):
|
||||
super(Tile, self).__init__(i)
|
||||
self._guests = dict()
|
||||
self._pos = tuple(pos)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._guests)
|
||||
|
||||
def is_empty(self):
|
||||
return not len(self._guests)
|
||||
|
||||
def is_occupied(self):
|
||||
return len(self._guests)
|
||||
|
||||
def enter(self, guest):
|
||||
if guest.name not in self._guests:
|
||||
self._guests.update({guest.name: guest})
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def leave(self, guest):
|
||||
try:
|
||||
del self._guests[guest.name]
|
||||
except (ValueError, KeyError):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Entity(Object):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return self._tile.pos
|
||||
|
||||
@property
|
||||
def tile(self):
|
||||
return self._tile
|
||||
|
||||
def __init__(self, identifier, tile: Tile, **kwargs):
|
||||
super(Entity, self).__init__(identifier, **kwargs)
|
||||
self._tile = tile
|
||||
|
||||
def summarize_state(self):
|
||||
return self.__dict__.copy()
|
||||
|
||||
|
||||
class MoveableEntity(Entity):
|
||||
|
||||
@property
|
||||
def last_tile(self):
|
||||
return self._last_tile
|
||||
|
||||
@property
|
||||
def last_pos(self):
|
||||
if self._last_tile:
|
||||
return self._last_tile.pos
|
||||
else:
|
||||
return h.NO_POS
|
||||
|
||||
@property
|
||||
def direction_of_view(self):
|
||||
last_x, last_y = self.last_pos
|
||||
curr_x, curr_y = self.pos
|
||||
return last_x-curr_x, last_y-curr_y
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MoveableEntity, self).__init__(*args, **kwargs)
|
||||
self._last_tile = None
|
||||
|
||||
def move(self, next_tile):
|
||||
curr_tile = self.tile
|
||||
if curr_tile != next_tile:
|
||||
next_tile.enter(self)
|
||||
curr_tile.leave(self)
|
||||
self._tile = next_tile
|
||||
self._last_tile = curr_tile
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class Door(Entity):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return 1 if self.is_closed else -1
|
||||
|
||||
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=500):
|
||||
super(Door, self).__init__(*args)
|
||||
self._state = c.IS_CLOSED_DOOR
|
||||
self.auto_close_interval = auto_close_interval
|
||||
self.time_to_close = -1
|
||||
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
||||
neighbor_tiles = [context.by_pos(tuple([sum(x) for x in zip(self.pos, diff)])) for diff in neighbor_pos]
|
||||
neighbor_pos = [x.pos for x in neighbor_tiles if x]
|
||||
possible_connections = itertools.combinations(neighbor_pos, 2)
|
||||
self.connectivity = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
if not max(abs(np.subtract(a, b))) > 1:
|
||||
self.connectivity.add_edge(a, b)
|
||||
if not closed_on_init:
|
||||
self._open()
|
||||
|
||||
@property
|
||||
def is_closed(self):
|
||||
return self._state == c.IS_CLOSED_DOOR
|
||||
|
||||
@property
|
||||
def is_open(self):
|
||||
return self._state == c.IS_OPEN_DOOR
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._state
|
||||
|
||||
def use(self):
|
||||
if self._state == c.IS_OPEN_DOOR:
|
||||
self._close()
|
||||
else:
|
||||
self._open()
|
||||
|
||||
def tick(self):
|
||||
if self.is_open and len(self.tile) == 1 and self.time_to_close:
|
||||
self.time_to_close -= 1
|
||||
elif self.is_open and not self.time_to_close and len(self.tile) == 1:
|
||||
self.use()
|
||||
|
||||
def _open(self):
|
||||
self.connectivity.add_edges_from([(self.pos, x) for x in self.connectivity.nodes])
|
||||
self._state = c.IS_OPEN_DOOR
|
||||
self.time_to_close = self.auto_close_interval
|
||||
|
||||
def _close(self):
|
||||
self.connectivity.remove_node(self.pos)
|
||||
self._state = c.IS_CLOSED_DOOR
|
||||
|
||||
def is_linked(self, old_pos, new_pos):
|
||||
try:
|
||||
_ = nx.shortest_path(self.connectivity, old_pos, new_pos)
|
||||
return True
|
||||
except nx.exception.NetworkXNoPath:
|
||||
return False
|
||||
|
||||
|
||||
class Agent(MoveableEntity):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(Agent, self).__init__(*args)
|
||||
self.clear_temp_sate()
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
def clear_temp_sate(self):
|
||||
self.temp_collisions = []
|
||||
self.temp_valid = None
|
||||
self.temp_action = -1
|
292
environments/factory/base/registers.py
Normal file
292
environments/factory/base/registers.py
Normal file
@ -0,0 +1,292 @@
|
||||
import itertools
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Slice, Action
|
||||
from environments.utility_classes import MovementProperties
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
|
||||
|
||||
class Register:
|
||||
_accepted_objects = Entity
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions: (int, int), tiles):
|
||||
entities = [cls._accepted_objects(i, tiles.by_pos(position)) for i, position in enumerate(positions)]
|
||||
registered_obj = cls()
|
||||
registered_obj.register_additional_items(entities)
|
||||
return registered_obj
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return len(self)
|
||||
|
||||
def __init__(self):
|
||||
self._register = dict()
|
||||
self._names = dict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._register)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.values())
|
||||
|
||||
def __add__(self, other: _accepted_objects):
|
||||
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
||||
f'{self._accepted_objects}, ' \
|
||||
f'but were {other.__class__}.,'
|
||||
self._names.update({other.name: len(self._register)})
|
||||
self._register.update({len(self._register): other})
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: List[_accepted_objects]):
|
||||
for other in others:
|
||||
self + other
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
return self._register.keys()
|
||||
|
||||
def values(self):
|
||||
return self._register.values()
|
||||
|
||||
def items(self):
|
||||
return self._register.items()
|
||||
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
return self._register[item]
|
||||
except KeyError:
|
||||
print('NO')
|
||||
raise
|
||||
|
||||
def by_name(self, item):
|
||||
return self[self._names[item]]
|
||||
|
||||
def by_enum(self, enum: Enum):
|
||||
return self[self._names[enum.name]]
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self._register})'
|
||||
|
||||
def get_name(self, item):
|
||||
return self._register[item].name
|
||||
|
||||
def get_idx_by_name(self, item):
|
||||
return self._names[item]
|
||||
|
||||
def get_idx(self, enum: Enum):
|
||||
return self._names[enum.name]
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, **kwargs):
|
||||
entities = [cls._accepted_objects(f'{cls._accepted_objects.__name__.upper()}#{i}', tile, **kwargs)
|
||||
for i, tile in enumerate(tiles)]
|
||||
registered_obj = cls()
|
||||
registered_obj.register_additional_items(entities)
|
||||
return registered_obj
|
||||
|
||||
|
||||
class EntityRegister(Register):
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates):
|
||||
tiles = cls()
|
||||
tiles.register_additional_items([cls._accepted_objects(i, pos) for i, pos in enumerate(argwhere_coordinates)])
|
||||
return tiles
|
||||
|
||||
def __init__(self):
|
||||
super(EntityRegister, self).__init__()
|
||||
self._tiles = dict()
|
||||
|
||||
def __add__(self, other):
|
||||
super(EntityRegister, self).__add__(other)
|
||||
self._tiles[other.pos] = other
|
||||
|
||||
def by_pos(self, pos):
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return self._tiles[pos]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
class Entities(Register):
|
||||
|
||||
_accepted_objects = Register
|
||||
|
||||
def __init__(self):
|
||||
super(Entities, self).__init__()
|
||||
|
||||
def __iter__(self):
|
||||
return iter([x for sublist in self.values() for x in sublist])
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions):
|
||||
raise AttributeError()
|
||||
|
||||
|
||||
class FloorTiles(EntityRegister):
|
||||
_accepted_objects = Tile
|
||||
|
||||
@property
|
||||
def occupied_tiles(self):
|
||||
tiles = [tile for tile in self if tile.is_occupied()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
@property
|
||||
def empty_tiles(self):
|
||||
tiles = [tile for tile in self if tile.is_empty()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
|
||||
class Agents(Register):
|
||||
|
||||
_accepted_objects = Agent
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return [agent.pos for agent in self]
|
||||
|
||||
|
||||
class Doors(EntityRegister):
|
||||
_accepted_objects = Door
|
||||
|
||||
def tick_doors(self):
|
||||
for door in self:
|
||||
door.tick()
|
||||
|
||||
|
||||
class Actions(Register):
|
||||
|
||||
_accepted_objects = Action
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return self._movement_actions
|
||||
|
||||
def __init__(self, movement_properties: MovementProperties, can_use_doors=False):
|
||||
self.allow_no_op = movement_properties.allow_no_op
|
||||
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
||||
self.allow_square_movement = movement_properties.allow_square_movement
|
||||
self.can_use_doors = can_use_doors
|
||||
super(Actions, self).__init__()
|
||||
|
||||
if self.allow_square_movement:
|
||||
self.register_additional_items([self._accepted_objects(direction) for direction in h.MANHATTAN_MOVES])
|
||||
if self.allow_diagonal_movement:
|
||||
self.register_additional_items([self._accepted_objects(direction) for direction in h.DIAGONAL_MOVES])
|
||||
self._movement_actions = self._register.copy()
|
||||
if self.can_use_doors:
|
||||
self.register_additional_items([self._accepted_objects('use_door')])
|
||||
if self.allow_no_op:
|
||||
self.register_additional_items([self._accepted_objects('no-op')])
|
||||
|
||||
def is_moving_action(self, action: Union[int]):
|
||||
#if isinstance(action, Action):
|
||||
# return (action.name in h.MANHATTAN_MOVES and self.allow_square_movement) or \
|
||||
# (action.name in h.DIAGONAL_MOVES and self.allow_diagonal_movement)
|
||||
#else:
|
||||
return action in self.movement_actions.keys()
|
||||
|
||||
def is_no_op(self, action: Union[str, int]):
|
||||
if isinstance(action, str):
|
||||
action = self.by_name(action)
|
||||
return self[action].name == 'no-op'
|
||||
|
||||
def is_door_usage(self, action: Union[str, int]):
|
||||
if isinstance(action, str):
|
||||
action = self.by_name(action)
|
||||
return self[action].name == 'use_door'
|
||||
|
||||
|
||||
class StateSlices(Register):
|
||||
|
||||
_accepted_objects = Slice
|
||||
|
||||
@property
|
||||
def AGENTSTARTIDX(self):
|
||||
if self._agent_start_idx:
|
||||
return self._agent_start_idx
|
||||
else:
|
||||
self._agent_start_idx = min([idx for idx, x in self.items() if c.AGENT.name in x.name])
|
||||
return self._agent_start_idx
|
||||
|
||||
def __init__(self):
|
||||
super(StateSlices, self).__init__()
|
||||
self._agent_start_idx = None
|
||||
|
||||
def _gather_occupation(self, excluded_slices):
|
||||
exclusion = excluded_slices or []
|
||||
assert isinstance(exclusion, (int, list))
|
||||
exclusion = exclusion if isinstance(exclusion, list) else [exclusion]
|
||||
|
||||
result = np.sum([x for i, x in self.items() if i not in exclusion], axis=0)
|
||||
return result
|
||||
|
||||
def free_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array:
|
||||
occupation = self._gather_occupation(excluded_slices)
|
||||
free_cells = np.argwhere(occupation == c.IS_FREE_CELL)
|
||||
np.random.shuffle(free_cells)
|
||||
return free_cells
|
||||
|
||||
def occupied_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array:
|
||||
occupation = self._gather_occupation(excluded_slices)
|
||||
occupied_cells = np.argwhere(occupation == c.IS_OCCUPIED_CELL.value)
|
||||
np.random.shuffle(occupied_cells)
|
||||
return occupied_cells
|
||||
|
||||
|
||||
class Zones(Register):
|
||||
|
||||
@property
|
||||
def danger_zone(self):
|
||||
return self._zone_slices[self.by_enum(c.DANGER_ZONE)]
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE.value]
|
||||
|
||||
def __init__(self, parsed_level):
|
||||
raise NotImplementedError('This needs a Rework')
|
||||
super(Zones, self).__init__()
|
||||
slices = list()
|
||||
self._accounting_zones = list()
|
||||
self._danger_zones = list()
|
||||
for symbol in np.unique(parsed_level):
|
||||
if symbol == h.WALL:
|
||||
continue
|
||||
elif symbol == h.DANGER_ZONE:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._danger_zones.append(symbol)
|
||||
else:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._accounting_zones.append(symbol)
|
||||
|
||||
self._zone_slices = np.stack(slices)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._zone_slices[item]
|
||||
|
||||
def get_name(self, item):
|
||||
return self._register[item]
|
||||
|
||||
def by_name(self, item):
|
||||
return self[super(Zones, self).by_name(item)]
|
||||
|
||||
def register_additional_items(self, other: Union[str, List[str]]):
|
||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
Reference in New Issue
Block a user