Merge branch 'main' into refactor_rename

# Conflicts:
#	marl_factory_grid/configs/default_config.yaml
#	marl_factory_grid/environment/entity/object.py
This commit is contained in:
Chanumask
2023-10-31 10:28:25 +01:00
18 changed files with 217 additions and 92 deletions

View File

@@ -60,7 +60,7 @@ Just define what your environment needs in a *yaml*-configfile like:
done_at_collisions: !!bool True done_at_collisions: !!bool True
ItemRespawn: ItemRespawn:
spawn_freq: 5 spawn_freq: 5
DoDoorAutoClose: {} DoorAutoClose: {}
Assets: Assets:
- Defaults - Defaults

View File

@@ -1,23 +1,4 @@
Agents: Agents:
Eberhart:
Actions:
- Move8
- Noop
- ItemAction
Observations:
- Combined:
- Other
- Walls
- GlobalPosition
- Battery
- ChargePods
- DirtPiles
- Destinations
- Doors
- Items
- Inventory
- DropOffLocations
- Maintainers
Wolfgang: Wolfgang:
Actions: Actions:
- Noop - Noop
@@ -42,7 +23,9 @@ Agents:
- DropOffLocations - DropOffLocations
- Maintainers - Maintainers
Entities: Entities:
Batteries: {} Batteries:
initial_charge: 0.8
per_action_costs: 0.02
ChargePods: {} ChargePods: {}
Destinations: {} Destinations: {}
DirtPiles: DirtPiles:
@@ -70,23 +53,21 @@ General:
Rules: Rules:
SpawnAgents: {} SpawnAgents: {}
BtryDoneAtDischarge: {} DoneAtBatteryDischarge: {}
Collision: Collision:
done_at_collisions: false done_at_collisions: false
AssignGlobalPositions: {} AssignGlobalPositions: {}
DoneAtDestinationReachAny: {}
DestinationReachReward: {} DestinationReachReward: {}
SpawnDestinations: SpawnDestinations:
n_dests: 1 n_dests: 1
spawn_mode: GROUPED spawn_mode: GROUPED
DoneOnAllDirtCleaned: {} DoneOnAllDirtCleaned: {}
SpawnDirt: SpawnDirt:
initial_n: 4
initial_amount: 0.5
respawn_n: 2
respawn_amount: 0.2
spawn_freq: 15 spawn_freq: 15
EntitiesSmearDirtOnMove: {} EntitiesSmearDirtOnMove:
DoDoorAutoClose: smear_ratio: 0.2
DoorAutoClose:
close_frequency: 10 close_frequency: 10
ItemRules: ItemRules:
max_dropoff_storage_size: 0 max_dropoff_storage_size: 0

View File

@@ -12,7 +12,7 @@ class BoundEntityMixin:
if self.bound_entity: if self.bound_entity:
return f'{self.__class__.__name__}({self.bound_entity.name})' return f'{self.__class__.__name__}({self.bound_entity.name})'
else: else:
print() pass
def belongs_to_entity(self, entity): def belongs_to_entity(self, entity):
return entity == self.bound_entity return entity == self.bound_entity

View File

@@ -40,17 +40,6 @@ class _Object:
name = h.add_pos_name(name, self) name = h.add_pos_name(name, self)
return name return name
# @property
# def name(self):
# name = f"{self.__class__.__name__}"
# if self.bound_entity:
# name += f"[{self.bound_entity.name}]"
# if self._str_ident is not None:
# name += f"({self._str_ident})"
# else:
# name += f"(#{self.u_int})"
# return name
@property @property
def identifier(self): def identifier(self):
if self._str_ident is not None: if self._str_ident is not None:
@@ -165,25 +154,30 @@ class _Object:
# except AttributeError: # except AttributeError:
# return False # return False
# #
# @property # @property
# def var_can_collide(self): # def var_can_collide(self):
# try: # try:
# return self._collection.var_can_collide or False # return self._collection.var_can_collide or False
# except AttributeError: # except AttributeError:
# return False # return False
# #
# @property
# def encoding(self):
# return c.VALUE_OCCUPIED_CELL
# #
# def __init__(self, **kwargs): # @property
# super(EnvObject, self).__init__(**kwargs) # def encoding(self):
# return c.VALUE_OCCUPIED_CELL
# #
# def change_parent_collection(self, other_collection):
# other_collection.add_item(self)
# self._collection.delete_env_object(self)
# self._collection = other_collection
# return self._collection == other_collection
# #
# def summarize_state(self): # def __init__(self, **kwargs):
# return dict(name=str(self.name)) # self._bound_entity = None
# super(EnvObject, self).__init__(**kwargs)
#
#
# def change_parent_collection(self, other_collection):
# other_collection.add_item(self)
# self._collection.delete_env_object(self)
# self._collection = other_collection
# return self._collection == other_collection
#
#
# def summarize_state(self):
# return dict(name=str(self.name))

