mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-07 01:51:35 +02:00
fixed render funciton and obsbuilder
This commit is contained in:
@ -66,7 +66,13 @@ class Collection(_Objects):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_pairs(self):
|
def obs_pairs(self):
|
||||||
return [(x.name, x) for x in self]
|
pair_list = [(self.name, self)]
|
||||||
|
try:
|
||||||
|
if self.var_can_be_bound:
|
||||||
|
pair_list.extend([(a.name, a) for a in self])
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
return pair_list
|
||||||
|
|
||||||
def by_entity(self, entity):
|
def by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
@ -81,7 +87,10 @@ class Collection(_Objects):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
if self.var_has_position:
|
||||||
return [y for y in [x.render() for x in self] if y is not None]
|
return [y for y in [x.render() for x in self] if y is not None]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
|
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
|
||||||
|
@ -44,5 +44,9 @@ class GlobalPositions(Collection):
|
|||||||
def var_can_collide(self):
|
def var_can_collide(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_be_bound(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||||
|
@ -49,9 +49,6 @@ class Battery(_Object):
|
|||||||
summary.update(dict(belongs_to=self._bound_entity.name, chargeLevel=self.charge_level))
|
summary.update(dict(belongs_to=self._bound_entity.name, chargeLevel=self.charge_level))
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def render(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class Pod(Entity):
|
class Pod(Entity):
|
||||||
|
|
||||||
|
@ -23,6 +23,10 @@ class Batteries(Collection):
|
|||||||
def var_has_position(self):
|
def var_has_position(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_be_bound(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_tag(self):
|
def obs_tag(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
@ -12,7 +12,7 @@ class DoorIndicator(Entity):
|
|||||||
return d.VALUE_ACCESS_INDICATOR
|
return d.VALUE_ACCESS_INDICATOR
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return None
|
return []
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -135,7 +135,7 @@ class DropOffLocations(Collection):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trigger_drop_off_location_spawn(state, n_locations):
|
def trigger_drop_off_location_spawn(state, n_locations):
|
||||||
empty_positions = state.entities.empty_positions[:n_locations]
|
empty_positions = state.entities.empty_positions()[:n_locations]
|
||||||
do_entites = state[i.DROP_OFF]
|
do_entites = state[i.DROP_OFF]
|
||||||
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
||||||
do_entites.add_items(drop_offs)
|
do_entites.add_items(drop_offs)
|
||||||
|
@ -15,7 +15,6 @@ from marl_factory_grid.utils.utility_classes import Floor
|
|||||||
|
|
||||||
|
|
||||||
class OBSBuilder(object):
|
class OBSBuilder(object):
|
||||||
|
|
||||||
default_obs = [c.WALLS, c.OTHERS]
|
default_obs = [c.WALLS, c.OTHERS]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -95,20 +94,19 @@ class OBSBuilder(object):
|
|||||||
agent_want_obs = self.obs_layers[agent.name]
|
agent_want_obs = self.obs_layers[agent.name]
|
||||||
|
|
||||||
# Handle in-grid observations aka visible observations (Things on the map, with pos)
|
# Handle in-grid observations aka visible observations (Things on the map, with pos)
|
||||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict)
|
visible_entities = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict)
|
||||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||||
if self.pomdp_r:
|
if self.pomdp_r:
|
||||||
for e in set(visible_entitites):
|
for e in set(visible_entities):
|
||||||
self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e)
|
self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e)
|
||||||
else:
|
else:
|
||||||
for e in set(visible_entitites):
|
for e in set(visible_entities):
|
||||||
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||||
|
|
||||||
pre_sort_obs = dict(pre_sort_obs)
|
pre_sort_obs = dict(pre_sort_obs)
|
||||||
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
||||||
|
|
||||||
for idx, l_name in enumerate(agent_want_obs):
|
for idx, l_name in enumerate(agent_want_obs):
|
||||||
print(l_name)
|
|
||||||
try:
|
try:
|
||||||
obs[idx] = pre_sort_obs[l_name]
|
obs[idx] = pre_sort_obs[l_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -125,12 +123,11 @@ class OBSBuilder(object):
|
|||||||
try:
|
try:
|
||||||
# Look for bound entity names!
|
# Look for bound entity names!
|
||||||
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
|
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
|
||||||
print(pattern)
|
|
||||||
name = next((x for x in self.all_obs if pattern.search(x)), None)
|
name = next((x for x in self.all_obs if pattern.search(x)), None)
|
||||||
e = self.all_obs[name]
|
e = self.all_obs[name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
try:
|
try:
|
||||||
e = next(v for k in self.all_obs.items() if l_name in k and agent.name in k)
|
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
f'Check for spelling errors! \n '
|
f'Check for spelling errors! \n '
|
||||||
@ -233,7 +230,7 @@ class RayCaster:
|
|||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
return f'{self.__class__.__name__}({self.agent.name})'
|
||||||
|
|
||||||
def build_ray_targets(self):
|
def build_ray_targets(self):
|
||||||
north = np.array([0, -1])*self.pomdp_r
|
north = np.array([0, -1]) * self.pomdp_r
|
||||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||||
rot_M = [
|
rot_M = [
|
||||||
[[math.cos(theta), -math.sin(theta)],
|
[[math.cos(theta), -math.sin(theta)],
|
||||||
@ -266,8 +263,9 @@ class RayCaster:
|
|||||||
diag_hits = all([
|
diag_hits = all([
|
||||||
self.ray_block_cache(
|
self.ray_block_cache(
|
||||||
key,
|
key,
|
||||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(pos_dict[key]))
|
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
|
||||||
for key in ((x, y-cy), (x-cx, y))
|
pos_dict[key]))
|
||||||
|
for key in ((x, y - cy), (x - cx, y))
|
||||||
]) if (cx != 0 and cy != 0) else False
|
]) if (cx != 0 and cy != 0) else False
|
||||||
|
|
||||||
visible += entities_hit if not diag_hits else []
|
visible += entities_hit if not diag_hits else []
|
||||||
|
Reference in New Issue
Block a user