no more tiles no more floor

This commit is contained in:
Steffen Illium
2023-10-20 14:36:23 +02:00
parent 8709b093b8
commit 7a1d3f84f1
41 changed files with 265 additions and 217 deletions

View File

@ -21,7 +21,7 @@ class TSPBaseAgent(ABC):
self.local_optimization = True self.local_optimization = True
self._env = state self._env = state
self.state = self._env.state[c.AGENT][agent_i] self.state = self._env.state[c.AGENT][agent_i]
self._floortile_graph = points_to_graph(self._env[c.FLOORS].positions) self._position_graph = points_to_graph(self._env.entities.floorlist)
self._static_route = None self._static_route = None
@abstractmethod @abstractmethod
@ -50,7 +50,7 @@ class TSPBaseAgent(ABC):
else: else:
nodes = [self.state.pos] + positions nodes = [self.state.pos] + positions
route = tsp.traveling_salesman_problem(self._floortile_graph, route = tsp.traveling_salesman_problem(self._position_graph,
nodes=nodes, cycle=True, method=tsp.greedy_tsp) nodes=nodes, cycle=True, method=tsp.greedy_tsp)
return route return route

View File

@ -4,17 +4,17 @@ import networkx as nx
import numpy as np import numpy as np
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhattan_connections=True):
""" """
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points. Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
There are three combinations of settings: There are three combinations of settings:
Allow all neigbors: Distance(a, b) <= sqrt(2) Allow all neigbors: Distance(a, b) <= sqrt(2)
Allow only manhattan: Distance(a, b) == 1 Allow only manhattan: Distance(a, b) == 1
Allow only euclidean: Distance(a, b) == sqrt(2) Allow only Euclidean: Distance(a, b) == sqrt(2)
:param coordiniates_or_tiles: A set of coordinates. :param coordiniates: A set of coordinates.
:type coordiniates_or_tiles: Tiles :type coordiniates: Tuple[int, int]
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors :param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
:type: bool :type: bool
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors :param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
@ -24,9 +24,7 @@ def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, all
:rtype: nx.Graph :rtype: nx.Graph
""" """
assert allow_euclidean_connections or allow_manhattan_connections assert allow_euclidean_connections or allow_manhattan_connections
if hasattr(coordiniates_or_tiles, 'positions'): possible_connections = itertools.combinations(coordiniates, 2)
coordiniates_or_tiles = coordiniates_or_tiles.positions
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
graph = nx.Graph() graph = nx.Graph()
for a, b in possible_connections: for a, b in possible_connections:
diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) diff = np.linalg.norm(np.asarray(a)-np.asarray(b))

View File

@ -66,7 +66,6 @@ Rules:
DestinationDone: {} DestinationDone: {}
DestinationReach: DestinationReach:
n_dests: 1 n_dests: 1
tiles: null
DestinationSpawn: DestinationSpawn:
n_dests: 1 n_dests: 1
spawn_frequency: 5 spawn_frequency: 5

View File

@ -42,12 +42,12 @@ class Move(Action, abc.ABC):
def do(self, entity, state): def do(self, entity, state):
new_pos = self._calc_new_pos(entity.pos) new_pos = self._calc_new_pos(entity.pos)
if state.check_move_validity(entity, new_pos): # next_tile := state[c.FLOOR].by_pos(new_pos): if state.check_move_validity(entity, new_pos):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
move_validity = entity.move(new_pos, state) move_validity = entity.move(new_pos, state)
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
else: # There is no floor, propably collision else: # There is no place to go, propably collision
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml # This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0) return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)

View File

@ -3,15 +3,13 @@ DANGER_ZONE = 'x' # Dange Zone tile _identifier fo
DEFAULTS = 'Defaults' DEFAULTS = 'Defaults'
SELF = 'Self' SELF = 'Self'
PLACEHOLDER = 'Placeholder' PLACEHOLDER = 'Placeholder'
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
WALL = 'Wall' # Identifier of Wall-objects and groups (groups). WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups). WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
LEVEL = 'Level' # Identifier of Level-objects and groups (groups). LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups). AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
OTHERS = 'Other' OTHERS = 'Other'
COMBINED = 'Combined' COMBINED = 'Combined'
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
# Attributes # Attributes
IS_BLOCKING_LIGHT = 'var_is_blocking_light' IS_BLOCKING_LIGHT = 'var_is_blocking_light'
@ -32,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
ACTION = 'action' # Identifier of Action-objects and groups (groups). ACTION = 'action' # Identifier of Action-objects and groups (groups).
COLLISION = 'Collision' # Identifier to use in the context of collitions. COLLISION = 'Collision' # Identifier to use in the context of collitions.
LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos. # LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ... VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
# Actions # Actions

View File

@ -2,7 +2,7 @@ from typing import List, Union
from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.actions import Action
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.utils import renderer from marl_factory_grid.utils import renderer
from marl_factory_grid.utils.helpers import is_move from marl_factory_grid.utils.helpers import is_move
from marl_factory_grid.utils.results import ActionResult, Result from marl_factory_grid.utils.results import ActionResult, Result

View File

