mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
no more tiles no more floor
This commit is contained in:
@ -21,7 +21,7 @@ class TSPBaseAgent(ABC):
|
||||
self.local_optimization = True
|
||||
self._env = state
|
||||
self.state = self._env.state[c.AGENT][agent_i]
|
||||
self._floortile_graph = points_to_graph(self._env[c.FLOORS].positions)
|
||||
self._position_graph = points_to_graph(self._env.entities.floorlist)
|
||||
self._static_route = None
|
||||
|
||||
@abstractmethod
|
||||
@ -50,7 +50,7 @@ class TSPBaseAgent(ABC):
|
||||
|
||||
else:
|
||||
nodes = [self.state.pos] + positions
|
||||
route = tsp.traveling_salesman_problem(self._floortile_graph,
|
||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
return route
|
||||
|
||||
|
@ -4,17 +4,17 @@ import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
"""
|
||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||
There are three combinations of settings:
|
||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
Allow only manhattan: Distance(a, b) == 1
|
||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||
Allow only Euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
|
||||
:param coordiniates_or_tiles: A set of coordinates.
|
||||
:type coordiniates_or_tiles: Tiles
|
||||
:param coordiniates: A set of coordinates.
|
||||
:type coordiniates: Tuple[int, int]
|
||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||
:type: bool
|
||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||
@ -24,9 +24,7 @@ def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, all
|
||||
:rtype: nx.Graph
|
||||
"""
|
||||
assert allow_euclidean_connections or allow_manhattan_connections
|
||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||
possible_connections = itertools.combinations(coordiniates, 2)
|
||||
graph = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||
|
@ -66,7 +66,6 @@ Rules:
|
||||
DestinationDone: {}
|
||||
DestinationReach:
|
||||
n_dests: 1
|
||||
tiles: null
|
||||
DestinationSpawn:
|
||||
n_dests: 1
|
||||
spawn_frequency: 5
|
||||
|
@ -42,12 +42,12 @@ class Move(Action, abc.ABC):
|
||||
|
||||
def do(self, entity, state):
|
||||
new_pos = self._calc_new_pos(entity.pos)
|
||||
if state.check_move_validity(entity, new_pos): # next_tile := state[c.FLOOR].by_pos(new_pos):
|
||||
if state.check_move_validity(entity, new_pos):
|
||||
# noinspection PyUnresolvedReferences
|
||||
move_validity = entity.move(new_pos, state)
|
||||
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
|
||||
else: # There is no floor, propably collision
|
||||
else: # There is no place to go, propably collision
|
||||
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
|
||||
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
|
||||
|
@ -3,15 +3,13 @@ DANGER_ZONE = 'x' # Dange Zone tile _identifier fo
|
||||
DEFAULTS = 'Defaults'
|
||||
SELF = 'Self'
|
||||
PLACEHOLDER = 'Placeholder'
|
||||
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
|
||||
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
|
||||
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||
OTHERS = 'Other'
|
||||
COMBINED = 'Combined'
|
||||
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||
|
||||
# Attributes
|
||||
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
||||
@ -32,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
|
||||
|
||||
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
||||
COLLISION = 'Collision' # Identifier to use in the context of collitions.
|
||||
LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
||||
|
||||
# Actions
|
||||
|
@ -2,7 +2,7 @@ from typing import List, Union
|
||||
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.utils import renderer
|
||||
from marl_factory_grid.utils.helpers import is_move
|
||||
from marl_factory_grid.utils.results import ActionResult, Result
|
||||
|
@ -1,8 +1,10 @@
|
||||
import abc
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import constants as c
|
||||
from .object import EnvObject
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.utility_classes import RenderEntity
|
||||
from ...utils.results import ActionResult
|
||||
|
||||
|
||||
@ -30,33 +32,32 @@ class Entity(EnvObject, abc.ABC):
|
||||
return self._pos
|
||||
|
||||
@property
|
||||
def tile(self):
|
||||
return self._tile # wall_n_floors funktionalität
|
||||
|
||||
# @property
|
||||
# def last_tile(self):
|
||||
# try:
|
||||
# return self._last_tile
|
||||
# except AttributeError:
|
||||
# # noinspection PyAttributeOutsideInit
|
||||
# self._last_tile = None
|
||||
# return self._last_tile
|
||||
def last_pos(self):
|
||||
try:
|
||||
return self._last_pos
|
||||
except AttributeError:
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._last_pos = c.VALUE_NO_POS
|
||||
return self._last_pos
|
||||
|
||||
@property
|
||||
def direction_of_view(self):
|
||||
last_x, last_y = self._last_pos
|
||||
curr_x, curr_y = self.pos
|
||||
return last_x - curr_x, last_y - curr_y
|
||||
if self._last_pos != c.VALUE_NO_POS:
|
||||
return 0, 0
|
||||
else:
|
||||
return np.subtract(self._last_pos, self.pos)
|
||||
|
||||
def move(self, next_pos, state):
|
||||
next_pos = next_pos
|
||||
curr_pos = self._pos
|
||||
if not_same_pos := curr_pos != next_pos:
|
||||
if valid := state.check_move_validity(self, next_pos):
|
||||
self._pos = next_pos
|
||||
self._last_pos = curr_pos
|
||||
for observer in self.observers:
|
||||
observer.notify_change_pos(self)
|
||||
observer.notify_del_entity(self)
|
||||
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
|
||||
self._pos = next_pos
|
||||
for observer in self.observers:
|
||||
observer.notify_add_entity(self)
|
||||
return valid
|
||||
return not_same_pos
|
||||
|
||||
@ -64,6 +65,7 @@ class Entity(EnvObject, abc.ABC):
|
||||
super().__init__(**kwargs)
|
||||
self._status = None
|
||||
self._pos = pos
|
||||
self._last_pos = pos
|
||||
if bind_to:
|
||||
try:
|
||||
self.bind_to(bind_to)
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
@ -30,17 +30,6 @@ class Floor(EnvObject):
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def neighboring_floor(self):
|
||||
if self._neighboring_floor:
|
||||
pass
|
||||
else:
|
||||
self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos))
|
||||
for pos in h.POS_MASK.reshape(-1, 2)
|
||||
if not np.all(pos == [0, 0])]
|
||||
if x]
|
||||
return self._neighboring_floor
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
@ -197,7 +197,7 @@ class Factory(gym.Env):
|
||||
del rewards['global']
|
||||
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
|
||||
reward = [x + global_rewards for x in reward]
|
||||
self.state.print(f"rewards are {rewards}")
|
||||
self.state.print(f"Individual rewards are {dict(rewards)}")
|
||||
return reward, combined_info_dict, done
|
||||
else:
|
||||
reward = sum(rewards.values())
|
||||
@ -220,7 +220,7 @@ class Factory(gym.Env):
|
||||
|
||||
def summarize_header(self):
|
||||
header = {'rec_step': self.state.curr_step}
|
||||
for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']):
|
||||
for entity_group in (x for x in self.state if x.name in ['Walls', 'DropOffLocations', 'ChargePods']):
|
||||
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
|
||||
return header
|
||||
|
||||
@ -229,7 +229,7 @@ class Factory(gym.Env):
|
||||
|
||||
# Todo: Protobuff Compatibility Section #######
|
||||
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
|
||||
for entity_group in (x for x in self.state if x.name not in [c.FLOORS]):
|
||||
for entity_group in self.state:
|
||||
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
|
||||
# TODO Section End ########
|
||||
for key in list(summary.keys()):
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from random import shuffle
|
||||
from typing import Dict
|
||||
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
@ -13,7 +14,7 @@ class Entities(Objects):
|
||||
def neighboring_positions(pos):
|
||||
return (POS_MASK + pos).reshape(-1, 2)
|
||||
|
||||
def get_near_pos(self, pos):
|
||||
def get_entities_near_pos(self, pos):
|
||||
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
|
||||
|
||||
def render(self):
|
||||
@ -38,11 +39,17 @@ class Entities(Objects):
|
||||
def guests_that_can_collide(self, pos):
|
||||
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||
|
||||
def empty_tiles(self):
|
||||
return[key for key in self.floorlist if not any(self.pos_dict[key])]
|
||||
@property
|
||||
def empty_positions(self):
|
||||
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
def occupied_tiles(self): # positions that are not empty
|
||||
return[key for key in self.floorlist if any(self.pos_dict[key])]
|
||||
@property
|
||||
def occupied_positions(self): # positions that are not empty
|
||||
empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
def is_blocked(self):
|
||||
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
||||
|
@ -37,7 +37,11 @@ class PositionMixin:
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj) # observer notify?
|
||||
try:
|
||||
for observer in obj.observers:
|
||||
observer.notify_del_entity(obj)
|
||||
except AttributeError:
|
||||
pass
|
||||
super().__delitem__(name)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
|
@ -103,6 +103,9 @@ class Objects:
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_name(self, name):
|
||||
return next(x for x in self if x.name == name)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
@ -120,7 +123,7 @@ class Objects:
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS, c.FLOORS]}
|
||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
|
||||
def spawn(self, n: int):
|
||||
@ -132,22 +135,25 @@ class Objects:
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
def notify_change_pos(self, entity: object):
|
||||
try:
|
||||
self.pos_dict[entity.last_pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
if entity.var_has_position:
|
||||
try:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
# def notify_change_pos(self, entity: object):
|
||||
# try:
|
||||
# self.pos_dict[entity.last_pos].remove(entity)
|
||||
# except (ValueError, AttributeError):
|
||||
# pass
|
||||
# if entity.var_has_position:
|
||||
# try:
|
||||
# self.pos_dict[entity.pos].append(entity)
|
||||
# except (ValueError, AttributeError):
|
||||
# pass
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
entity.del_observer(self)
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
except (AttributeError, ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: Object):
|
||||
|
@ -15,6 +15,7 @@ class Walls(PositionMixin, EnvObjects):
|
||||
super(Walls, self).__init__(*args, **kwargs)
|
||||
self._value = c.VALUE_OCCUPIED_CELL
|
||||
|
||||
#ToDo: Do we need this? Move to spawn methode?
|
||||
# @classmethod
|
||||
# def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
# tiles = cls(*args, **kwargs)
|
@ -49,7 +49,7 @@ class SpawnAgents(Rule):
|
||||
agent_conf = state.agents_conf
|
||||
# agents = Agents(lvl_map.size)
|
||||
agents = state[c.AGENT]
|
||||
empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)]
|
||||
empty_positions = state.entities.empty_positions[:len(agent_conf)]
|
||||
for agent_name in agent_conf:
|
||||
actions = agent_conf[agent_name]['actions'].copy()
|
||||
observations = agent_conf[agent_name]['observations'].copy()
|
||||
@ -58,18 +58,17 @@ class SpawnAgents(Rule):
|
||||
shuffle(positions)
|
||||
while True:
|
||||
try:
|
||||
tile = state[c.FLOORS].by_pos(positions.pop())
|
||||
pos = positions.pop()
|
||||
except IndexError as e:
|
||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||
f'\n{agent_name[agent_name]["positions"].copy()}')
|
||||
try:
|
||||
agents.add_item(Agent(actions, observations, tile, str_ident=agent_name))
|
||||
except AssertionError:
|
||||
state.print(f'No valid pos:{tile.pos} for {agent_name}')
|
||||
if agents.by_pos(pos) and state.check_pos_validity(pos):
|
||||
continue
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
|
||||
break
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name))
|
||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
|
||||
pass
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@ from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
|
||||
from marl_factory_grid.modules.batteries import constants as b
|
||||
|
||||
|
@ -70,8 +70,8 @@ class PodRules(Rule):
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
pod_collection = state[b.CHARGE_PODS]
|
||||
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_pods]
|
||||
pods = pod_collection.from_coordinates(empty_tiles, entity_kwargs=dict(
|
||||
empty_positions = state.entities.empty_positions()
|
||||
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
|
||||
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
||||
)
|
||||
pod_collection.add_items(pods)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from numpy import random
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.modules.clean_up import constants as d
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
@ -31,8 +30,6 @@ class DirtPiles(PositionMixin, EnvObjects):
|
||||
self.max_local_amount = max_local_amount
|
||||
|
||||
def spawn(self, then_dirty_positions, amount) -> bool:
|
||||
# if isinstance(then_dirty_tiles, Floor):
|
||||
# then_dirty_tiles = [then_dirty_tiles]
|
||||
for pos in then_dirty_positions:
|
||||
if not self.amount > self.max_global_amount:
|
||||
if dirt := self.by_pos(pos):
|
||||
@ -56,8 +53,8 @@ class DirtPiles(PositionMixin, EnvObjects):
|
||||
|
||||
var = self.dirt_spawn_r_var
|
||||
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
|
||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||
return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
||||
n_dirty_positions = max(0, int(new_spawn * len(free_for_dirt)))
|
||||
return self.spawn(free_for_dirt[:n_dirty_positions], self.initial_amount)
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtPiles, self).__repr__()
|
||||
|
@ -4,7 +4,7 @@ from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
|
||||
|
||||
@ -17,7 +17,6 @@ class Destination(BoundEntityMixin, Entity):
|
||||
var_is_blocking_light = False
|
||||
var_can_be_bound = True # Introduce this globally!
|
||||
|
||||
@property
|
||||
def was_reached(self):
|
||||
return self._was_reached
|
||||
|
||||
@ -35,11 +34,10 @@ class Destination(BoundEntityMixin, Entity):
|
||||
self._per_agent_actions[agent.name] += 1
|
||||
return c.VALID
|
||||
|
||||
@property
|
||||
def has_just_been_reached(self):
|
||||
if self.was_reached:
|
||||
def has_just_been_reached(self, state):
|
||||
if self.was_reached():
|
||||
return False
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in state.entities.pos_dict[self.pos] if x.var_can_collide)
|
||||
agent_at_position = any(state[c.AGENT].by_pos(self.pos))
|
||||
|
||||
if self.bound_entity:
|
||||
return ((agent_at_position and not self.action_counts)
|
||||
@ -57,7 +55,7 @@ class Destination(BoundEntityMixin, Entity):
|
||||
return state_summary
|
||||
|
||||
def render(self):
|
||||
if self.was_reached:
|
||||
if self.was_reached():
|
||||
return None
|
||||
else:
|
||||
return RenderEntity(d.DESTINATION, self.pos)
|
||||
|
@ -16,28 +16,29 @@ class DestinationReachAll(Rule):
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
results = []
|
||||
reached = False
|
||||
for dest in state[d.DESTINATION]:
|
||||
if dest.has_just_been_reached and not dest.was_reached:
|
||||
# Dest has just been reached, some agent needs to stand here, grab any first.
|
||||
if dest.has_just_been_reached(state) and not dest.was_reached():
|
||||
# Dest has just been reached, some agent needs to stand here
|
||||
for agent in state[c.AGENT].by_pos(dest.pos):
|
||||
if dest.bound_entity:
|
||||
if dest.bound_entity == agent:
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
reached = True
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
state.print(f'{dest.name} is reached now, mark as reached...')
|
||||
dest.mark_as_reached()
|
||||
reached = True
|
||||
else:
|
||||
pass
|
||||
if reached:
|
||||
state.print(f'{dest.name} is reached now, mark as reached...')
|
||||
dest.mark_as_reached()
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
return results
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if all(x.was_reached for x in state[d.DESTINATION]):
|
||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
||||
|
||||
@ -48,7 +49,7 @@ class DestinationReachAny(DestinationReachAll):
|
||||
super(DestinationReachAny, self).__init__()
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if any(x.was_reached for x in state[d.DESTINATION]):
|
||||
if any(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||
return []
|
||||
|
||||
@ -63,7 +64,7 @@ class DestinationSpawn(Rule):
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.trigger_destination_spawn(self.n_dests, state)
|
||||
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
@ -72,24 +73,14 @@ class DestinationSpawn(Rule):
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
|
||||
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
|
||||
validity = self.trigger_destination_spawn(n_dest_spawn, state)
|
||||
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
|
||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
||||
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
|
||||
validity = self.trigger_destination_spawn(n_dest_spawn, state)
|
||||
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
|
||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
||||
else:
|
||||
pass
|
||||
|
||||
def trigger_destination_spawn(self, n_dests, state):
|
||||
empty_positions = state[c.FLOORS].empty_tiles[:n_dests]
|
||||
if destinations := [Destination(pos) for pos in empty_positions]:
|
||||
state[d.DESTINATION].add_items(destinations)
|
||||
state.print(f'{n_dests} new destinations have been spawned')
|
||||
return c.VALID
|
||||
else:
|
||||
state.print('No Destiantions are spawning, limit is reached.')
|
||||
return c.NOT_VALID
|
||||
|
||||
|
||||
class FixedDestinationSpawn(Rule):
|
||||
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
|
||||
@ -99,11 +90,17 @@ class FixedDestinationSpawn(Rule):
|
||||
def on_init(self, state, lvl_map):
|
||||
for (agent_name, position_list) in self.per_agent_positions.items():
|
||||
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
|
||||
position_list = position_list.copy()
|
||||
shuffle(position_list)
|
||||
while True:
|
||||
pos = position_list.pop()
|
||||
if pos != agent.pos and not state[d.DESTINATION].by_pos(pos):
|
||||
destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent)
|
||||
try:
|
||||
pos = position_list.pop()
|
||||
except IndexError:
|
||||
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
|
||||
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
|
||||
exit(9999)
|
||||
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
|
||||
destination = Destination(pos, bind_to=agent)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .actions import DoorUse
|
||||
from .entitites import Door, DoorIndicator
|
||||
from .groups import Doors
|
||||
from .rule_door_auto_close import DoorAutoClose
|
||||
from .rules import DoorAutoClose, DoorIndicateArea
|
||||
|
@ -13,8 +13,9 @@ class DoorUse(Action):
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
# Check if agent really is standing on a door:
|
||||
e = state.entities.get_near_pos(entity.pos)
|
||||
e = state.entities.get_entities_near_pos(entity.pos)
|
||||
try:
|
||||
# Only one door opens TODO introcude loop
|
||||
door = next(x for x in e if x.name.startswith(d.DOOR))
|
||||
valid = door.use()
|
||||
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
|
||||
|
@ -1,5 +1,5 @@
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.modules.doors import constants as d
|
||||
@ -41,7 +41,7 @@ class Door(Entity):
|
||||
def str_state(self):
|
||||
return 'open' if self.is_open else 'closed'
|
||||
|
||||
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
|
||||
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
|
||||
self._status = d.STATE_CLOSED
|
||||
super(Door, self).__init__(*args, **kwargs)
|
||||
self.auto_close_interval = auto_close_interval
|
||||
@ -50,8 +50,6 @@ class Door(Entity):
|
||||
self._open()
|
||||
else:
|
||||
self._close()
|
||||
if indicate_area:
|
||||
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
|
@ -1,7 +1,8 @@
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
from marl_factory_grid.modules.doors import constants as d
|
||||
from . import constants as d
|
||||
from .entitites import DoorIndicator
|
||||
|
||||
|
||||
class DoorAutoClose(Rule):
|
||||
@ -19,3 +20,13 @@ class DoorAutoClose(Rule):
|
||||
return [TickResult(self.name, validity=c.VALID, value=0)]
|
||||
state.print('There are no doors, but you loaded the corresponding Module')
|
||||
return []
|
||||
|
||||
|
||||
class DoorIndicateArea(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
for door in state[d.DOORS]:
|
||||
state[d.DOORS].add_items([DoorIndicator(x) for x in state.entities.neighboring_positions(door.pos)])
|
@ -9,6 +9,8 @@ from marl_factory_grid.utils.results import TickResult
|
||||
class AgentSingleZonePlacementBeta(Rule):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError()
|
||||
# TODO!!!! Is this concept needed any more?
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
@ -21,9 +23,9 @@ class AgentSingleZonePlacementBeta(Rule):
|
||||
coordinates = random.choices(self.coordinates, k=len(agents))
|
||||
else:
|
||||
raise ValueError
|
||||
tiles = [state[c.FLOORS].by_pos(pos) for pos in coordinates]
|
||||
for agent, tile in zip(agents, tiles):
|
||||
agent.move(tile, state)
|
||||
|
||||
for agent, pos in zip(agents, coordinates):
|
||||
agent.move(pos, state)
|
||||
|
||||
def tick_step(self, state):
|
||||
return []
|
||||
|
@ -2,7 +2,7 @@ from collections import deque
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from random import shuffle
|
||||
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
@ -19,10 +21,12 @@ class Items(PositionMixin, EnvObjects):
|
||||
@staticmethod
|
||||
def trigger_item_spawn(state, n_items, spawn_frequency):
|
||||
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
|
||||
floor_list = state.entities.floorlist[:item_to_spawns]
|
||||
state[i.ITEM].spawn(floor_list)
|
||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') # spawn in self._next_item_spawn ?
|
||||
return len(floor_list)
|
||||
position_list = [x for x in state.entities.floorlist]
|
||||
shuffle(position_list)
|
||||
position_list = state.entities.floorlist[:item_to_spawns]
|
||||
state[i.ITEM].spawn(position_list)
|
||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}')
|
||||
return len(position_list)
|
||||
else:
|
||||
state.print('No Items are spawning, limit is reached.')
|
||||
return 0
|
||||
@ -100,7 +104,7 @@ class DropOffLocations(PositionMixin, EnvObjects):
|
||||
|
||||
@staticmethod
|
||||
def trigger_drop_off_location_spawn(state, n_locations):
|
||||
empty_tiles = state.entities.floorlist[:n_locations]
|
||||
empty_positions = state.entities.empty_positions()[:n_locations]
|
||||
do_entites = state[i.DROP_OFF]
|
||||
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
|
||||
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
||||
do_entites.add_items(drop_offs)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from ...utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
|
||||
|
@ -13,8 +13,8 @@ class MachineRule(Rule):
|
||||
self.n_machines = n_machines
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_machines]
|
||||
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles)
|
||||
# TODO Move to spawn!!!
|
||||
state[m.MACHINES].add_items(Machine(pos) for pos in state.entities.empty_positions())
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
@ -8,7 +8,7 @@ from ...environment.entity.entity import Entity
|
||||
from ..doors import constants as do
|
||||
from ..maintenance import constants as mi
|
||||
from ...utils.helpers import MOVEMAP
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.utility_classes import RenderEntity
|
||||
from ...utils.states import Gamestate
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class Maintainer(Entity):
|
||||
self._next = []
|
||||
self._last = []
|
||||
self._last_serviced = 'None'
|
||||
self._floortile_graph = points_to_graph(state[c.FLOORS].positions)
|
||||
self._floortile_graph = points_to_graph(state.entities.floorlist)
|
||||
|
||||
def tick(self, state):
|
||||
if found_objective := state[self.objective].by_pos(self.pos):
|
||||
|
@ -14,7 +14,8 @@ class MaintenanceRule(Rule):
|
||||
self.n_maintainer = n_maintainer
|
||||
|
||||
def on_init(self, state: Gamestate, lvl_map):
|
||||
state[M.MAINTAINERS].spawn(state[c.FLOORS].empty_tiles[:self.n_maintainer], state)
|
||||
# Move to spawn? : #TODO
|
||||
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
|
@ -3,8 +3,7 @@ from typing import List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.modules.doors import constants as d
|
||||
@ -21,5 +20,5 @@ class Zone(Object):
|
||||
self.coords = coords
|
||||
|
||||
@property
|
||||
def random_tile(self):
|
||||
def random_pos(self):
|
||||
return random.choice(self.coords)
|
||||
|
@ -19,7 +19,7 @@ class ZoneInit(Rule):
|
||||
while z_idx:
|
||||
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||
if len(zone_positions):
|
||||
zones.append(Zone([state[c.FLOORS].by_pos(pos) for pos in zone_positions]))
|
||||
zones.append(Zone(zone_positions))
|
||||
z_idx += 1
|
||||
else:
|
||||
z_idx = 0
|
||||
@ -38,7 +38,7 @@ class AgentSingleZonePlacement(Rule):
|
||||
|
||||
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
||||
for agent in state[c.AGENT]:
|
||||
agent.move(state[z.ZONES][z_idxs.pop()].random_tile, state)
|
||||
agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state)
|
||||
return []
|
||||
|
||||
def tick_step(self, state):
|
||||
@ -65,10 +65,10 @@ class IndividualDestinationZonePlacement(Rule):
|
||||
other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
|
||||
already_has_destination = True
|
||||
while already_has_destination:
|
||||
tile = choice(other_zones).random_tile
|
||||
if state[d.DESTINATION].by_pos(tile.pos) is None:
|
||||
pos = choice(other_zones).random_pos
|
||||
if state[d.DESTINATION].by_pos(pos) is None:
|
||||
already_has_destination = False
|
||||
destination = Destination(tile, bind_to=agent)
|
||||
destination = Destination(pos, bind_to=agent)
|
||||
|
||||
state[d.DESTINATION].add_item(destination)
|
||||
continue
|
||||
|
@ -25,10 +25,8 @@ This file is used for:
|
||||
LEVELS_DIR = 'modules/levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
# Not used anymore? Clean!
|
||||
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
|
||||
'episode']
|
||||
|
||||
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
||||
@ -223,7 +221,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.groups.agents import Agents
|
||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
||||
from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors
|
||||
from marl_factory_grid.environment.groups.walls import Walls
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
@ -34,16 +34,14 @@ class LevelParser(object):
|
||||
|
||||
def do_init(self):
|
||||
# Global Entities
|
||||
list_of_all_floors = ([tuple(floor) for floor in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
||||
entities = Entities(list_of_all_floors)
|
||||
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
||||
entities = Entities(list_of_all_positions)
|
||||
|
||||
# Walls
|
||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||
entities.add_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_coordinates(list_of_all_floors, self.size)
|
||||
entities.add_items({c.FLOOR: floor})
|
||||
# Agents
|
||||
entities.add_items({c.AGENT: Agents(self.size)})
|
||||
|
||||
# All other
|
||||
|
@ -9,6 +9,7 @@ from numba import njit
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
from marl_factory_grid.utils.utility_classes import Floor
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
@ -39,6 +40,7 @@ class OBSBuilder(object):
|
||||
|
||||
self.reset_struc_obs_block(state)
|
||||
self.curr_lightmaps = dict()
|
||||
self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist})
|
||||
|
||||
def reset_struc_obs_block(self, state):
|
||||
self._curr_env_step = state.curr_step
|
||||
@ -82,19 +84,23 @@ class OBSBuilder(object):
|
||||
self._sort_and_name_observation_conf(agent)
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
|
||||
# Handle in-grid observations aka visible observations
|
||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities)
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros((self.pomdp_d, self.pomdp_d)))
|
||||
for e in set(visible_entitites):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out or range
|
||||
pass
|
||||
# Handle in-grid observations aka visible observations (Things on the map, with pos)
|
||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict)
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||
if self.pomdp_r:
|
||||
for e in set(visible_entitites):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out or range
|
||||
pass
|
||||
else:
|
||||
for e in set(visible_entitites):
|
||||
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||
|
||||
pre_sort_obs = dict(pre_sort_obs)
|
||||
obs = np.zeros((len(agent_want_obs), self.pomdp_d, self.pomdp_d))
|
||||
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
||||
|
||||
for idx, l_name in enumerate(agent_want_obs):
|
||||
try:
|
||||
@ -144,13 +150,26 @@ class OBSBuilder(object):
|
||||
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
||||
|
||||
try:
|
||||
self.curr_lightmaps[agent.name] = pre_sort_obs[c.FLOORS].astype(bool)
|
||||
light_map = np.zeros(self.obs_shape)
|
||||
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False))
|
||||
if self.pomdp_r:
|
||||
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor]
|
||||
else:
|
||||
coords = [x.pos for x in visible_floor]
|
||||
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1)
|
||||
self.curr_lightmaps[agent.name] = light_map
|
||||
except KeyError:
|
||||
print()
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
self.ray_caster[agent.name] = RayCaster(agent, self.pomdp_r)
|
||||
'''
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
'''
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
obs_layers = []
|
||||
|
||||
for obs_str in agent.observations:
|
||||
@ -173,7 +192,7 @@ class OBSBuilder(object):
|
||||
names.extend([x.name for x in agent.collection if x.name != agent.name])
|
||||
else:
|
||||
names.append(val)
|
||||
combined = Combined(names, self.pomdp_r, identifier=agent.name)
|
||||
combined = Combined(names, self.size, identifier=agent.name)
|
||||
self.all_obs[combined.name] = combined
|
||||
obs_layers.append(combined.name)
|
||||
elif obs_str == c.OTHERS:
|
||||
@ -183,19 +202,18 @@ class OBSBuilder(object):
|
||||
else:
|
||||
obs_layers.append(obs_str)
|
||||
self.obs_layers[agent.name] = obs_layers
|
||||
self.curr_lightmaps[agent.name] = np.zeros((self.pomdp_d or self.level_shape[0],
|
||||
self.pomdp_d or self.level_shape[1]
|
||||
))
|
||||
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
|
||||
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||
self.n_rays = (self.pomdp_r + 1) * 8
|
||||
self.degs = degs
|
||||
self.ray_targets = self.build_ray_targets()
|
||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||
self._cache_dict = {}
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
@ -211,30 +229,30 @@ class RayCaster:
|
||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, cache_dict, key, callback):
|
||||
if key not in cache_dict:
|
||||
cache_dict[key] = callback()
|
||||
return cache_dict[key]
|
||||
def ray_block_cache(self, key, callback):
|
||||
if key not in self._cache_dict:
|
||||
self._cache_dict[key] = callback()
|
||||
return self._cache_dict[key]
|
||||
|
||||
def visible_entities(self, entities):
|
||||
def visible_entities(self, pos_dict, reset_cache=True):
|
||||
visible = list()
|
||||
cache_blocking = {}
|
||||
if reset_cache:
|
||||
self._cache_dict = {}
|
||||
|
||||
for ray in self.get_rays():
|
||||
rx, ry = ray[0]
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
entities_hit = entities.pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache(cache_blocking,
|
||||
(x, y),
|
||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light))
|
||||
entities_hit = pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache((x, y),
|
||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
||||
)
|
||||
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
cache_blocking,
|
||||
key,
|
||||
lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light))
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(pos_dict[key]))
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
|
@ -1,16 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderEntity:
|
||||
name: str
|
||||
pos: np.array
|
||||
value: float = 1
|
||||
value_operation: str = 'none'
|
||||
state: str = None
|
||||
id: int = 0
|
||||
aux: Any = None
|
||||
real_name: str = 'none'
|
@ -9,7 +9,7 @@ import pygame
|
||||
from typing import Tuple, Union
|
||||
import time
|
||||
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
|
||||
AGENT: str = 'agent'
|
||||
STATE_IDLE: str = 'idle'
|
||||
|
@ -3,8 +3,6 @@ from typing import List, Dict, Tuple
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
@ -112,15 +110,10 @@ class Gamestate(object):
|
||||
results.extend(on_check_done_result)
|
||||
return results
|
||||
|
||||
# def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
# tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
# if sum([x.var_can_collide for x in e]) > 1]
|
||||
# # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
# return tiles
|
||||
|
||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||
positions = [pos for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.var_can_collide for x in e]) > 1]
|
||||
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
|
||||
if any([e.var_can_collide for e in entity_list_for_position])]
|
||||
return positions
|
||||
|
||||
def check_move_validity(self, moving_entity, position):
|
||||
@ -128,6 +121,14 @@ class Gamestate(object):
|
||||
# and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||
if moving_entity.pos != position and not any(
|
||||
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
||||
moving_entity.var_is_blocking_pos and moving_entity.is_occupied()):
|
||||
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def check_pos_validity(self, position):
|
||||
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -15,7 +15,7 @@ ENTITIES = 'Objects'
|
||||
OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
ASSETS = 'Assets'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls',
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
|
||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||
|
||||
|
||||
|
@ -1,4 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
@ -10,3 +14,37 @@ class MarlFrameStack(gym.ObservationWrapper):
|
||||
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||
return observation[0:].swapaxes(0, 1)
|
||||
return observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderEntity:
|
||||
name: str
|
||||
pos: np.array
|
||||
value: float = 1
|
||||
value_operation: str = 'none'
|
||||
state: str = None
|
||||
id: int = 0
|
||||
aux: Any = None
|
||||
real_name: str = 'none'
|
||||
|
||||
|
||||
@dataclass
|
||||
class Floor:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f"Floor({self.pos})"
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return self.x, self.y
|
||||
|
||||
x: int
|
||||
y: int
|
||||
var_is_blocking_light: bool = False
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.name == other.name
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
Reference in New Issue
Block a user