View File

@@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from operator import itemgetter from operator import itemgetter
from random import shuffle from random import shuffle, random
from typing import Dict from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects from marl_factory_grid.environment.groups.objects import _Objects
@@ -26,6 +26,7 @@ class Entities(_Objects):
@property @property
def floorlist(self): def floorlist(self):
shuffle(self._floor_positions)
return self._floor_positions return self._floor_positions
def __init__(self, floor_positions): def __init__(self, floor_positions):

View File

@@ -59,7 +59,7 @@ class _Objects:
return self return self
def remove_item(self, item: _entity): def remove_item(self, item: _entity):
for observer in self.observers: for observer in item.observers:
observer.notify_del_entity(item) observer.notify_del_entity(item)
# noinspection PyTypeChecker # noinspection PyTypeChecker
del self._data[item.name] del self._data[item.name]
@@ -126,10 +126,6 @@ class _Objects:
return f'{self.__class__.__name__}[{repr_dict}]' return f'{self.__class__.__name__}[{repr_dict}]'
def notify_del_entity(self, entity: _Object): def notify_del_entity(self, entity: _Object):
try:
entity.del_observer(self)
except AttributeError:
pass
try: try:
self.pos_dict[entity.pos].remove(entity) self.pos_dict[entity.pos].remove(entity)
except (AttributeError, ValueError, IndexError): except (AttributeError, ValueError, IndexError):

View File

@@ -1,4 +1,4 @@
from .actions import BtryCharge from .actions import BtryCharge
from .entitites import Pod, Battery from .entitites import Pod, Battery
from .groups import ChargePods, Batteries from .groups import ChargePods, Batteries
from .rules import BtryDoneAtDischarge, BatteryDecharge from .rules import DoneAtBatteryDischarge, BatteryDecharge

View File

@@ -94,7 +94,7 @@ class BatteryDecharge(Rule):
return results return results
class BtryDoneAtDischarge(BatteryDecharge): class DoneAtBatteryDischarge(BatteryDecharge):
def __init__(self, reward_discharge_done=b.REWARD_DISCHARGE_DONE, mode: str = b.SINGLE, **kwargs): def __init__(self, reward_discharge_done=b.REWARD_DISCHARGE_DONE, mode: str = b.SINGLE, **kwargs):
f""" f"""

View File

@@ -45,6 +45,7 @@ class DirtPiles(Collection):
if not self.amount > self.max_global_amount: if not self.amount > self.max_global_amount:
amount = amount_s[idx] if isinstance(amount_s, list) else amount_s amount = amount_s[idx] if isinstance(amount_s, list) else amount_s
if dirt := self.by_pos(pos): if dirt := self.by_pos(pos):
dirt = next(dirt.iter())
new_value = dirt.amount + amount new_value = dirt.amount + amount
dirt.set_new_amount(new_value) dirt.set_new_amount(new_value)
else: else:
@@ -57,8 +58,8 @@ class DirtPiles(Collection):
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter) return Result(identifier=f'{self.name}_spawn', validity=c.VALID, reward=0, value=spawn_counter)
def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result: def trigger_dirt_spawn(self, n, amount, state, n_var=0.2, amount_var=0.2) -> Result:
free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 1 or ( free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 0 or (
len(state.entities.pos_dict[x]) == 2 and isinstance(next(y for y in x), DirtPile))] len(state.entities.pos_dict[x]) >= 1 and isinstance(next(y for y in x), DirtPile))]
# free_for_dirt = [x for x in state[c.FLOOR] # free_for_dirt = [x for x in state[c.FLOOR]
# if len(x.guests) == 0 or ( # if len(x.guests) == 0 or (
# len(x.guests) == 1 and # len(x.guests) == 1 and

View File

