mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-15 23:37:14 +02:00
Merge branch 'main' into refactor_rename
# Conflicts: # marl_factory_grid/configs/default_config.yaml # marl_factory_grid/environment/entity/object.py
This commit is contained in:
@@ -60,7 +60,7 @@ Just define what your environment needs in a *yaml*-configfile like:
|
|||||||
done_at_collisions: !!bool True
|
done_at_collisions: !!bool True
|
||||||
ItemRespawn:
|
ItemRespawn:
|
||||||
spawn_freq: 5
|
spawn_freq: 5
|
||||||
DoDoorAutoClose: {}
|
DoorAutoClose: {}
|
||||||
|
|
||||||
Assets:
|
Assets:
|
||||||
- Defaults
|
- Defaults
|
||||||
|
@@ -1,23 +1,4 @@
|
|||||||
Agents:
|
Agents:
|
||||||
Eberhart:
|
|
||||||
Actions:
|
|
||||||
- Move8
|
|
||||||
- Noop
|
|
||||||
- ItemAction
|
|
||||||
Observations:
|
|
||||||
- Combined:
|
|
||||||
- Other
|
|
||||||
- Walls
|
|
||||||
- GlobalPosition
|
|
||||||
- Battery
|
|
||||||
- ChargePods
|
|
||||||
- DirtPiles
|
|
||||||
- Destinations
|
|
||||||
- Doors
|
|
||||||
- Items
|
|
||||||
- Inventory
|
|
||||||
- DropOffLocations
|
|
||||||
- Maintainers
|
|
||||||
Wolfgang:
|
Wolfgang:
|
||||||
Actions:
|
Actions:
|
||||||
- Noop
|
- Noop
|
||||||
@@ -42,7 +23,9 @@ Agents:
|
|||||||
- DropOffLocations
|
- DropOffLocations
|
||||||
- Maintainers
|
- Maintainers
|
||||||
Entities:
|
Entities:
|
||||||
Batteries: {}
|
Batteries:
|
||||||
|
initial_charge: 0.8
|
||||||
|
per_action_costs: 0.02
|
||||||
ChargePods: {}
|
ChargePods: {}
|
||||||
Destinations: {}
|
Destinations: {}
|
||||||
DirtPiles:
|
DirtPiles:
|
||||||
@@ -70,23 +53,21 @@ General:
|
|||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
SpawnAgents: {}
|
SpawnAgents: {}
|
||||||
BtryDoneAtDischarge: {}
|
DoneAtBatteryDischarge: {}
|
||||||
Collision:
|
Collision:
|
||||||
done_at_collisions: false
|
done_at_collisions: false
|
||||||
AssignGlobalPositions: {}
|
AssignGlobalPositions: {}
|
||||||
|
DoneAtDestinationReachAny: {}
|
||||||
DestinationReachReward: {}
|
DestinationReachReward: {}
|
||||||
SpawnDestinations:
|
SpawnDestinations:
|
||||||
n_dests: 1
|
n_dests: 1
|
||||||
spawn_mode: GROUPED
|
spawn_mode: GROUPED
|
||||||
DoneOnAllDirtCleaned: {}
|
DoneOnAllDirtCleaned: {}
|
||||||
SpawnDirt:
|
SpawnDirt:
|
||||||
initial_n: 4
|
|
||||||
initial_amount: 0.5
|
|
||||||
respawn_n: 2
|
|
||||||
respawn_amount: 0.2
|
|
||||||
spawn_freq: 15
|
spawn_freq: 15
|
||||||
EntitiesSmearDirtOnMove: {}
|
EntitiesSmearDirtOnMove:
|
||||||
DoDoorAutoClose:
|
smear_ratio: 0.2
|
||||||
|
DoorAutoClose:
|
||||||
close_frequency: 10
|
close_frequency: 10
|
||||||
ItemRules:
|
ItemRules:
|
||||||
max_dropoff_storage_size: 0
|
max_dropoff_storage_size: 0
|
||||||
|
@@ -12,7 +12,7 @@ class BoundEntityMixin:
|
|||||||
if self.bound_entity:
|
if self.bound_entity:
|
||||||
return f'{self.__class__.__name__}({self.bound_entity.name})'
|
return f'{self.__class__.__name__}({self.bound_entity.name})'
|
||||||
else:
|
else:
|
||||||
print()
|
pass
|
||||||
|
|
||||||
def belongs_to_entity(self, entity):
|
def belongs_to_entity(self, entity):
|
||||||
return entity == self.bound_entity
|
return entity == self.bound_entity
|
||||||
|
@@ -40,17 +40,6 @@ class _Object:
|
|||||||
name = h.add_pos_name(name, self)
|
name = h.add_pos_name(name, self)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
# @property
|
|
||||||
# def name(self):
|
|
||||||
# name = f"{self.__class__.__name__}"
|
|
||||||
# if self.bound_entity:
|
|
||||||
# name += f"[{self.bound_entity.name}]"
|
|
||||||
# if self._str_ident is not None:
|
|
||||||
# name += f"({self._str_ident})"
|
|
||||||
# else:
|
|
||||||
# name += f"(#{self.u_int})"
|
|
||||||
# return name
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self):
|
def identifier(self):
|
||||||
if self._str_ident is not None:
|
if self._str_ident is not None:
|
||||||
@@ -165,25 +154,30 @@ class _Object:
|
|||||||
# except AttributeError:
|
# except AttributeError:
|
||||||
# return False
|
# return False
|
||||||
#
|
#
|
||||||
# @property
|
# @property
|
||||||
# def var_can_collide(self):
|
# def var_can_collide(self):
|
||||||
# try:
|
# try:
|
||||||
# return self._collection.var_can_collide or False
|
# return self._collection.var_can_collide or False
|
||||||
# except AttributeError:
|
# except AttributeError:
|
||||||
# return False
|
# return False
|
||||||
#
|
#
|
||||||
# @property
|
|
||||||
# def encoding(self):
|
|
||||||
# return c.VALUE_OCCUPIED_CELL
|
|
||||||
#
|
#
|
||||||
# def __init__(self, **kwargs):
|
# @property
|
||||||
# super(EnvObject, self).__init__(**kwargs)
|
# def encoding(self):
|
||||||
|
# return c.VALUE_OCCUPIED_CELL
|
||||||
#
|
#
|
||||||
# def change_parent_collection(self, other_collection):
|
|
||||||
# other_collection.add_item(self)
|
|
||||||
# self._collection.delete_env_object(self)
|
|
||||||
# self._collection = other_collection
|
|
||||||
# return self._collection == other_collection
|
|
||||||
#
|
#
|
||||||
# def summarize_state(self):
|
# def __init__(self, **kwargs):
|
||||||
# return dict(name=str(self.name))
|
# self._bound_entity = None
|
||||||
|
# super(EnvObject, self).__init__(**kwargs)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def change_parent_collection(self, other_collection):
|
||||||
|
# other_collection.add_item(self)
|
||||||
|
# self._collection.delete_env_object(self)
|
||||||
|
# self._collection = other_collection
|
||||||
|
# return self._collection == other_collection
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def summarize_state(self):
|
||||||
|
# return dict(name=str(self.name))
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from random import shuffle
|
from random import shuffle, random
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.objects import _Objects
|
from marl_factory_grid.environment.groups.objects import _Objects
|
||||||
@@ -26,6 +26,7 @@ class Entities(_Objects):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def floorlist(self):
|
def floorlist(self):
|
||||||
|
shuffle(self._floor_positions)
|
||||||
return self._floor_positions
|
return self._floor_positions
|
||||||
|
|
||||||
def __init__(self, floor_positions):
|
def __init__(self, floor_positions):
|
||||||
|
@@ -59,7 +59,7 @@ class _Objects:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def remove_item(self, item: _entity):
|
def remove_item(self, item: _entity):
|
||||||
for observer in self.observers:
|
for observer in item.observers:
|
||||||
observer.notify_del_entity(item)
|
observer.notify_del_entity(item)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
del self._data[item.name]
|
del self._data[item.name]
|
||||||
@@ -126,10 +126,6 @@ class _Objects:
|
|||||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||||
|
|
||||||
def notify_del_entity(self, entity: _Object):
|
def notify_del_entity(self, entity: _Object):
|
||||||
try:
|
|
||||||
entity.del_observer(self)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
try:
|
try:
|
||||||
self.pos_dict[entity.pos].remove(entity)
|
self.pos_dict[entity.pos].remove(entity)
|
||||||
except (AttributeError, ValueError, IndexError):
|
except (AttributeError, ValueError, IndexError):
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
from .actions import BtryCharge
|
from .actions import BtryCharge
|
||||||
from .entitites import Pod, Battery
|
from .entitites import Pod, Battery
|
||||||
from .groups import ChargePods, Batteries
|
from .groups import ChargePods, Batteries
|
||||||
from .rules import BtryDoneAtDischarge, BatteryDecharge
|
from .rules import DoneAtBatteryDischarge, BatteryDecharge
|
||||||
|
@@ -94,7 +94,7 @@ class BatteryDecharge(Rule):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class BtryDoneAtDischarge(BatteryDecharge):
|
class DoneAtBatteryDischarge(BatteryDecharge):
|
||||||
|
|
||||||
def __init__(self, reward_discharge_done=b.REWARD_DISCHARGE_DONE, mode: str = b.SINGLE, **kwargs):
|
def __init__(self, reward_discharge_done=b.REWARD_DISCHARGE_DONE, mode: str = b.SINGLE, **kwargs):
|
||||||
f"""
|
f"""
|
||||||
|
@@ -45,6 +45,7 @@ class DirtPiles(Collection):
|
|||||||
if not self.amount > self.max_global_amount:
|
if not self.amount > self.max_global_amount:
|
||||||
amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
|
amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
|
||||||
if dirt := self.by_pos(pos):
|
if dirt := self.by_pos(pos):
|
||||||
|
dirt = next(dirt.iter())
|
||||||
new_value = dirt.amount + amount
|
new_value = dirt.amount + amount
|
||||||
dirt.set_new_amount(new_value)
|
dirt.set_new_amount(new_value)
|
||||||
else:
|
else:
|
||||||
@@ -57,8 +58,8 @@ class DirtPiles(Collection):
|
|||||||
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter)
|
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter)
|
||||||
|
|
||||||
def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result:
|
def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result:
|
||||||
free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 1 or (
|
free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or (
|
||||||
len(state.entities.pos_dict[x]) == 2 and isinstance(next(y for y in x), DirtPile))]
|
len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))]
|
||||||
# free_for_dirt = [x for x in state[c.FLOOR]
|
# free_for_dirt = [x for x in state[c.FLOOR]
|
||||||
# if len(x.guests) == 0 or (
|
# if len(x.guests) == 0 or (
|
||||||
# len(x.guests) == 1 and
|
# len(x.guests) == 1 and
|
||||||
|
@@ -27,7 +27,8 @@ class DoneOnAllDirtCleaned(Rule):
|
|||||||
|
|
||||||
class SpawnDirt(Rule):
|
class SpawnDirt(Rule):
|
||||||
|
|
||||||
def __init__(self, initial_n: int, initial_amount: float, respawn_n: int, respawn_amount: float,
|
def __init__(self, initial_n: int = 5, initial_amount: float = 1.3,
|
||||||
|
respawn_n: int = 3, respawn_amount: float = 0.8,
|
||||||
n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15):
|
n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15):
|
||||||
"""
|
"""
|
||||||
Defines the spawn pattern of intial and additional 'Dirt'-entitites.
|
Defines the spawn pattern of intial and additional 'Dirt'-entitites.
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
from .actions import DoorUse
|
from .actions import DoorUse
|
||||||
from .entitites import Door, DoorIndicator
|
from .entitites import Door, DoorIndicator
|
||||||
from .groups import Doors
|
from .groups import Doors
|
||||||
from .rules import DoDoorAutoClose, IndicateDoorAreaInObservation
|
from .rules import DoorAutoClose, IndicateDoorAreaInObservation
|
||||||
|
@@ -5,7 +5,7 @@ from . import constants as d
|
|||||||
from .entitites import DoorIndicator
|
from .entitites import DoorIndicator
|
||||||
|
|
||||||
|
|
||||||
class DoDoorAutoClose(Rule):
|
class DoorAutoClose(Rule):
|
||||||
|
|
||||||
def __init__(self, close_frequency: int = 10):
|
def __init__(self, close_frequency: int = 10):
|
||||||
"""
|
"""
|
||||||
|
@@ -135,7 +135,7 @@ class DropOffLocations(Collection):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trigger_drop_off_location_spawn(state, n_locations):
|
def trigger_drop_off_location_spawn(state, n_locations):
|
||||||
empty_positions = state.entities.empty_positions()[:n_locations]
|
empty_positions = state.entities.empty_positions[:n_locations]
|
||||||
do_entites = state[i.DROP_OFF]
|
do_entites = state[i.DROP_OFF]
|
||||||
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
||||||
do_entites.add_items(drop_offs)
|
do_entites.add_items(drop_offs)
|
||||||
|
@@ -8,6 +8,7 @@ import yaml
|
|||||||
|
|
||||||
from marl_factory_grid.environment.groups.agents import Agents
|
from marl_factory_grid.environment.groups.agents import Agents
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
|
from marl_factory_grid.environment.rules import Rule
|
||||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
@@ -81,7 +82,15 @@ class FactoryConfigParser(object):
|
|||||||
entity_class = locate_and_import_class(entity, folder_path)
|
entity_class = locate_and_import_class(entity, folder_path)
|
||||||
except AttributeError as e3:
|
except AttributeError as e3:
|
||||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||||
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||||
|
print()
|
||||||
|
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||||
|
print('Possible Entitys are:', str(ents))
|
||||||
|
print()
|
||||||
|
print('Goodbye')
|
||||||
|
print()
|
||||||
|
exit()
|
||||||
|
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
||||||
|
|
||||||
entity_kwargs = self.entities.get(entity, {})
|
entity_kwargs = self.entities.get(entity, {})
|
||||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||||
@@ -114,6 +123,7 @@ class FactoryConfigParser(object):
|
|||||||
|
|
||||||
# Observation
|
# Observation
|
||||||
observations = list()
|
observations = list()
|
||||||
|
assert self.agents[name]['Observations'] is not None, 'Did you specify any Observation?'
|
||||||
if c.DEFAULTS in self.agents[name]['Observations']:
|
if c.DEFAULTS in self.agents[name]['Observations']:
|
||||||
observations.extend(self.default_observations)
|
observations.extend(self.default_observations)
|
||||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||||
@@ -141,6 +151,8 @@ class FactoryConfigParser(object):
|
|||||||
rule_class = locate_and_import_class(rule, folder_path)
|
rule_class = locate_and_import_class(rule, folder_path)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||||
|
# Fixme This check does not work!
|
||||||
|
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
|
||||||
rule_kwargs = self.rules.get(rule, {})
|
rule_kwargs = self.rules.get(rule, {})
|
||||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||||
return rules_classes
|
return rules_classes
|
||||||
|
@@ -76,6 +76,14 @@ class OBSBuilder(object):
|
|||||||
named_obs_dict[agent.name] = {'observation': obs, 'names': names}
|
named_obs_dict[agent.name] = {'observation': obs, 'names': names}
|
||||||
return named_obs_dict
|
return named_obs_dict
|
||||||
|
|
||||||
|
def place_entity_in_observation(self, obs_array, agent, e):
|
||||||
|
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||||
|
try:
|
||||||
|
obs_array[x, y] += e.encoding
|
||||||
|
except IndexError:
|
||||||
|
# Seemded to be visible but is out of range
|
||||||
|
pass
|
||||||
|
|
||||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||||
assert self._curr_env_step == state.curr_step, (
|
assert self._curr_env_step == state.curr_step, (
|
||||||
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
|
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
|
||||||
@@ -91,12 +99,7 @@ class OBSBuilder(object):
|
|||||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||||
if self.pomdp_r:
|
if self.pomdp_r:
|
||||||
for e in set(visible_entitites):
|
for e in set(visible_entitites):
|
||||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e)
|
||||||
try:
|
|
||||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
|
||||||
except IndexError:
|
|
||||||
# Seemded to be visible but is out or range
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
for e in set(visible_entitites):
|
for e in set(visible_entitites):
|
||||||
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||||
@@ -157,18 +160,20 @@ class OBSBuilder(object):
|
|||||||
np.put(obs[idx], 0, v, mode='raise')
|
np.put(obs[idx], 0, v, mode='raise')
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
||||||
|
if self.pomdp_r:
|
||||||
try:
|
try:
|
||||||
light_map = np.zeros(self.obs_shape)
|
light_map = np.zeros(self.obs_shape)
|
||||||
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False))
|
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
|
||||||
if self.pomdp_r:
|
if self.pomdp_r:
|
||||||
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor]
|
for f in set(visible_floor):
|
||||||
else:
|
self.place_entity_in_observation(light_map, agent, f)
|
||||||
coords = [x.pos for x in visible_floor]
|
else:
|
||||||
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1)
|
for f in set(visible_floor):
|
||||||
self.curr_lightmaps[agent.name] = light_map
|
light_map[f.x, f.y] += f.encoding
|
||||||
except KeyError:
|
self.curr_lightmaps[agent.name] = light_map
|
||||||
print()
|
except (KeyError, ValueError):
|
||||||
|
print()
|
||||||
|
pass
|
||||||
return obs, self.obs_layers[agent.name]
|
return obs, self.obs_layers[agent.name]
|
||||||
|
|
||||||
def _sort_and_name_observation_conf(self, agent):
|
def _sort_and_name_observation_conf(self, agent):
|
||||||
|
130
marl_factory_grid/utils/ray_caster.py
Normal file
130
marl_factory_grid/utils/ray_caster.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import math
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numba import njit
|
||||||
|
|
||||||
|
|
||||||
|
class RayCaster:
|
||||||
|
def __init__(self, agent, pomdp_r, degs=360):
|
||||||
|
self.agent = agent
|
||||||
|
self.pomdp_r = pomdp_r
|
||||||
|
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||||
|
self.degs = degs
|
||||||
|
self.ray_targets = self.build_ray_targets()
|
||||||
|
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||||
|
self._cache_dict = {}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.agent.name})'
|
||||||
|
|
||||||
|
def build_ray_targets(self):
|
||||||
|
north = np.array([0, -1])*self.pomdp_r
|
||||||
|
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||||
|
rot_M = [
|
||||||
|
[[math.cos(theta), -math.sin(theta)],
|
||||||
|
[math.sin(theta), math.cos(theta)]] for theta in thetas
|
||||||
|
]
|
||||||
|
rot_M = np.stack(rot_M, 0)
|
||||||
|
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||||
|
return rot_M.astype(int)
|
||||||
|
|
||||||
|
def ray_block_cache(self, key, callback):
|
||||||
|
if key not in self._cache_dict:
|
||||||
|
self._cache_dict[key] = callback()
|
||||||
|
return self._cache_dict[key]
|
||||||
|
|
||||||
|
def visible_entities(self, pos_dict, reset_cache=True):
|
||||||
|
visible = list()
|
||||||
|
if reset_cache:
|
||||||
|
self._cache_dict = dict()
|
||||||
|
|
||||||
|
for ray in self.get_rays():
|
||||||
|
rx, ry = ray[0]
|
||||||
|
for x, y in ray:
|
||||||
|
cx, cy = x - rx, y - ry
|
||||||
|
|
||||||
|
entities_hit = pos_dict[(x, y)]
|
||||||
|
hits = self.ray_block_cache((x, y),
|
||||||
|
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
||||||
|
)
|
||||||
|
|
||||||
|
diag_hits = all([
|
||||||
|
self.ray_block_cache(
|
||||||
|
key,
|
||||||
|
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||||
|
for key in ((x, y-cy), (x-cx, y))
|
||||||
|
]) if (cx != 0 and cy != 0) else False
|
||||||
|
|
||||||
|
visible += entities_hit if not diag_hits else []
|
||||||
|
if hits or diag_hits:
|
||||||
|
break
|
||||||
|
rx, ry = x, y
|
||||||
|
return visible
|
||||||
|
|
||||||
|
def get_rays(self):
|
||||||
|
a_pos = self.agent.pos
|
||||||
|
outline = self.ray_targets + a_pos
|
||||||
|
return self.bresenham_loop(a_pos, outline)
|
||||||
|
|
||||||
|
# todo do this once and cache the points!
|
||||||
|
def get_fov_outline(self) -> np.ndarray:
|
||||||
|
return self.ray_targets + self.agent.pos
|
||||||
|
|
||||||
|
def get_square_outline(self):
|
||||||
|
agent = self.agent
|
||||||
|
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||||
|
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||||
|
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||||
|
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||||
|
return outline
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@njit
|
||||||
|
def bresenham_loop(a_pos, points):
|
||||||
|
results = []
|
||||||
|
for end in points:
|
||||||
|
x1, y1 = a_pos
|
||||||
|
x2, y2 = end
|
||||||
|
dx = x2 - x1
|
||||||
|
dy = y2 - y1
|
||||||
|
|
||||||
|
# Determine how steep the line is
|
||||||
|
is_steep = abs(dy) > abs(dx)
|
||||||
|
|
||||||
|
# Rotate line
|
||||||
|
if is_steep:
|
||||||
|
x1, y1 = y1, x1
|
||||||
|
x2, y2 = y2, x2
|
||||||
|
|
||||||
|
# Swap start and end points if necessary and store swap state
|
||||||
|
swapped = False
|
||||||
|
if x1 > x2:
|
||||||
|
x1, x2 = x2, x1
|
||||||
|
y1, y2 = y2, y1
|
||||||
|
swapped = True
|
||||||
|
|
||||||
|
# Recalculate differentials
|
||||||
|
dx = x2 - x1
|
||||||
|
dy = y2 - y1
|
||||||
|
|
||||||
|
# Calculate error
|
||||||
|
error = int(dx / 2.0)
|
||||||
|
ystep = 1 if y1 < y2 else -1
|
||||||
|
|
||||||
|
# Iterate over bounding box generating points between start and end
|
||||||
|
y = y1
|
||||||
|
points = []
|
||||||
|
for x in range(int(x1), int(x2) + 1):
|
||||||
|
coord = [y, x] if is_steep else [x, y]
|
||||||
|
points.append(coord)
|
||||||
|
error -= abs(dy)
|
||||||
|
if error < 0:
|
||||||
|
y += ystep
|
||||||
|
error += dx
|
||||||
|
|
||||||
|
# Reverse the list if the coordinates were swapped
|
||||||
|
if swapped:
|
||||||
|
points.reverse()
|
||||||
|
results.append(points)
|
||||||
|
return results
|
@@ -88,8 +88,8 @@ class Gamestate(object):
|
|||||||
results.extend(self.rules.tick_pre_step_all(self))
|
results.extend(self.rules.tick_pre_step_all(self))
|
||||||
|
|
||||||
for idx, action_int in enumerate(actions):
|
for idx, action_int in enumerate(actions):
|
||||||
|
agent = self[c.AGENT][idx].clear_temp_state()
|
||||||
if not agent.var_is_paralyzed:
|
if not agent.var_is_paralyzed:
|
||||||
agent = self[c.AGENT][idx].clear_temp_state()
|
|
||||||
action = agent.actions[action_int]
|
action = agent.actions[action_int]
|
||||||
action_result = action.do(agent, self)
|
action_result = action.do(agent, self)
|
||||||
results.append(action_result)
|
results.append(action_result)
|
||||||
|
@@ -31,6 +31,10 @@ class RenderEntity:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Floor:
|
class Floor:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return f"Floor({self.pos})"
|
return f"Floor({self.pos})"
|
||||||
|
Reference in New Issue
Block a user