mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
Merge branch 'main' into refactor_rename
# Conflicts: # marl_factory_grid/configs/default_config.yaml # marl_factory_grid/environment/entity/object.py
This commit is contained in:
@@ -8,6 +8,7 @@ import yaml
|
||||
|
||||
from marl_factory_grid.environment.groups.agents import Agents
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
@@ -81,7 +82,15 @@ class FactoryConfigParser(object):
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e3:
|
||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print()
|
||||
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print()
|
||||
print('Goodbye')
|
||||
print()
|
||||
exit()
|
||||
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
||||
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
@@ -114,6 +123,7 @@ class FactoryConfigParser(object):
|
||||
|
||||
# Observation
|
||||
observations = list()
|
||||
assert self.agents[name]['Observations'] is not None, 'Did you specify any Observation?'
|
||||
if c.DEFAULTS in self.agents[name]['Observations']:
|
||||
observations.extend(self.default_observations)
|
||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||
@@ -141,6 +151,8 @@ class FactoryConfigParser(object):
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
# Fixme This check does not work!
|
||||
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
|
||||
rule_kwargs = self.rules.get(rule, {})
|
||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||
return rules_classes
|
||||
|
||||
@@ -76,6 +76,14 @@ class OBSBuilder(object):
|
||||
named_obs_dict[agent.name] = {'observation': obs, 'names': names}
|
||||
return named_obs_dict
|
||||
|
||||
def place_entity_in_observation(self, obs_array, agent, e):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
obs_array[x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out of range
|
||||
pass
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
assert self._curr_env_step == state.curr_step, (
|
||||
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
|
||||
@@ -91,12 +99,7 @@ class OBSBuilder(object):
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||
if self.pomdp_r:
|
||||
for e in set(visible_entitites):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out or range
|
||||
pass
|
||||
self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e)
|
||||
else:
|
||||
for e in set(visible_entitites):
|
||||
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||
@@ -157,18 +160,20 @@ class OBSBuilder(object):
|
||||
np.put(obs[idx], 0, v, mode='raise')
|
||||
except IndexError:
|
||||
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
||||
|
||||
try:
|
||||
light_map = np.zeros(self.obs_shape)
|
||||
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False))
|
||||
if self.pomdp_r:
|
||||
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor]
|
||||
else:
|
||||
coords = [x.pos for x in visible_floor]
|
||||
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1)
|
||||
self.curr_lightmaps[agent.name] = light_map
|
||||
except KeyError:
|
||||
print()
|
||||
if self.pomdp_r:
|
||||
try:
|
||||
light_map = np.zeros(self.obs_shape)
|
||||
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
|
||||
if self.pomdp_r:
|
||||
for f in set(visible_floor):
|
||||
self.place_entity_in_observation(light_map, agent, f)
|
||||
else:
|
||||
for f in set(visible_floor):
|
||||
light_map[f.x, f.y] += f.encoding
|
||||
self.curr_lightmaps[agent.name] = light_map
|
||||
except (KeyError, ValueError):
|
||||
print()
|
||||
pass
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
|
||||
130
marl_factory_grid/utils/ray_caster.py
Normal file
130
marl_factory_grid/utils/ray_caster.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import math
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||
self.degs = degs
|
||||
self.ray_targets = self.build_ray_targets()
|
||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||
self._cache_dict = {}
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1])*self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
[math.sin(theta), math.cos(theta)]] for theta in thetas
|
||||
]
|
||||
rot_M = np.stack(rot_M, 0)
|
||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, key, callback):
|
||||
if key not in self._cache_dict:
|
||||
self._cache_dict[key] = callback()
|
||||
return self._cache_dict[key]
|
||||
|
||||
def visible_entities(self, pos_dict, reset_cache=True):
|
||||
visible = list()
|
||||
if reset_cache:
|
||||
self._cache_dict = dict()
|
||||
|
||||
for ray in self.get_rays():
|
||||
rx, ry = ray[0]
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
entities_hit = pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache((x, y),
|
||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
||||
)
|
||||
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
key,
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
if hits or diag_hits:
|
||||
break
|
||||
rx, ry = x, y
|
||||
return visible
|
||||
|
||||
def get_rays(self):
|
||||
a_pos = self.agent.pos
|
||||
outline = self.ray_targets + a_pos
|
||||
return self.bresenham_loop(a_pos, outline)
|
||||
|
||||
# todo do this once and cache the points!
|
||||
def get_fov_outline(self) -> np.ndarray:
|
||||
return self.ray_targets + self.agent.pos
|
||||
|
||||
def get_square_outline(self):
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
return outline
|
||||
|
||||
@staticmethod
|
||||
@njit
|
||||
def bresenham_loop(a_pos, points):
|
||||
results = []
|
||||
for end in points:
|
||||
x1, y1 = a_pos
|
||||
x2, y2 = end
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Determine how steep the line is
|
||||
is_steep = abs(dy) > abs(dx)
|
||||
|
||||
# Rotate line
|
||||
if is_steep:
|
||||
x1, y1 = y1, x1
|
||||
x2, y2 = y2, x2
|
||||
|
||||
# Swap start and end points if necessary and store swap state
|
||||
swapped = False
|
||||
if x1 > x2:
|
||||
x1, x2 = x2, x1
|
||||
y1, y2 = y2, y1
|
||||
swapped = True
|
||||
|
||||
# Recalculate differentials
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Calculate error
|
||||
error = int(dx / 2.0)
|
||||
ystep = 1 if y1 < y2 else -1
|
||||
|
||||
# Iterate over bounding box generating points between start and end
|
||||
y = y1
|
||||
points = []
|
||||
for x in range(int(x1), int(x2) + 1):
|
||||
coord = [y, x] if is_steep else [x, y]
|
||||
points.append(coord)
|
||||
error -= abs(dy)
|
||||
if error < 0:
|
||||
y += ystep
|
||||
error += dx
|
||||
|
||||
# Reverse the list if the coordinates were swapped
|
||||
if swapped:
|
||||
points.reverse()
|
||||
results.append(points)
|
||||
return results
|
||||
@@ -88,8 +88,8 @@ class Gamestate(object):
|
||||
results.extend(self.rules.tick_pre_step_all(self))
|
||||
|
||||
for idx, action_int in enumerate(actions):
|
||||
agent = self[c.AGENT][idx].clear_temp_state()
|
||||
if not agent.var_is_paralyzed:
|
||||
agent = self[c.AGENT][idx].clear_temp_state()
|
||||
action = agent.actions[action_int]
|
||||
action_result = action.do(agent, self)
|
||||
results.append(action_result)
|
||||
|
||||
@@ -31,6 +31,10 @@ class RenderEntity:
|
||||
@dataclass
|
||||
class Floor:
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f"Floor({self.pos})"
|
||||
|
||||
Reference in New Issue
Block a user