@@ -27,7 +27,8 @@ class DoneOnAllDirtCleaned(Rule):
class SpawnDirt(Rule): class SpawnDirt(Rule):
def __init__(self, initial_n: int, initial_amount: float, respawn_n: int, respawn_amount: float, def __init__(self, initial_n: int = 5, initial_amount: float = 1.3,
respawn_n: int = 3, respawn_amount: float = 0.8,
n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15): n_var: float = 0.2, amount_var: float = 0.2, spawn_freq: int = 15):
""" """
Defines the spawn pattern of intial and additional 'Dirt'-entitites. Defines the spawn pattern of intial and additional 'Dirt'-entitites.

View File

@@ -1,4 +1,4 @@
from .actions import DoorUse from .actions import DoorUse
from .entitites import Door, DoorIndicator from .entitites import Door, DoorIndicator
from .groups import Doors from .groups import Doors
from .rules import DoDoorAutoClose, IndicateDoorAreaInObservation from .rules import DoorAutoClose, IndicateDoorAreaInObservation

View File

@@ -5,7 +5,7 @@ from . import constants as d
from .entitites import DoorIndicator from .entitites import DoorIndicator
class DoDoorAutoClose(Rule): class DoorAutoClose(Rule):
def __init__(self, close_frequency: int = 10): def __init__(self, close_frequency: int = 10):
""" """

View File

@@ -135,7 +135,7 @@ class DropOffLocations(Collection):
@staticmethod @staticmethod
def trigger_drop_off_location_spawn(state, n_locations): def trigger_drop_off_location_spawn(state, n_locations):
empty_positions = state.entities.empty_positions()[:n_locations] empty_positions = state.entities.empty_positions[:n_locations]
do_entites = state[i.DROP_OFF] do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(pos) for pos in empty_positions] drop_offs = [DropOffLocation(pos) for pos in empty_positions]
do_entites.add_items(drop_offs) do_entites.add_items(drop_offs)

View File

@@ -8,6 +8,7 @@ import yaml
from marl_factory_grid.environment.groups.agents import Agents from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import locate_and_import_class from marl_factory_grid.utils.helpers import locate_and_import_class
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@@ -81,7 +82,15 @@ class FactoryConfigParser(object):
entity_class = locate_and_import_class(entity, folder_path) entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3: except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x] ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents)) print('### Error ### Error ### Error ### Error ### Error ###')
print()
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('Possible Entitys are:', str(ents))
print()
print('Goodbye')
print()
exit()
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
entity_kwargs = self.entities.get(entity, {}) entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
@@ -114,6 +123,7 @@ class FactoryConfigParser(object):
# Observation # Observation
observations = list() observations = list()
assert self.agents[name]['Observations'] is not None, 'Did you specify any Observation?'
if c.DEFAULTS in self.agents[name]['Observations']: if c.DEFAULTS in self.agents[name]['Observations']:
observations.extend(self.default_observations) observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
@@ -141,6 +151,8 @@ class FactoryConfigParser(object):
rule_class = locate_and_import_class(rule, folder_path) rule_class = locate_and_import_class(rule, folder_path)
except AttributeError: except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path) rule_class = locate_and_import_class(rule, self.custom_modules_path)
# Fixme This check does not work!
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
rule_kwargs = self.rules.get(rule, {}) rule_kwargs = self.rules.get(rule, {})
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}}) rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
return rules_classes return rules_classes

View File

@@ -76,6 +76,14 @@ class OBSBuilder(object):
named_obs_dict[agent.name] = {'observation': obs, 'names': names} named_obs_dict[agent.name] = {'observation': obs, 'names': names}
return named_obs_dict return named_obs_dict
def place_entity_in_observation(self, obs_array, agent, e):
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
try:
obs_array[x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out of range
pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray): def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, ( assert self._curr_env_step == state.curr_step, (
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'" "The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
@@ -91,12 +99,7 @@ class OBSBuilder(object):
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape)) pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
if self.pomdp_r: if self.pomdp_r:
for e in set(visible_entitites): for e in set(visible_entitites):
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r self.place_entity_in_observation(pre_sort_obs[e.obs_tag], agent, e)
try:
pre_sort_obs[e.obs_tag][x, y] += e.encoding
except IndexError:
# Seemded to be visible but is out or range
pass
else: else:
for e in set(visible_entitites): for e in set(visible_entitites):
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
@@ -157,18 +160,20 @@ class OBSBuilder(object):
np.put(obs[idx], 0, v, mode='raise') np.put(obs[idx], 0, v, mode='raise')
except IndexError: except IndexError:
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.') raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
if self.pomdp_r:
try: try:
light_map = np.zeros(self.obs_shape) light_map = np.zeros(self.obs_shape)
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)) visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
if self.pomdp_r: if self.pomdp_r:
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor] for f in set(visible_floor):
else: self.place_entity_in_observation(light_map, agent, f)
coords = [x.pos for x in visible_floor] else:
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1) for f in set(visible_floor):
self.curr_lightmaps[agent.name] = light_map light_map[f.x, f.y] += f.encoding
except KeyError: self.curr_lightmaps[agent.name] = light_map
print() except (KeyError, ValueError):
print()
pass
return obs, self.obs_layers[agent.name] return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent): def _sort_and_name_observation_conf(self, agent):

