mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 21:47:25 +01:00
fixed observation,
fixed observer reporting
This commit is contained in:
@@ -60,7 +60,7 @@ class Objects:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def remove_item(self, item: _entity):
|
def remove_item(self, item: _entity):
|
||||||
for observer in self.observers:
|
for observer in item.observers:
|
||||||
observer.notify_del_entity(item)
|
observer.notify_del_entity(item)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
del self._data[item.name]
|
del self._data[item.name]
|
||||||
@@ -147,10 +147,6 @@ class Objects:
|
|||||||
# pass
|
# pass
|
||||||
|
|
||||||
def notify_del_entity(self, entity: Object):
|
def notify_del_entity(self, entity: Object):
|
||||||
try:
|
|
||||||
entity.del_observer(self)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
try:
|
try:
|
||||||
self.pos_dict[entity.pos].remove(entity)
|
self.pos_dict[entity.pos].remove(entity)
|
||||||
except (AttributeError, ValueError, IndexError):
|
except (AttributeError, ValueError, IndexError):
|
||||||
|
|||||||
130
marl_factory_grid/utils/ray_caster.py
Normal file
130
marl_factory_grid/utils/ray_caster.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import math
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numba import njit
|
||||||
|
|
||||||
|
|
||||||
|
class RayCaster:
|
||||||
|
def __init__(self, agent, pomdp_r, degs=360):
|
||||||
|
self.agent = agent
|
||||||
|
self.pomdp_r = pomdp_r
|
||||||
|
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||||
|
self.degs = degs
|
||||||
|
self.ray_targets = self.build_ray_targets()
|
||||||
|
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||||
|
self._cache_dict = {}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.agent.name})'
|
||||||
|
|
||||||
|
def build_ray_targets(self):
|
||||||
|
north = np.array([0, -1])*self.pomdp_r
|
||||||
|
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||||
|
rot_M = [
|
||||||
|
[[math.cos(theta), -math.sin(theta)],
|
||||||
|
[math.sin(theta), math.cos(theta)]] for theta in thetas
|
||||||
|
]
|
||||||
|
rot_M = np.stack(rot_M, 0)
|
||||||
|
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||||
|
return rot_M.astype(int)
|
||||||
|
|
||||||
|
def ray_block_cache(self, key, callback):
|
||||||
|
if key not in self._cache_dict:
|
||||||
|
self._cache_dict[key] = callback()
|
||||||
|
return self._cache_dict[key]
|
||||||
|
|
||||||
|
def visible_entities(self, pos_dict, reset_cache=True):
|
||||||
|
visible = list()
|
||||||
|
if reset_cache:
|
||||||
|
self._cache_dict = dict()
|
||||||
|
|
||||||
|
for ray in self.get_rays():
|
||||||
|
rx, ry = ray[0]
|
||||||
|
for x, y in ray:
|
||||||
|
cx, cy = x - rx, y - ry
|
||||||
|
|
||||||
|
entities_hit = pos_dict[(x, y)]
|
||||||
|
hits = self.ray_block_cache((x, y),
|
||||||
|
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
||||||
|
)
|
||||||
|
|
||||||
|
diag_hits = all([
|
||||||
|
self.ray_block_cache(
|
||||||
|
key,
|
||||||
|
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||||
|
for key in ((x, y-cy), (x-cx, y))
|
||||||
|
]) if (cx != 0 and cy != 0) else False
|
||||||
|
|
||||||
|
visible += entities_hit if not diag_hits else []
|
||||||
|
if hits or diag_hits:
|
||||||
|
break
|
||||||
|
rx, ry = x, y
|
||||||
|
return visible
|
||||||
|
|
||||||
|
def get_rays(self):
|
||||||
|
a_pos = self.agent.pos
|
||||||
|
outline = self.ray_targets + a_pos
|
||||||
|
return self.bresenham_loop(a_pos, outline)
|
||||||
|
|
||||||
|
# todo do this once and cache the points!
|
||||||
|
def get_fov_outline(self) -> np.ndarray:
|
||||||
|
return self.ray_targets + self.agent.pos
|
||||||
|
|
||||||
|
def get_square_outline(self):
|
||||||
|
agent = self.agent
|
||||||
|
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||||
|
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||||
|
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||||
|
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||||
|
return outline
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@njit
|
||||||
|
def bresenham_loop(a_pos, points):
|
||||||
|
results = []
|
||||||
|
for end in points:
|
||||||
|
x1, y1 = a_pos
|
||||||
|
x2, y2 = end
|
||||||
|
dx = x2 - x1
|
||||||
|
dy = y2 - y1
|
||||||
|
|
||||||
|
# Determine how steep the line is
|
||||||
|
is_steep = abs(dy) > abs(dx)
|
||||||
|
|
||||||
|
# Rotate line
|
||||||
|
if is_steep:
|
||||||
|
x1, y1 = y1, x1
|
||||||
|
x2, y2 = y2, x2
|
||||||
|
|
||||||
|
# Swap start and end points if necessary and store swap state
|
||||||
|
swapped = False
|
||||||
|
if x1 > x2:
|
||||||
|
x1, x2 = x2, x1
|
||||||
|
y1, y2 = y2, y1
|
||||||
|
swapped = True
|
||||||
|
|
||||||
|
# Recalculate differentials
|
||||||
|
dx = x2 - x1
|
||||||
|
dy = y2 - y1
|
||||||
|
|
||||||
|
# Calculate error
|
||||||
|
error = int(dx / 2.0)
|
||||||
|
ystep = 1 if y1 < y2 else -1
|
||||||
|
|
||||||
|
# Iterate over bounding box generating points between start and end
|
||||||
|
y = y1
|
||||||
|
points = []
|
||||||
|
for x in range(int(x1), int(x2) + 1):
|
||||||
|
coord = [y, x] if is_steep else [x, y]
|
||||||
|
points.append(coord)
|
||||||
|
error -= abs(dy)
|
||||||
|
if error < 0:
|
||||||
|
y += ystep
|
||||||
|
error += dx
|
||||||
|
|
||||||
|
# Reverse the list if the coordinates were swapped
|
||||||
|
if swapped:
|
||||||
|
points.reverse()
|
||||||
|
results.append(points)
|
||||||
|
return results
|
||||||
Reference in New Issue
Block a user