mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
no more tiles no more floor
This commit is contained in:
@@ -25,10 +25,8 @@ This file is used for:
|
||||
LEVELS_DIR = 'modules/levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
# Not used anymore? Clean!
|
||||
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
|
||||
'episode']
|
||||
|
||||
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
||||
@@ -223,7 +221,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.groups.agents import Agents
|
||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
||||
from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors
|
||||
from marl_factory_grid.environment.groups.walls import Walls
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
@@ -34,16 +34,14 @@ class LevelParser(object):
|
||||
|
||||
def do_init(self):
|
||||
# Global Entities
|
||||
list_of_all_floors = ([tuple(floor) for floor in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
||||
entities = Entities(list_of_all_floors)
|
||||
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
||||
entities = Entities(list_of_all_positions)
|
||||
|
||||
# Walls
|
||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||
entities.add_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_coordinates(list_of_all_floors, self.size)
|
||||
entities.add_items({c.FLOOR: floor})
|
||||
# Agents
|
||||
entities.add_items({c.AGENT: Agents(self.size)})
|
||||
|
||||
# All other
|
||||
|
||||
@@ -9,6 +9,7 @@ from numba import njit
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
from marl_factory_grid.utils.utility_classes import Floor
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
@@ -39,6 +40,7 @@ class OBSBuilder(object):
|
||||
|
||||
self.reset_struc_obs_block(state)
|
||||
self.curr_lightmaps = dict()
|
||||
self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist})
|
||||
|
||||
def reset_struc_obs_block(self, state):
|
||||
self._curr_env_step = state.curr_step
|
||||
@@ -82,19 +84,23 @@ class OBSBuilder(object):
|
||||
self._sort_and_name_observation_conf(agent)
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
|
||||
# Handle in-grid observations aka visible observations
|
||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities)
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros((self.pomdp_d, self.pomdp_d)))
|
||||
for e in set(visible_entitites):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out or range
|
||||
pass
|
||||
# Handle in-grid observations aka visible observations (Things on the map, with pos)
|
||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict)
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||
if self.pomdp_r:
|
||||
for e in set(visible_entitites):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out or range
|
||||
pass
|
||||
else:
|
||||
for e in set(visible_entitites):
|
||||
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||
|
||||
pre_sort_obs = dict(pre_sort_obs)
|
||||
obs = np.zeros((len(agent_want_obs), self.pomdp_d, self.pomdp_d))
|
||||
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
||||
|
||||
for idx, l_name in enumerate(agent_want_obs):
|
||||
try:
|
||||
@@ -144,13 +150,26 @@ class OBSBuilder(object):
|
||||
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
||||
|
||||
try:
|
||||
self.curr_lightmaps[agent.name] = pre_sort_obs[c.FLOORS].astype(bool)
|
||||
light_map = np.zeros(self.obs_shape)
|
||||
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False))
|
||||
if self.pomdp_r:
|
||||
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor]
|
||||
else:
|
||||
coords = [x.pos for x in visible_floor]
|
||||
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1)
|
||||
self.curr_lightmaps[agent.name] = light_map
|
||||
except KeyError:
|
||||
print()
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
self.ray_caster[agent.name] = RayCaster(agent, self.pomdp_r)
|
||||
'''
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
'''
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
obs_layers = []
|
||||
|
||||
for obs_str in agent.observations:
|
||||
@@ -173,7 +192,7 @@ class OBSBuilder(object):
|
||||
names.extend([x.name for x in agent.collection if x.name != agent.name])
|
||||
else:
|
||||
names.append(val)
|
||||
combined = Combined(names, self.pomdp_r, identifier=agent.name)
|
||||
combined = Combined(names, self.size, identifier=agent.name)
|
||||
self.all_obs[combined.name] = combined
|
||||
obs_layers.append(combined.name)
|
||||
elif obs_str == c.OTHERS:
|
||||
@@ -183,19 +202,18 @@ class OBSBuilder(object):
|
||||
else:
|
||||
obs_layers.append(obs_str)
|
||||
self.obs_layers[agent.name] = obs_layers
|
||||
self.curr_lightmaps[agent.name] = np.zeros((self.pomdp_d or self.level_shape[0],
|
||||
self.pomdp_d or self.level_shape[1]
|
||||
))
|
||||
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
|
||||
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||
self.n_rays = (self.pomdp_r + 1) * 8
|
||||
self.degs = degs
|
||||
self.ray_targets = self.build_ray_targets()
|
||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||
self._cache_dict = {}
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
@@ -211,30 +229,30 @@ class RayCaster:
|
||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, cache_dict, key, callback):
|
||||
if key not in cache_dict:
|
||||
cache_dict[key] = callback()
|
||||
return cache_dict[key]
|
||||
def ray_block_cache(self, key, callback):
|
||||
if key not in self._cache_dict:
|
||||
self._cache_dict[key] = callback()
|
||||
return self._cache_dict[key]
|
||||
|
||||
def visible_entities(self, entities):
|
||||
def visible_entities(self, pos_dict, reset_cache=True):
|
||||
visible = list()
|
||||
cache_blocking = {}
|
||||
if reset_cache:
|
||||
self._cache_dict = {}
|
||||
|
||||
for ray in self.get_rays():
|
||||
rx, ry = ray[0]
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
entities_hit = entities.pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache(cache_blocking,
|
||||
(x, y),
|
||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light))
|
||||
entities_hit = pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache((x, y),
|
||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
||||
)
|
||||
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
cache_blocking,
|
||||
key,
|
||||
lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light))
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(pos_dict[key]))
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderEntity:
|
||||
name: str
|
||||
pos: np.array
|
||||
value: float = 1
|
||||
value_operation: str = 'none'
|
||||
state: str = None
|
||||
id: int = 0
|
||||
aux: Any = None
|
||||
real_name: str = 'none'
|
||||
@@ -9,7 +9,7 @@ import pygame
|
||||
from typing import Tuple, Union
|
||||
import time
|
||||
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
|
||||
AGENT: str = 'agent'
|
||||
STATE_IDLE: str = 'idle'
|
||||
|
||||
@@ -3,8 +3,6 @@ from typing import List, Dict, Tuple
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
@@ -112,15 +110,10 @@ class Gamestate(object):
|
||||
results.extend(on_check_done_result)
|
||||
return results
|
||||
|
||||
# def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
# tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
# if sum([x.var_can_collide for x in e]) > 1]
|
||||
# # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
# return tiles
|
||||
|
||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||
positions = [pos for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.var_can_collide for x in e]) > 1]
|
||||
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
|
||||
if any([e.var_can_collide for e in entity_list_for_position])]
|
||||
return positions
|
||||
|
||||
def check_move_validity(self, moving_entity, position):
|
||||
@@ -128,6 +121,14 @@ class Gamestate(object):
|
||||
# and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||
if moving_entity.pos != position and not any(
|
||||
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
||||
moving_entity.var_is_blocking_pos and moving_entity.is_occupied()):
|
||||
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def check_pos_validity(self, position):
|
||||
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ ENTITIES = 'Objects'
|
||||
OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
ASSETS = 'Assets'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls',
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
|
||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
@@ -10,3 +14,37 @@ class MarlFrameStack(gym.ObservationWrapper):
|
||||
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||
return observation[0:].swapaxes(0, 1)
|
||||
return observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderEntity:
|
||||
name: str
|
||||
pos: np.array
|
||||
value: float = 1
|
||||
value_operation: str = 'none'
|
||||
state: str = None
|
||||
id: int = 0
|
||||
aux: Any = None
|
||||
real_name: str = 'none'
|
||||
|
||||
|
||||
@dataclass
|
||||
class Floor:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f"Floor({self.pos})"
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return self.x, self.y
|
||||
|
||||
x: int
|
||||
y: int
|
||||
var_is_blocking_light: bool = False
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.name == other.name
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
Reference in New Issue
Block a user