@ -1,8 +1,10 @@
import abc import abc
import numpy as np
from .. import constants as c from .. import constants as c
from .object import EnvObject from .object import EnvObject
from ...utils.render import RenderEntity from ...utils.utility_classes import RenderEntity
from ...utils.results import ActionResult from ...utils.results import ActionResult
@ -30,33 +32,32 @@ class Entity(EnvObject, abc.ABC):
return self._pos return self._pos
@property @property
def tile(self): def last_pos(self):
return self._tile # wall_n_floors funktionalität try:
return self._last_pos
# @property except AttributeError:
# def last_tile(self): # noinspection PyAttributeOutsideInit
# try: self._last_pos = c.VALUE_NO_POS
# return self._last_tile return self._last_pos
# except AttributeError:
# # noinspection PyAttributeOutsideInit
# self._last_tile = None
# return self._last_tile
@property @property
def direction_of_view(self): def direction_of_view(self):
last_x, last_y = self._last_pos if self._last_pos != c.VALUE_NO_POS:
curr_x, curr_y = self.pos return 0, 0
return last_x - curr_x, last_y - curr_y else:
return np.subtract(self._last_pos, self.pos)
def move(self, next_pos, state): def move(self, next_pos, state):
next_pos = next_pos next_pos = next_pos
curr_pos = self._pos curr_pos = self._pos
if not_same_pos := curr_pos != next_pos: if not_same_pos := curr_pos != next_pos:
if valid := state.check_move_validity(self, next_pos): if valid := state.check_move_validity(self, next_pos):
self._pos = next_pos
self._last_pos = curr_pos
for observer in self.observers: for observer in self.observers:
observer.notify_change_pos(self) observer.notify_del_entity(self)
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
self._pos = next_pos
for observer in self.observers:
observer.notify_add_entity(self)
return valid return valid
return not_same_pos return not_same_pos
@ -64,6 +65,7 @@ class Entity(EnvObject, abc.ABC):
super().__init__(**kwargs) super().__init__(**kwargs)
self._status = None self._status = None
self._pos = pos self._pos = pos
self._last_pos = pos
if bind_to: if bind_to:
try: try:
self.bind_to(bind_to) self.bind_to(bind_to)

View File

@ -4,7 +4,7 @@ import numpy as np
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import EnvObject from marl_factory_grid.environment.entity.object import EnvObject
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
@ -30,17 +30,6 @@ class Floor(EnvObject):
def var_is_blocking_light(self): def var_is_blocking_light(self):
return False return False
@property
def neighboring_floor(self):
if self._neighboring_floor:
pass
else:
self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos))
for pos in h.POS_MASK.reshape(-1, 2)
if not np.all(pos == [0, 0])]
if x]
return self._neighboring_floor
@property @property
def encoding(self): def encoding(self):
return c.VALUE_OCCUPIED_CELL return c.VALUE_OCCUPIED_CELL

View File

@ -197,7 +197,7 @@ class Factory(gym.Env):
del rewards['global'] del rewards['global']
reward = [rewards[agent.name] for agent in self.state[c.AGENT]] reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
reward = [x + global_rewards for x in reward] reward = [x + global_rewards for x in reward]
self.state.print(f"rewards are {rewards}") self.state.print(f"Individual rewards are {dict(rewards)}")
return reward, combined_info_dict, done return reward, combined_info_dict, done
else: else:
reward = sum(rewards.values()) reward = sum(rewards.values())
@ -220,7 +220,7 @@ class Factory(gym.Env):
def summarize_header(self): def summarize_header(self):
header = {'rec_step': self.state.curr_step} header = {'rec_step': self.state.curr_step}
for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']): for entity_group in (x for x in self.state if x.name in ['Walls', 'DropOffLocations', 'ChargePods']):
header.update({f'rec{entity_group.name}': entity_group.summarize_states()}) header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
return header return header
@ -229,7 +229,7 @@ class Factory(gym.Env):
# Todo: Protobuff Compatibility Section ####### # Todo: Protobuff Compatibility Section #######
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]): # for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
for entity_group in (x for x in self.state if x.name not in [c.FLOORS]): for entity_group in self.state:
summary.update({entity_group.name.lower(): entity_group.summarize_states()}) summary.update({entity_group.name.lower(): entity_group.summarize_states()})
# TODO Section End ######## # TODO Section End ########
for key in list(summary.keys()): for key in list(summary.keys()):

View File

@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
from operator import itemgetter from operator import itemgetter
from random import shuffle
from typing import Dict from typing import Dict
from marl_factory_grid.environment.groups.objects import Objects from marl_factory_grid.environment.groups.objects import Objects
@ -13,7 +14,7 @@ class Entities(Objects):
def neighboring_positions(pos): def neighboring_positions(pos):
return (POS_MASK + pos).reshape(-1, 2) return (POS_MASK + pos).reshape(-1, 2)
def get_near_pos(self, pos): def get_entities_near_pos(self, pos):
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x] return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
def render(self): def render(self):
@ -38,11 +39,17 @@ class Entities(Objects):
def guests_that_can_collide(self, pos): def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide] return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
def empty_tiles(self): @property
return[key for key in self.floorlist if not any(self.pos_dict[key])] def empty_positions(self):
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
shuffle(empty_positions)
return empty_positions
def occupied_tiles(self): # positions that are not empty @property
return[key for key in self.floorlist if any(self.pos_dict[key])] def occupied_positions(self): # positions that are not empty
empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
shuffle(empty_positions)
return empty_positions
def is_blocked(self): def is_blocked(self):
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])] return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]