View File

@@ -0,0 +1,130 @@
import math
from itertools import product
import numpy as np
from numba import njit
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
self.agent = agent
self.pomdp_r = pomdp_r
self.n_rays = 100 # (self.pomdp_r + 1) * 8
self.degs = degs
self.ray_targets = self.build_ray_targets()
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
self._cache_dict = {}
def __repr__(self):
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1])*self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
[math.sin(theta), math.cos(theta)]] for theta in thetas
]
rot_M = np.stack(rot_M, 0)
rot_M = np.unique(np.round(rot_M @ north), axis=0)
return rot_M.astype(int)
def ray_block_cache(self, key, callback):
if key not in self._cache_dict:
self._cache_dict[key] = callback()
return self._cache_dict[key]
def visible_entities(self, pos_dict, reset_cache=True):
visible = list()
if reset_cache:
self._cache_dict = dict()
for ray in self.get_rays():
rx, ry = ray[0]
for x, y in ray:
cx, cy = x - rx, y - ry
entities_hit = pos_dict[(x, y)]
hits = self.ray_block_cache((x, y),
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
)
diag_hits = all([
self.ray_block_cache(
key,
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
if hits or diag_hits:
break
rx, ry = x, y
return visible
def get_rays(self):
a_pos = self.agent.pos
outline = self.ray_targets + a_pos
return self.bresenham_loop(a_pos, outline)
# todo do this once and cache the points!
def get_fov_outline(self) -> np.ndarray:
return self.ray_targets + self.agent.pos
def get_square_outline(self):
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod
@njit
def bresenham_loop(a_pos, points):
results = []
for end in points:
x1, y1 = a_pos
x2, y2 = end
dx = x2 - x1
dy = y2 - y1
# Determine how steep the line is
is_steep = abs(dy) > abs(dx)
# Rotate line
if is_steep:
x1, y1 = y1, x1
x2, y2 = y2, x2
# Swap start and end points if necessary and store swap state
swapped = False
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
swapped = True
# Recalculate differentials
dx = x2 - x1
dy = y2 - y1
# Calculate error
error = int(dx / 2.0)
ystep = 1 if y1 < y2 else -1
# Iterate over bounding box generating points between start and end
y = y1
points = []
for x in range(int(x1), int(x2) + 1):
coord = [y, x] if is_steep else [x, y]
points.append(coord)
error -= abs(dy)
if error < 0:
y += ystep
error += dx
# Reverse the list if the coordinates were swapped
if swapped:
points.reverse()
results.append(points)
return results

View File

@@ -88,8 +88,8 @@ class Gamestate(object):
results.extend(self.rules.tick_pre_step_all(self)) results.extend(self.rules.tick_pre_step_all(self))
for idx, action_int in enumerate(actions): for idx, action_int in enumerate(actions):
agent = self[c.AGENT][idx].clear_temp_state()
if not agent.var_is_paralyzed: if not agent.var_is_paralyzed:
agent = self[c.AGENT][idx].clear_temp_state()
action = agent.actions[action_int] action = agent.actions[action_int]
action_result = action.do(agent, self) action_result = action.do(agent, self)
results.append(action_result) results.append(action_result)

View File

@@ -31,6 +31,10 @@ class RenderEntity:
@dataclass @dataclass
class Floor: class Floor:
@property
def encoding(self):
return 1
@property @property
def name(self): def name(self):
return f"Floor({self.pos})" return f"Floor({self.pos})"