View File

@ -37,7 +37,11 @@ class PositionMixin:
def __delitem__(self, name): def __delitem__(self, name):
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
obj.tile.leave(obj) # observer notify? try:
for observer in obj.observers:
observer.notify_del_entity(obj)
except AttributeError:
pass
super().__delitem__(name) super().__delitem__(name)
def by_pos(self, pos: (int, int)): def by_pos(self, pos: (int, int)):

View File

@ -103,6 +103,9 @@ class Objects:
except StopIteration: except StopIteration:
return None return None
def by_name(self, name):
return next(x for x in self if x.name == name)
def __getitem__(self, item): def __getitem__(self, item):
if isinstance(item, (int, np.int64, np.int32)): if isinstance(item, (int, np.int64, np.int32)):
if item < 0: if item < 0:
@ -120,7 +123,7 @@ class Objects:
raise TypeError raise TypeError
def __repr__(self): def __repr__(self):
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS, c.FLOORS]} repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]' return f'{self.__class__.__name__}[{repr_dict}]'
def spawn(self, n: int): def spawn(self, n: int):
@ -132,22 +135,25 @@ class Objects:
for item in items: for item in items:
del self[item] del self[item]
def notify_change_pos(self, entity: object): # def notify_change_pos(self, entity: object):
try: # try:
self.pos_dict[entity.last_pos].remove(entity) # self.pos_dict[entity.last_pos].remove(entity)
except (ValueError, AttributeError): # except (ValueError, AttributeError):
pass # pass
if entity.var_has_position: # if entity.var_has_position:
try: # try:
self.pos_dict[entity.pos].append(entity) # self.pos_dict[entity.pos].append(entity)
except (ValueError, AttributeError): # except (ValueError, AttributeError):
pass # pass
def notify_del_entity(self, entity: Object): def notify_del_entity(self, entity: Object):
try: try:
entity.del_observer(self) entity.del_observer(self)
except AttributeError:
pass
try:
self.pos_dict[entity.pos].remove(entity) self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError): except (AttributeError, ValueError, IndexError):
pass pass
def notify_add_entity(self, entity: Object): def notify_add_entity(self, entity: Object):

View File

@ -15,6 +15,7 @@ class Walls(PositionMixin, EnvObjects):
super(Walls, self).__init__(*args, **kwargs) super(Walls, self).__init__(*args, **kwargs)
self._value = c.VALUE_OCCUPIED_CELL self._value = c.VALUE_OCCUPIED_CELL
#ToDo: Do we need this? Move to spawn methode?
# @classmethod # @classmethod
# def from_coordinates(cls, argwhere_coordinates, *args, **kwargs): # def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
# tiles = cls(*args, **kwargs) # tiles = cls(*args, **kwargs)

View File

@ -49,7 +49,7 @@ class SpawnAgents(Rule):
agent_conf = state.agents_conf agent_conf = state.agents_conf
# agents = Agents(lvl_map.size) # agents = Agents(lvl_map.size)
agents = state[c.AGENT] agents = state[c.AGENT]
empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)] empty_positions = state.entities.empty_positions[:len(agent_conf)]
for agent_name in agent_conf: for agent_name in agent_conf:
actions = agent_conf[agent_name]['actions'].copy() actions = agent_conf[agent_name]['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy() observations = agent_conf[agent_name]['observations'].copy()
@ -58,18 +58,17 @@ class SpawnAgents(Rule):
shuffle(positions) shuffle(positions)
while True: while True:
try: try:
tile = state[c.FLOORS].by_pos(positions.pop()) pos = positions.pop()
except IndexError as e: except IndexError as e:
raise ValueError(f'It was not possible to spawn an Agent on the available position: ' raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_name[agent_name]["positions"].copy()}') f'\n{agent_name[agent_name]["positions"].copy()}')
try: if agents.by_pos(pos) and state.check_pos_validity(pos):
agents.add_item(Agent(actions, observations, tile, str_ident=agent_name))
except AssertionError:
state.print(f'No valid pos:{tile.pos} for {agent_name}')
continue continue
else:
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
break break
else: else:
agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name)) agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
pass pass

View File

@ -2,7 +2,7 @@ from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
from marl_factory_grid.environment.entity.object import EnvObject from marl_factory_grid.environment.entity.object import EnvObject
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.modules.batteries import constants as b

View File

@ -70,8 +70,8 @@ class PodRules(Rule):
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
pod_collection = state[b.CHARGE_PODS] pod_collection = state[b.CHARGE_PODS]
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_pods] empty_positions = state.entities.empty_positions()
pods = pod_collection.from_coordinates(empty_tiles, entity_kwargs=dict( pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
multi_charge=self.multi_charge, charge_rate=self.charge_rate) multi_charge=self.multi_charge, charge_rate=self.charge_rate)
) )
pod_collection.add_items(pods) pod_collection.add_items(pods)

View File

@ -1,7 +1,7 @@
from numpy import random from numpy import random
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.clean_up import constants as d from marl_factory_grid.modules.clean_up import constants as d

View File

@ -1,6 +1,5 @@
from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin from marl_factory_grid.environment.groups.mixins import PositionMixin
from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.modules.clean_up.entitites import DirtPile from marl_factory_grid.modules.clean_up.entitites import DirtPile
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@ -31,8 +30,6 @@ class DirtPiles(PositionMixin, EnvObjects):
self.max_local_amount = max_local_amount self.max_local_amount = max_local_amount
def spawn(self, then_dirty_positions, amount) -> bool: def spawn(self, then_dirty_positions, amount) -> bool:
# if isinstance(then_dirty_tiles, Floor):
# then_dirty_tiles = [then_dirty_tiles]
for pos in then_dirty_positions: for pos in then_dirty_positions:
if not self.amount > self.max_global_amount: if not self.amount > self.max_global_amount:
if dirt := self.by_pos(pos): if dirt := self.by_pos(pos):
@ -56,8 +53,8 @@ class DirtPiles(PositionMixin, EnvObjects):
var = self.dirt_spawn_r_var var = self.dirt_spawn_r_var
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0)) new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) n_dirty_positions = max(0, int(new_spawn * len(free_for_dirt)))
return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount) return self.spawn(free_for_dirt[:n_dirty_positions], self.initial_amount)
def __repr__(self): def __repr__(self):
s = super(DirtPiles, self).__repr__() s = super(DirtPiles, self).__repr__()

View File

@ -4,7 +4,7 @@ from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.destinations import constants as d from marl_factory_grid.modules.destinations import constants as d
@ -17,7 +17,6 @@ class Destination(BoundEntityMixin, Entity):
var_is_blocking_light = False var_is_blocking_light = False
var_can_be_bound = True # Introduce this globally! var_can_be_bound = True # Introduce this globally!
@property
def was_reached(self): def was_reached(self):
return self._was_reached return self._was_reached
@ -35,11 +34,10 @@ class Destination(BoundEntityMixin, Entity):
self._per_agent_actions[agent.name] += 1 self._per_agent_actions[agent.name] += 1
return c.VALID return c.VALID
@property def has_just_been_reached(self, state):
def has_just_been_reached(self): if self.was_reached():
if self.was_reached:
return False return False
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in state.entities.pos_dict[self.pos] if x.var_can_collide) agent_at_position = any(state[c.AGENT].by_pos(self.pos))
if self.bound_entity: if self.bound_entity:
return ((agent_at_position and not self.action_counts) return ((agent_at_position and not self.action_counts)
@ -57,7 +55,7 @@ class Destination(BoundEntityMixin, Entity):
return state_summary return state_summary
def render(self): def render(self):
if self.was_reached: if self.was_reached():
return None return None
else: else:
return RenderEntity(d.DESTINATION, self.pos) return RenderEntity(d.DESTINATION, self.pos)

View File

@ -16,28 +16,29 @@ class DestinationReachAll(Rule):
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
results = [] results = []
reached = False
for dest in state[d.DESTINATION]: for dest in state[d.DESTINATION]:
if dest.has_just_been_reached and not dest.was_reached: if dest.has_just_been_reached(state) and not dest.was_reached():
# Dest has just been reached, some agent needs to stand here, grab any first. # Dest has just been reached, some agent needs to stand here
for agent in state[c.AGENT].by_pos(dest.pos): for agent in state[c.AGENT].by_pos(dest.pos):
if dest.bound_entity: if dest.bound_entity:
if dest.bound_entity == agent: if dest.bound_entity == agent:
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) reached = True
else: else:
pass pass
else: else:
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) reached = True
state.print(f'{dest.name} is reached now, mark as reached...')
dest.mark_as_reached()
else: else:
pass pass
if reached:
state.print(f'{dest.name} is reached now, mark as reached...')
dest.mark_as_reached()
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
return results return results
def tick_post_step(self, state) -> List[TickResult]:
return []
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if all(x.was_reached for x in state[d.DESTINATION]): if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
@ -48,7 +49,7 @@ class DestinationReachAny(DestinationReachAll):
super(DestinationReachAny, self).__init__() super(DestinationReachAny, self).__init__()
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if any(x.was_reached for x in state[d.DESTINATION]): if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return [] return []
@ -63,7 +64,7 @@ class DestinationSpawn(Rule):
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
self.trigger_destination_spawn(self.n_dests, state) state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
pass pass
def tick_pre_step(self, state) -> List[TickResult]: def tick_pre_step(self, state) -> List[TickResult]:
@ -72,24 +73,14 @@ class DestinationSpawn(Rule):
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])): if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests: if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
validity = self.trigger_destination_spawn(n_dest_spawn, state) validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn: elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
validity = self.trigger_destination_spawn(n_dest_spawn, state) validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)] return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
else: else:
pass pass
def trigger_destination_spawn(self, n_dests, state):
empty_positions = state[c.FLOORS].empty_tiles[:n_dests]
if destinations := [Destination(pos) for pos in empty_positions]:
state[d.DESTINATION].add_items(destinations)
state.print(f'{n_dests} new destinations have been spawned')
return c.VALID
else:
state.print('No Destiantions are spawning, limit is reached.')
return c.NOT_VALID
class FixedDestinationSpawn(Rule): class FixedDestinationSpawn(Rule):
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
@ -99,11 +90,17 @@ class FixedDestinationSpawn(Rule):
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
for (agent_name, position_list) in self.per_agent_positions.items(): for (agent_name, position_list) in self.per_agent_positions.items():
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
position_list = position_list.copy()
shuffle(position_list) shuffle(position_list)
while True: while True:
pos = position_list.pop() try:
if pos != agent.pos and not state[d.DESTINATION].by_pos(pos): pos = position_list.pop()
destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent) except IndexError:
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
exit(9999)
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
destination = Destination(pos, bind_to=agent)
break break
else: else:
continue continue

View File

@ -1,4 +1,4 @@
from .actions import DoorUse from .actions import DoorUse
from .entitites import Door, DoorIndicator from .entitites import Door, DoorIndicator
from .groups import Doors from .groups import Doors
from .rule_door_auto_close import DoorAutoClose from .rules import DoorAutoClose, DoorIndicateArea

View File

@ -13,8 +13,9 @@ class DoorUse(Action):
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
# Check if agent really is standing on a door: # Check if agent really is standing on a door:
e = state.entities.get_near_pos(entity.pos) e = state.entities.get_entities_near_pos(entity.pos)
try: try:
# Only one door opens TODO introcude loop
door = next(x for x in e if x.name.startswith(d.DOOR)) door = next(x for x in e if x.name.startswith(d.DOOR))
valid = door.use() valid = door.use()
state.print(f'{entity.name} just used a {door.name} at {door.pos}') state.print(f'{entity.name} just used a {door.name} at {door.pos}')

View File

@ -1,5 +1,5 @@
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.doors import constants as d from marl_factory_grid.modules.doors import constants as d
@ -41,7 +41,7 @@ class Door(Entity):
def str_state(self): def str_state(self):
return 'open' if self.is_open else 'closed' return 'open' if self.is_open else 'closed'
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs): def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
self._status = d.STATE_CLOSED self._status = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs) super(Door, self).__init__(*args, **kwargs)
self.auto_close_interval = auto_close_interval self.auto_close_interval = auto_close_interval
@ -50,8 +50,6 @@ class Door(Entity):
self._open() self._open()
else: else:
self._close() self._close()
if indicate_area:
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
def summarize_state(self): def summarize_state(self):
state_dict = super().summarize_state() state_dict = super().summarize_state()

View File

@ -1,7 +1,8 @@
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import TickResult from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.modules.doors import constants as d from . import constants as d
from .entitites import DoorIndicator
class DoorAutoClose(Rule): class DoorAutoClose(Rule):
@ -19,3 +20,13 @@ class DoorAutoClose(Rule):
return [TickResult(self.name, validity=c.VALID, value=0)] return [TickResult(self.name, validity=c.VALID, value=0)]
state.print('There are no doors, but you loaded the corresponding Module') state.print('There are no doors, but you loaded the corresponding Module')
return [] return []
class DoorIndicateArea(Rule):
def __init__(self):
super().__init__()
def on_init(self, state, lvl_map):
for door in state[d.DOORS]:
state[d.DOORS].add_items([DoorIndicator(x) for x in state.entities.neighboring_positions(door.pos)])

View File

@ -9,6 +9,8 @@ from marl_factory_grid.utils.results import TickResult
class AgentSingleZonePlacementBeta(Rule): class AgentSingleZonePlacementBeta(Rule):
def __init__(self): def __init__(self):
raise NotImplementedError()
# TODO!!!! Is this concept needed any more?
super().__init__() super().__init__()
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
@ -21,9 +23,9 @@ class AgentSingleZonePlacementBeta(Rule):
coordinates = random.choices(self.coordinates, k=len(agents)) coordinates = random.choices(self.coordinates, k=len(agents))
else: else:
raise ValueError raise ValueError
tiles = [state[c.FLOORS].by_pos(pos) for pos in coordinates]
for agent, tile in zip(agents, tiles): for agent, pos in zip(agents, coordinates):
agent.move(tile, state) agent.move(pos, state)
def tick_step(self, state): def tick_step(self, state):
return [] return []

View File

@ -2,7 +2,7 @@ from collections import deque
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.items import constants as i from marl_factory_grid.modules.items import constants as i

View File

@ -1,3 +1,5 @@
from random import shuffle
from marl_factory_grid.modules.items import constants as i from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@ -19,10 +21,12 @@ class Items(PositionMixin, EnvObjects):
@staticmethod @staticmethod
def trigger_item_spawn(state, n_items, spawn_frequency): def trigger_item_spawn(state, n_items, spawn_frequency):
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))): if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
floor_list = state.entities.floorlist[:item_to_spawns] position_list = [x for x in state.entities.floorlist]
state[i.ITEM].spawn(floor_list) shuffle(position_list)
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') # spawn in self._next_item_spawn ? position_list = state.entities.floorlist[:item_to_spawns]
return len(floor_list) state[i.ITEM].spawn(position_list)
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}')
return len(position_list)
else: else:
state.print('No Items are spawning, limit is reached.') state.print('No Items are spawning, limit is reached.')
return 0 return 0
@ -100,7 +104,7 @@ class DropOffLocations(PositionMixin, EnvObjects):
@staticmethod @staticmethod
def trigger_drop_off_location_spawn(state, n_locations): def trigger_drop_off_location_spawn(state, n_locations):
empty_tiles = state.entities.floorlist[:n_locations] empty_positions = state.entities.empty_positions()[:n_locations]
do_entites = state[i.DROP_OFF] do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(tile) for tile in empty_tiles] drop_offs = [DropOffLocation(pos) for pos in empty_positions]
do_entites.add_items(drop_offs) do_entites.add_items(drop_offs)

View File

@ -1,5 +1,5 @@
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.render import RenderEntity from ...utils.utility_classes import RenderEntity
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import TickResult from marl_factory_grid.utils.results import TickResult

View File

@ -13,8 +13,8 @@ class MachineRule(Rule):
self.n_machines = n_machines self.n_machines = n_machines
def on_init(self, state, lvl_map): def on_init(self, state, lvl_map):
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_machines] # TODO Move to spawn!!!
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles) state[m.MACHINES].add_items(Machine(pos) for pos in state.entities.empty_positions())
def tick_pre_step(self, state) -> List[TickResult]: def tick_pre_step(self, state) -> List[TickResult]:
pass pass

View File

@ -8,7 +8,7 @@ from ...environment.entity.entity import Entity
from ..doors import constants as do from ..doors import constants as do
from ..maintenance import constants as mi from ..maintenance import constants as mi
from ...utils.helpers import MOVEMAP from ...utils.helpers import MOVEMAP
from ...utils.render import RenderEntity from ...utils.utility_classes import RenderEntity
from ...utils.states import Gamestate from ...utils.states import Gamestate
@ -39,7 +39,7 @@ class Maintainer(Entity):
self._next = [] self._next = []
self._last = [] self._last = []
self._last_serviced = 'None' self._last_serviced = 'None'
self._floortile_graph = points_to_graph(state[c.FLOORS].positions) self._floortile_graph = points_to_graph(state.entities.floorlist)
def tick(self, state): def tick(self, state):
if found_objective := state[self.objective].by_pos(self.pos): if found_objective := state[self.objective].by_pos(self.pos):

View File

@ -14,7 +14,8 @@ class MaintenanceRule(Rule):
self.n_maintainer = n_maintainer self.n_maintainer = n_maintainer
def on_init(self, state: Gamestate, lvl_map): def on_init(self, state: Gamestate, lvl_map):
state[M.MAINTAINERS].spawn(state[c.FLOORS].empty_tiles[:self.n_maintainer], state) # Move to spawn? : #TODO
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
pass pass
def tick_pre_step(self, state) -> List[TickResult]: def tick_pre_step(self, state) -> List[TickResult]:

View File

@ -3,8 +3,7 @@ from typing import List, Tuple
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.entity.wall_floor import Floor from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.utils.render import RenderEntity
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.doors import constants as d from marl_factory_grid.modules.doors import constants as d
@ -21,5 +20,5 @@ class Zone(Object):
self.coords = coords self.coords = coords
@property @property
def random_tile(self): def random_pos(self):
return random.choice(self.coords) return random.choice(self.coords)

View File

@ -19,7 +19,7 @@ class ZoneInit(Rule):
while z_idx: while z_idx:
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx) zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
if len(zone_positions): if len(zone_positions):
zones.append(Zone([state[c.FLOORS].by_pos(pos) for pos in zone_positions])) zones.append(Zone(zone_positions))
z_idx += 1 z_idx += 1
else: else:
z_idx = 0 z_idx = 0
@ -38,7 +38,7 @@ class AgentSingleZonePlacement(Rule):
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents) z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
for agent in state[c.AGENT]: for agent in state[c.AGENT]:
agent.move(state[z.ZONES][z_idxs.pop()].random_tile, state) agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state)
return [] return []
def tick_step(self, state): def tick_step(self, state):
@ -65,10 +65,10 @@ class IndividualDestinationZonePlacement(Rule):
other_zones = [x for x in state[z.ZONES] if x not in agent_zones] other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
already_has_destination = True already_has_destination = True
while already_has_destination: while already_has_destination:
tile = choice(other_zones).random_tile pos = choice(other_zones).random_pos
if state[d.DESTINATION].by_pos(tile.pos) is None: if state[d.DESTINATION].by_pos(pos) is None:
already_has_destination = False already_has_destination = False
destination = Destination(tile, bind_to=agent) destination = Destination(pos, bind_to=agent)
state[d.DESTINATION].add_item(destination) state[d.DESTINATION].add_item(destination)
continue continue

View File

@ -25,10 +25,8 @@ This file is used for:
LEVELS_DIR = 'modules/levels' # for use in studies and experiments LEVELS_DIR = 'modules/levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step STEPS_START = 1 # Define where to the stepcount; which is the first step
# Not used anymore? Clean!
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
'episode'] 'episode']
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]], POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
@ -223,7 +221,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos] module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
mod = importlib.import_module('.'.join(module_parts)) mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor' and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin', 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'

View File

@ -6,7 +6,7 @@ import numpy as np
from marl_factory_grid.environment.groups.agents import Agents from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.groups.global_entities import Entities from marl_factory_grid.environment.groups.global_entities import Entities
from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors from marl_factory_grid.environment.groups.walls import Walls
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@ -34,16 +34,14 @@ class LevelParser(object):
def do_init(self): def do_init(self):
# Global Entities # Global Entities
list_of_all_floors = ([tuple(floor) for floor in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)]) list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
entities = Entities(list_of_all_floors) entities = Entities(list_of_all_positions)
# Walls # Walls
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size) walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
entities.add_items({c.WALLS: walls}) entities.add_items({c.WALLS: walls})
# Floor # Agents
floor = Floors.from_coordinates(list_of_all_floors, self.size)
entities.add_items({c.FLOOR: floor})
entities.add_items({c.AGENT: Agents(self.size)}) entities.add_items({c.AGENT: Agents(self.size)})
# All other # All other

View File

@ -9,6 +9,7 @@ from numba import njit
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.utils import Combined from marl_factory_grid.environment.groups.utils import Combined
from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils.utility_classes import Floor
class OBSBuilder(object): class OBSBuilder(object):
@ -39,6 +40,7 @@ class OBSBuilder(object):
self.reset_struc_obs_block(state) self.reset_struc_obs_block(state)
self.curr_lightmaps = dict() self.curr_lightmaps = dict()
self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist})
def reset_struc_obs_block(self, state): def reset_struc_obs_block(self, state):
self._curr_env_step = state.curr_step self._curr_env_step = state.curr_step
@ -82,19 +84,23 @@ class OBSBuilder(object):
self._sort_and_name_observation_conf(agent) self._sort_and_name_observation_conf(agent)
agent_want_obs = self.obs_layers[agent.name] agent_want_obs = self.obs_layers[agent.name]
# Handle in-grid observations aka visible observations # Handle in-grid observations aka visible observations (Things on the map, with pos)
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities) visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict)
pre_sort_obs = defaultdict(lambda: np.zeros((self.pomdp_d, self.pomdp_d))) pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
for e in set(visible_entitites): if self.pomdp_r:
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r for e in set(visible_entitites):
try: x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
pre_sort_obs[e.obs_tag][x, y] += e.encoding try:
except IndexError: pre_sort_obs[e.obs_tag][x, y] += e.encoding
# Seemded to be visible but is out or range except IndexError:
pass # Seemded to be visible but is out or range
pass
else:
for e in set(visible_entitites):
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
pre_sort_obs = dict(pre_sort_obs) pre_sort_obs = dict(pre_sort_obs)
obs = np.zeros((len(agent_want_obs), self.pomdp_d, self.pomdp_d)) obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
for idx, l_name in enumerate(agent_want_obs): for idx, l_name in enumerate(agent_want_obs):
try: try:
@ -144,13 +150,26 @@ class OBSBuilder(object):
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.') raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
try: try:
self.curr_lightmaps[agent.name] = pre_sort_obs[c.FLOORS].astype(bool) light_map = np.zeros(self.obs_shape)
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False))
if self.pomdp_r:
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor]
else:
coords = [x.pos for x in visible_floor]
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1)
self.curr_lightmaps[agent.name] = light_map
except KeyError: except KeyError:
print() print()
return obs, self.obs_layers[agent.name] return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent): def _sort_and_name_observation_conf(self, agent):
self.ray_caster[agent.name] = RayCaster(agent, self.pomdp_r) '''
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
'''
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
obs_layers = [] obs_layers = []
for obs_str in agent.observations: for obs_str in agent.observations:
@ -173,7 +192,7 @@ class OBSBuilder(object):
names.extend([x.name for x in agent.collection if x.name != agent.name]) names.extend([x.name for x in agent.collection if x.name != agent.name])
else: else:
names.append(val) names.append(val)
combined = Combined(names, self.pomdp_r, identifier=agent.name) combined = Combined(names, self.size, identifier=agent.name)
self.all_obs[combined.name] = combined self.all_obs[combined.name] = combined
obs_layers.append(combined.name) obs_layers.append(combined.name)
elif obs_str == c.OTHERS: elif obs_str == c.OTHERS:
@ -183,19 +202,18 @@ class OBSBuilder(object):
else: else:
obs_layers.append(obs_str) obs_layers.append(obs_str)
self.obs_layers[agent.name] = obs_layers self.obs_layers[agent.name] = obs_layers
self.curr_lightmaps[agent.name] = np.zeros((self.pomdp_d or self.level_shape[0], self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
self.pomdp_d or self.level_shape[1]
))
class RayCaster: class RayCaster:
def __init__(self, agent, pomdp_r, degs=360): def __init__(self, agent, pomdp_r, degs=360):
self.agent = agent self.agent = agent
self.pomdp_r = pomdp_r self.pomdp_r = pomdp_r
self.n_rays = 100 # (self.pomdp_r + 1) * 8 self.n_rays = (self.pomdp_r + 1) * 8
self.degs = degs self.degs = degs
self.ray_targets = self.build_ray_targets() self.ray_targets = self.build_ray_targets()
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r]) self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
self._cache_dict = {}
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}({self.agent.name})' return f'{self.__class__.__name__}({self.agent.name})'
@ -211,30 +229,30 @@ class RayCaster:
rot_M = np.unique(np.round(rot_M @ north), axis=0) rot_M = np.unique(np.round(rot_M @ north), axis=0)
return rot_M.astype(int) return rot_M.astype(int)
def ray_block_cache(self, cache_dict, key, callback): def ray_block_cache(self, key, callback):
if key not in cache_dict: if key not in self._cache_dict:
cache_dict[key] = callback() self._cache_dict[key] = callback()
return cache_dict[key] return self._cache_dict[key]
def visible_entities(self, entities): def visible_entities(self, pos_dict, reset_cache=True):
visible = list() visible = list()
cache_blocking = {} if reset_cache:
self._cache_dict = {}
for ray in self.get_rays(): for ray in self.get_rays():
rx, ry = ray[0] rx, ry = ray[0]
for x, y in ray: for x, y in ray:
cx, cy = x - rx, y - ry cx, cy = x - rx, y - ry
entities_hit = entities.pos_dict[(x, y)] entities_hit = pos_dict[(x, y)]
hits = self.ray_block_cache(cache_blocking, hits = self.ray_block_cache((x, y),
(x, y), lambda: any(True for e in entities_hit if e.var_is_blocking_light)
lambda: any(True for e in entities_hit if e.var_is_blocking_light)) )
diag_hits = all([ diag_hits = all([
self.ray_block_cache( self.ray_block_cache(
cache_blocking,
key, key,
lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light)) lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(pos_dict[key]))
for key in ((x, y-cy), (x-cx, y)) for key in ((x, y-cy), (x-cx, y))
]) if (cx != 0 and cy != 0) else False ]) if (cx != 0 and cy != 0) else False

View File

@ -1,16 +0,0 @@
from dataclasses import dataclass
from typing import Any
import numpy as np
@dataclass
class RenderEntity:
name: str
pos: np.array
value: float = 1
value_operation: str = 'none'
state: str = None
id: int = 0
aux: Any = None
real_name: str = 'none'

View File

@ -9,7 +9,7 @@ import pygame
from typing import Tuple, Union from typing import Tuple, Union
import time import time
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.utility_classes import RenderEntity
AGENT: str = 'agent' AGENT: str = 'agent'
STATE_IDLE: str = 'idle' STATE_IDLE: str = 'idle'

View File

@ -3,8 +3,6 @@ from typing import List, Dict, Tuple
import numpy as np import numpy as np
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.groups.global_entities import Entities
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.results import Result
@ -112,15 +110,10 @@ class Gamestate(object):
results.extend(on_check_done_result) results.extend(on_check_done_result)
return results return results
# def get_all_tiles_with_collisions(self) -> List[Floor]:
# tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items()
# if sum([x.var_can_collide for x in e]) > 1]
# # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
# return tiles
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, e in self.entities.pos_dict.items() positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
if sum([x.var_can_collide for x in e]) > 1] if any([e.var_can_collide for e in entity_list_for_position])]
return positions return positions
def check_move_validity(self, moving_entity, position): def check_move_validity(self, moving_entity, position):
@ -128,6 +121,14 @@ class Gamestate(object):
# and not (guest.var_is_blocking_pos and self.is_occupied()): # and not (guest.var_is_blocking_pos and self.is_occupied()):
if moving_entity.pos != position and not any( if moving_entity.pos != position and not any(
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not ( entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
moving_entity.var_is_blocking_pos and moving_entity.is_occupied()): moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
return True return True
return False else:
return False
def check_pos_validity(self, position):
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
return True
else:
return False

View File

@ -15,7 +15,7 @@ ENTITIES = 'Objects'
OBSERVATIONS = 'Observations' OBSERVATIONS = 'Observations'
RULES = 'Rule' RULES = 'Rule'
ASSETS = 'Assets' ASSETS = 'Assets'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls', EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ] 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]

View File

@ -1,4 +1,8 @@
from dataclasses import dataclass
from typing import Any
import gymnasium as gym import gymnasium as gym
import numpy as np
class MarlFrameStack(gym.ObservationWrapper): class MarlFrameStack(gym.ObservationWrapper):
@ -10,3 +14,37 @@ class MarlFrameStack(gym.ObservationWrapper):
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1: if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
return observation[0:].swapaxes(0, 1) return observation[0:].swapaxes(0, 1)
return observation return observation
@dataclass
class RenderEntity:
name: str
pos: np.array
value: float = 1
value_operation: str = 'none'
state: str = None
id: int = 0
aux: Any = None
real_name: str = 'none'
@dataclass
class Floor:
@property
def name(self):
return f"Floor({self.pos})"
@property
def pos(self):
return self.x, self.y
x: int
y: int
var_is_blocking_light: bool = False
def __eq__(self, other):
return self.name == other.name
def __hash__(self):
return hash(self.name)