mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
no more tiles no more floor
This commit is contained in:
@ -21,7 +21,7 @@ class TSPBaseAgent(ABC):
|
|||||||
self.local_optimization = True
|
self.local_optimization = True
|
||||||
self._env = state
|
self._env = state
|
||||||
self.state = self._env.state[c.AGENT][agent_i]
|
self.state = self._env.state[c.AGENT][agent_i]
|
||||||
self._floortile_graph = points_to_graph(self._env[c.FLOORS].positions)
|
self._position_graph = points_to_graph(self._env.entities.floorlist)
|
||||||
self._static_route = None
|
self._static_route = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -50,7 +50,7 @@ class TSPBaseAgent(ABC):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
nodes = [self.state.pos] + positions
|
nodes = [self.state.pos] + positions
|
||||||
route = tsp.traveling_salesman_problem(self._floortile_graph,
|
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||||
return route
|
return route
|
||||||
|
|
||||||
|
@ -4,17 +4,17 @@ import networkx as nx
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||||
"""
|
"""
|
||||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||||
There are three combinations of settings:
|
There are three combinations of settings:
|
||||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||||
Allow only manhattan: Distance(a, b) == 1
|
Allow only manhattan: Distance(a, b) == 1
|
||||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
Allow only Euclidean: Distance(a, b) == sqrt(2)
|
||||||
|
|
||||||
|
|
||||||
:param coordiniates_or_tiles: A set of coordinates.
|
:param coordiniates: A set of coordinates.
|
||||||
:type coordiniates_or_tiles: Tiles
|
:type coordiniates: Tuple[int, int]
|
||||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||||
:type: bool
|
:type: bool
|
||||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||||
@ -24,9 +24,7 @@ def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, all
|
|||||||
:rtype: nx.Graph
|
:rtype: nx.Graph
|
||||||
"""
|
"""
|
||||||
assert allow_euclidean_connections or allow_manhattan_connections
|
assert allow_euclidean_connections or allow_manhattan_connections
|
||||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
possible_connections = itertools.combinations(coordiniates, 2)
|
||||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
|
||||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
|
||||||
graph = nx.Graph()
|
graph = nx.Graph()
|
||||||
for a, b in possible_connections:
|
for a, b in possible_connections:
|
||||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||||
|
@ -66,7 +66,6 @@ Rules:
|
|||||||
DestinationDone: {}
|
DestinationDone: {}
|
||||||
DestinationReach:
|
DestinationReach:
|
||||||
n_dests: 1
|
n_dests: 1
|
||||||
tiles: null
|
|
||||||
DestinationSpawn:
|
DestinationSpawn:
|
||||||
n_dests: 1
|
n_dests: 1
|
||||||
spawn_frequency: 5
|
spawn_frequency: 5
|
||||||
|
@ -42,12 +42,12 @@ class Move(Action, abc.ABC):
|
|||||||
|
|
||||||
def do(self, entity, state):
|
def do(self, entity, state):
|
||||||
new_pos = self._calc_new_pos(entity.pos)
|
new_pos = self._calc_new_pos(entity.pos)
|
||||||
if state.check_move_validity(entity, new_pos): # next_tile := state[c.FLOOR].by_pos(new_pos):
|
if state.check_move_validity(entity, new_pos):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
move_validity = entity.move(new_pos, state)
|
move_validity = entity.move(new_pos, state)
|
||||||
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
|
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
|
||||||
else: # There is no floor, propably collision
|
else: # There is no place to go, propably collision
|
||||||
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
|
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
|
||||||
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
|
||||||
|
@ -3,15 +3,13 @@ DANGER_ZONE = 'x' # Dange Zone tile _identifier fo
|
|||||||
DEFAULTS = 'Defaults'
|
DEFAULTS = 'Defaults'
|
||||||
SELF = 'Self'
|
SELF = 'Self'
|
||||||
PLACEHOLDER = 'Placeholder'
|
PLACEHOLDER = 'Placeholder'
|
||||||
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
|
|
||||||
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
|
|
||||||
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
||||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||||
OTHERS = 'Other'
|
OTHERS = 'Other'
|
||||||
COMBINED = 'Combined'
|
COMBINED = 'Combined'
|
||||||
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||||
|
|
||||||
# Attributes
|
# Attributes
|
||||||
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
||||||
@ -32,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
|
|||||||
|
|
||||||
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
||||||
COLLISION = 'Collision' # Identifier to use in the context of collitions.
|
COLLISION = 'Collision' # Identifier to use in the context of collitions.
|
||||||
LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||||
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
|
@ -2,7 +2,7 @@ from typing import List, Union
|
|||||||
|
|
||||||
from marl_factory_grid.environment.actions import Action
|
from marl_factory_grid.environment.actions import Action
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.utils import renderer
|
from marl_factory_grid.utils import renderer
|
||||||
from marl_factory_grid.utils.helpers import is_move
|
from marl_factory_grid.utils.helpers import is_move
|
||||||
from marl_factory_grid.utils.results import ActionResult, Result
|
from marl_factory_grid.utils.results import ActionResult, Result
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import abc
|
import abc
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from .. import constants as c
|
from .. import constants as c
|
||||||
from .object import EnvObject
|
from .object import EnvObject
|
||||||
from ...utils.render import RenderEntity
|
from ...utils.utility_classes import RenderEntity
|
||||||
from ...utils.results import ActionResult
|
from ...utils.results import ActionResult
|
||||||
|
|
||||||
|
|
||||||
@ -30,33 +32,32 @@ class Entity(EnvObject, abc.ABC):
|
|||||||
return self._pos
|
return self._pos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tile(self):
|
def last_pos(self):
|
||||||
return self._tile # wall_n_floors funktionalität
|
try:
|
||||||
|
return self._last_pos
|
||||||
# @property
|
except AttributeError:
|
||||||
# def last_tile(self):
|
# noinspection PyAttributeOutsideInit
|
||||||
# try:
|
self._last_pos = c.VALUE_NO_POS
|
||||||
# return self._last_tile
|
return self._last_pos
|
||||||
# except AttributeError:
|
|
||||||
# # noinspection PyAttributeOutsideInit
|
|
||||||
# self._last_tile = None
|
|
||||||
# return self._last_tile
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def direction_of_view(self):
|
def direction_of_view(self):
|
||||||
last_x, last_y = self._last_pos
|
if self._last_pos != c.VALUE_NO_POS:
|
||||||
curr_x, curr_y = self.pos
|
return 0, 0
|
||||||
return last_x - curr_x, last_y - curr_y
|
else:
|
||||||
|
return np.subtract(self._last_pos, self.pos)
|
||||||
|
|
||||||
def move(self, next_pos, state):
|
def move(self, next_pos, state):
|
||||||
next_pos = next_pos
|
next_pos = next_pos
|
||||||
curr_pos = self._pos
|
curr_pos = self._pos
|
||||||
if not_same_pos := curr_pos != next_pos:
|
if not_same_pos := curr_pos != next_pos:
|
||||||
if valid := state.check_move_validity(self, next_pos):
|
if valid := state.check_move_validity(self, next_pos):
|
||||||
self._pos = next_pos
|
|
||||||
self._last_pos = curr_pos
|
|
||||||
for observer in self.observers:
|
for observer in self.observers:
|
||||||
observer.notify_change_pos(self)
|
observer.notify_del_entity(self)
|
||||||
|
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
|
||||||
|
self._pos = next_pos
|
||||||
|
for observer in self.observers:
|
||||||
|
observer.notify_add_entity(self)
|
||||||
return valid
|
return valid
|
||||||
return not_same_pos
|
return not_same_pos
|
||||||
|
|
||||||
@ -64,6 +65,7 @@ class Entity(EnvObject, abc.ABC):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._status = None
|
self._status = None
|
||||||
self._pos = pos
|
self._pos = pos
|
||||||
|
self._last_pos = pos
|
||||||
if bind_to:
|
if bind_to:
|
||||||
try:
|
try:
|
||||||
self.bind_to(bind_to)
|
self.bind_to(bind_to)
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.entity.object import EnvObject
|
from marl_factory_grid.environment.entity.object import EnvObject
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.utils import helpers as h
|
from marl_factory_grid.utils import helpers as h
|
||||||
|
|
||||||
|
|
||||||
@ -30,17 +30,6 @@ class Floor(EnvObject):
|
|||||||
def var_is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
|
||||||
def neighboring_floor(self):
|
|
||||||
if self._neighboring_floor:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos))
|
|
||||||
for pos in h.POS_MASK.reshape(-1, 2)
|
|
||||||
if not np.all(pos == [0, 0])]
|
|
||||||
if x]
|
|
||||||
return self._neighboring_floor
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.VALUE_OCCUPIED_CELL
|
return c.VALUE_OCCUPIED_CELL
|
||||||
|
@ -197,7 +197,7 @@ class Factory(gym.Env):
|
|||||||
del rewards['global']
|
del rewards['global']
|
||||||
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
|
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
|
||||||
reward = [x + global_rewards for x in reward]
|
reward = [x + global_rewards for x in reward]
|
||||||
self.state.print(f"rewards are {rewards}")
|
self.state.print(f"Individual rewards are {dict(rewards)}")
|
||||||
return reward, combined_info_dict, done
|
return reward, combined_info_dict, done
|
||||||
else:
|
else:
|
||||||
reward = sum(rewards.values())
|
reward = sum(rewards.values())
|
||||||
@ -220,7 +220,7 @@ class Factory(gym.Env):
|
|||||||
|
|
||||||
def summarize_header(self):
|
def summarize_header(self):
|
||||||
header = {'rec_step': self.state.curr_step}
|
header = {'rec_step': self.state.curr_step}
|
||||||
for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']):
|
for entity_group in (x for x in self.state if x.name in ['Walls', 'DropOffLocations', 'ChargePods']):
|
||||||
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
|
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
|
||||||
return header
|
return header
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ class Factory(gym.Env):
|
|||||||
|
|
||||||
# Todo: Protobuff Compatibility Section #######
|
# Todo: Protobuff Compatibility Section #######
|
||||||
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
|
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
|
||||||
for entity_group in (x for x in self.state if x.name not in [c.FLOORS]):
|
for entity_group in self.state:
|
||||||
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
|
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
|
||||||
# TODO Section End ########
|
# TODO Section End ########
|
||||||
for key in list(summary.keys()):
|
for key in list(summary.keys()):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
from random import shuffle
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.objects import Objects
|
from marl_factory_grid.environment.groups.objects import Objects
|
||||||
@ -13,7 +14,7 @@ class Entities(Objects):
|
|||||||
def neighboring_positions(pos):
|
def neighboring_positions(pos):
|
||||||
return (POS_MASK + pos).reshape(-1, 2)
|
return (POS_MASK + pos).reshape(-1, 2)
|
||||||
|
|
||||||
def get_near_pos(self, pos):
|
def get_entities_near_pos(self, pos):
|
||||||
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
|
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
@ -38,11 +39,17 @@ class Entities(Objects):
|
|||||||
def guests_that_can_collide(self, pos):
|
def guests_that_can_collide(self, pos):
|
||||||
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||||
|
|
||||||
def empty_tiles(self):
|
@property
|
||||||
return[key for key in self.floorlist if not any(self.pos_dict[key])]
|
def empty_positions(self):
|
||||||
|
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
|
||||||
|
shuffle(empty_positions)
|
||||||
|
return empty_positions
|
||||||
|
|
||||||
def occupied_tiles(self): # positions that are not empty
|
@property
|
||||||
return[key for key in self.floorlist if any(self.pos_dict[key])]
|
def occupied_positions(self): # positions that are not empty
|
||||||
|
empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
|
||||||
|
shuffle(empty_positions)
|
||||||
|
return empty_positions
|
||||||
|
|
||||||
def is_blocked(self):
|
def is_blocked(self):
|
||||||
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
||||||
|
@ -37,7 +37,11 @@ class PositionMixin:
|
|||||||
|
|
||||||
def __delitem__(self, name):
|
def __delitem__(self, name):
|
||||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||||
obj.tile.leave(obj) # observer notify?
|
try:
|
||||||
|
for observer in obj.observers:
|
||||||
|
observer.notify_del_entity(obj)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
super().__delitem__(name)
|
super().__delitem__(name)
|
||||||
|
|
||||||
def by_pos(self, pos: (int, int)):
|
def by_pos(self, pos: (int, int)):
|
||||||
|
@ -103,6 +103,9 @@ class Objects:
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def by_name(self, name):
|
||||||
|
return next(x for x in self if x.name == name)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if isinstance(item, (int, np.int64, np.int32)):
|
if isinstance(item, (int, np.int64, np.int32)):
|
||||||
if item < 0:
|
if item < 0:
|
||||||
@ -120,7 +123,7 @@ class Objects:
|
|||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS, c.FLOORS]}
|
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||||
|
|
||||||
def spawn(self, n: int):
|
def spawn(self, n: int):
|
||||||
@ -132,22 +135,25 @@ class Objects:
|
|||||||
for item in items:
|
for item in items:
|
||||||
del self[item]
|
del self[item]
|
||||||
|
|
||||||
def notify_change_pos(self, entity: object):
|
# def notify_change_pos(self, entity: object):
|
||||||
try:
|
# try:
|
||||||
self.pos_dict[entity.last_pos].remove(entity)
|
# self.pos_dict[entity.last_pos].remove(entity)
|
||||||
except (ValueError, AttributeError):
|
# except (ValueError, AttributeError):
|
||||||
pass
|
# pass
|
||||||
if entity.var_has_position:
|
# if entity.var_has_position:
|
||||||
try:
|
# try:
|
||||||
self.pos_dict[entity.pos].append(entity)
|
# self.pos_dict[entity.pos].append(entity)
|
||||||
except (ValueError, AttributeError):
|
# except (ValueError, AttributeError):
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
def notify_del_entity(self, entity: Object):
|
def notify_del_entity(self, entity: Object):
|
||||||
try:
|
try:
|
||||||
entity.del_observer(self)
|
entity.del_observer(self)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
self.pos_dict[entity.pos].remove(entity)
|
self.pos_dict[entity.pos].remove(entity)
|
||||||
except (ValueError, AttributeError):
|
except (AttributeError, ValueError, IndexError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def notify_add_entity(self, entity: Object):
|
def notify_add_entity(self, entity: Object):
|
||||||
|
@ -15,6 +15,7 @@ class Walls(PositionMixin, EnvObjects):
|
|||||||
super(Walls, self).__init__(*args, **kwargs)
|
super(Walls, self).__init__(*args, **kwargs)
|
||||||
self._value = c.VALUE_OCCUPIED_CELL
|
self._value = c.VALUE_OCCUPIED_CELL
|
||||||
|
|
||||||
|
#ToDo: Do we need this? Move to spawn methode?
|
||||||
# @classmethod
|
# @classmethod
|
||||||
# def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
# def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||||
# tiles = cls(*args, **kwargs)
|
# tiles = cls(*args, **kwargs)
|
@ -49,7 +49,7 @@ class SpawnAgents(Rule):
|
|||||||
agent_conf = state.agents_conf
|
agent_conf = state.agents_conf
|
||||||
# agents = Agents(lvl_map.size)
|
# agents = Agents(lvl_map.size)
|
||||||
agents = state[c.AGENT]
|
agents = state[c.AGENT]
|
||||||
empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)]
|
empty_positions = state.entities.empty_positions[:len(agent_conf)]
|
||||||
for agent_name in agent_conf:
|
for agent_name in agent_conf:
|
||||||
actions = agent_conf[agent_name]['actions'].copy()
|
actions = agent_conf[agent_name]['actions'].copy()
|
||||||
observations = agent_conf[agent_name]['observations'].copy()
|
observations = agent_conf[agent_name]['observations'].copy()
|
||||||
@ -58,18 +58,17 @@ class SpawnAgents(Rule):
|
|||||||
shuffle(positions)
|
shuffle(positions)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
tile = state[c.FLOORS].by_pos(positions.pop())
|
pos = positions.pop()
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||||
f'\n{agent_name[agent_name]["positions"].copy()}')
|
f'\n{agent_name[agent_name]["positions"].copy()}')
|
||||||
try:
|
if agents.by_pos(pos) and state.check_pos_validity(pos):
|
||||||
agents.add_item(Agent(actions, observations, tile, str_ident=agent_name))
|
|
||||||
except AssertionError:
|
|
||||||
state.print(f'No valid pos:{tile.pos} for {agent_name}')
|
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name))
|
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
|||||||
from marl_factory_grid.environment.entity.object import EnvObject
|
from marl_factory_grid.environment.entity.object import EnvObject
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
|
|
||||||
from marl_factory_grid.modules.batteries import constants as b
|
from marl_factory_grid.modules.batteries import constants as b
|
||||||
|
|
||||||
|
@ -70,8 +70,8 @@ class PodRules(Rule):
|
|||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
pod_collection = state[b.CHARGE_PODS]
|
pod_collection = state[b.CHARGE_PODS]
|
||||||
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_pods]
|
empty_positions = state.entities.empty_positions()
|
||||||
pods = pod_collection.from_coordinates(empty_tiles, entity_kwargs=dict(
|
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
|
||||||
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
||||||
)
|
)
|
||||||
pod_collection.add_items(pods)
|
pod_collection.add_items(pods)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from numpy import random
|
from numpy import random
|
||||||
|
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.modules.clean_up import constants as d
|
from marl_factory_grid.modules.clean_up import constants as d
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
|
||||||
from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
@ -31,8 +30,6 @@ class DirtPiles(PositionMixin, EnvObjects):
|
|||||||
self.max_local_amount = max_local_amount
|
self.max_local_amount = max_local_amount
|
||||||
|
|
||||||
def spawn(self, then_dirty_positions, amount) -> bool:
|
def spawn(self, then_dirty_positions, amount) -> bool:
|
||||||
# if isinstance(then_dirty_tiles, Floor):
|
|
||||||
# then_dirty_tiles = [then_dirty_tiles]
|
|
||||||
for pos in then_dirty_positions:
|
for pos in then_dirty_positions:
|
||||||
if not self.amount > self.max_global_amount:
|
if not self.amount > self.max_global_amount:
|
||||||
if dirt := self.by_pos(pos):
|
if dirt := self.by_pos(pos):
|
||||||
@ -56,8 +53,8 @@ class DirtPiles(PositionMixin, EnvObjects):
|
|||||||
|
|
||||||
var = self.dirt_spawn_r_var
|
var = self.dirt_spawn_r_var
|
||||||
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
|
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
|
||||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
n_dirty_positions = max(0, int(new_spawn * len(free_for_dirt)))
|
||||||
return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
return self.spawn(free_for_dirt[:n_dirty_positions], self.initial_amount)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
s = super(DirtPiles, self).__repr__()
|
s = super(DirtPiles, self).__repr__()
|
||||||
|
@ -4,7 +4,7 @@ from marl_factory_grid.environment.entity.agent import Agent
|
|||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.modules.destinations import constants as d
|
from marl_factory_grid.modules.destinations import constants as d
|
||||||
|
|
||||||
|
|
||||||
@ -17,7 +17,6 @@ class Destination(BoundEntityMixin, Entity):
|
|||||||
var_is_blocking_light = False
|
var_is_blocking_light = False
|
||||||
var_can_be_bound = True # Introduce this globally!
|
var_can_be_bound = True # Introduce this globally!
|
||||||
|
|
||||||
@property
|
|
||||||
def was_reached(self):
|
def was_reached(self):
|
||||||
return self._was_reached
|
return self._was_reached
|
||||||
|
|
||||||
@ -35,11 +34,10 @@ class Destination(BoundEntityMixin, Entity):
|
|||||||
self._per_agent_actions[agent.name] += 1
|
self._per_agent_actions[agent.name] += 1
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
@property
|
def has_just_been_reached(self, state):
|
||||||
def has_just_been_reached(self):
|
if self.was_reached():
|
||||||
if self.was_reached:
|
|
||||||
return False
|
return False
|
||||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in state.entities.pos_dict[self.pos] if x.var_can_collide)
|
agent_at_position = any(state[c.AGENT].by_pos(self.pos))
|
||||||
|
|
||||||
if self.bound_entity:
|
if self.bound_entity:
|
||||||
return ((agent_at_position and not self.action_counts)
|
return ((agent_at_position and not self.action_counts)
|
||||||
@ -57,7 +55,7 @@ class Destination(BoundEntityMixin, Entity):
|
|||||||
return state_summary
|
return state_summary
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
if self.was_reached:
|
if self.was_reached():
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return RenderEntity(d.DESTINATION, self.pos)
|
return RenderEntity(d.DESTINATION, self.pos)
|
||||||
|
@ -16,28 +16,29 @@ class DestinationReachAll(Rule):
|
|||||||
|
|
||||||
def tick_step(self, state) -> List[TickResult]:
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
results = []
|
results = []
|
||||||
|
reached = False
|
||||||
for dest in state[d.DESTINATION]:
|
for dest in state[d.DESTINATION]:
|
||||||
if dest.has_just_been_reached and not dest.was_reached:
|
if dest.has_just_been_reached(state) and not dest.was_reached():
|
||||||
# Dest has just been reached, some agent needs to stand here, grab any first.
|
# Dest has just been reached, some agent needs to stand here
|
||||||
for agent in state[c.AGENT].by_pos(dest.pos):
|
for agent in state[c.AGENT].by_pos(dest.pos):
|
||||||
if dest.bound_entity:
|
if dest.bound_entity:
|
||||||
if dest.bound_entity == agent:
|
if dest.bound_entity == agent:
|
||||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
reached = True
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
reached = True
|
||||||
state.print(f'{dest.name} is reached now, mark as reached...')
|
|
||||||
dest.mark_as_reached()
|
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
if reached:
|
||||||
|
state.print(f'{dest.name} is reached now, mark as reached...')
|
||||||
|
dest.mark_as_reached()
|
||||||
|
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
if all(x.was_reached for x in state[d.DESTINATION]):
|
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ class DestinationReachAny(DestinationReachAll):
|
|||||||
super(DestinationReachAny, self).__init__()
|
super(DestinationReachAny, self).__init__()
|
||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
if any(x.was_reached for x in state[d.DESTINATION]):
|
if any(x.was_reached() for x in state[d.DESTINATION]):
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ class DestinationSpawn(Rule):
|
|||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
self.trigger_destination_spawn(self.n_dests, state)
|
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
@ -72,24 +73,14 @@ class DestinationSpawn(Rule):
|
|||||||
def tick_step(self, state) -> List[TickResult]:
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
|
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
|
||||||
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
|
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
|
||||||
validity = self.trigger_destination_spawn(n_dest_spawn, state)
|
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
|
||||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
||||||
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
|
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
|
||||||
validity = self.trigger_destination_spawn(n_dest_spawn, state)
|
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
|
||||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def trigger_destination_spawn(self, n_dests, state):
|
|
||||||
empty_positions = state[c.FLOORS].empty_tiles[:n_dests]
|
|
||||||
if destinations := [Destination(pos) for pos in empty_positions]:
|
|
||||||
state[d.DESTINATION].add_items(destinations)
|
|
||||||
state.print(f'{n_dests} new destinations have been spawned')
|
|
||||||
return c.VALID
|
|
||||||
else:
|
|
||||||
state.print('No Destiantions are spawning, limit is reached.')
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
|
|
||||||
class FixedDestinationSpawn(Rule):
|
class FixedDestinationSpawn(Rule):
|
||||||
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
|
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
|
||||||
@ -99,11 +90,17 @@ class FixedDestinationSpawn(Rule):
|
|||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
for (agent_name, position_list) in self.per_agent_positions.items():
|
for (agent_name, position_list) in self.per_agent_positions.items():
|
||||||
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
|
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
|
||||||
|
position_list = position_list.copy()
|
||||||
shuffle(position_list)
|
shuffle(position_list)
|
||||||
while True:
|
while True:
|
||||||
pos = position_list.pop()
|
try:
|
||||||
if pos != agent.pos and not state[d.DESTINATION].by_pos(pos):
|
pos = position_list.pop()
|
||||||
destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent)
|
except IndexError:
|
||||||
|
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
|
||||||
|
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
|
||||||
|
exit(9999)
|
||||||
|
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
|
||||||
|
destination = Destination(pos, bind_to=agent)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .actions import DoorUse
|
from .actions import DoorUse
|
||||||
from .entitites import Door, DoorIndicator
|
from .entitites import Door, DoorIndicator
|
||||||
from .groups import Doors
|
from .groups import Doors
|
||||||
from .rule_door_auto_close import DoorAutoClose
|
from .rules import DoorAutoClose, DoorIndicateArea
|
||||||
|
@ -13,8 +13,9 @@ class DoorUse(Action):
|
|||||||
|
|
||||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||||
# Check if agent really is standing on a door:
|
# Check if agent really is standing on a door:
|
||||||
e = state.entities.get_near_pos(entity.pos)
|
e = state.entities.get_entities_near_pos(entity.pos)
|
||||||
try:
|
try:
|
||||||
|
# Only one door opens TODO introcude loop
|
||||||
door = next(x for x in e if x.name.startswith(d.DOOR))
|
door = next(x for x in e if x.name.startswith(d.DOOR))
|
||||||
valid = door.use()
|
valid = door.use()
|
||||||
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
|
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
from marl_factory_grid.modules.doors import constants as d
|
from marl_factory_grid.modules.doors import constants as d
|
||||||
@ -41,7 +41,7 @@ class Door(Entity):
|
|||||||
def str_state(self):
|
def str_state(self):
|
||||||
return 'open' if self.is_open else 'closed'
|
return 'open' if self.is_open else 'closed'
|
||||||
|
|
||||||
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
|
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
|
||||||
self._status = d.STATE_CLOSED
|
self._status = d.STATE_CLOSED
|
||||||
super(Door, self).__init__(*args, **kwargs)
|
super(Door, self).__init__(*args, **kwargs)
|
||||||
self.auto_close_interval = auto_close_interval
|
self.auto_close_interval = auto_close_interval
|
||||||
@ -50,8 +50,6 @@ class Door(Entity):
|
|||||||
self._open()
|
self._open()
|
||||||
else:
|
else:
|
||||||
self._close()
|
self._close()
|
||||||
if indicate_area:
|
|
||||||
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
|
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self):
|
||||||
state_dict = super().summarize_state()
|
state_dict = super().summarize_state()
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from marl_factory_grid.environment.rules import Rule
|
from marl_factory_grid.environment.rules import Rule
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.utils.results import TickResult
|
from marl_factory_grid.utils.results import TickResult
|
||||||
from marl_factory_grid.modules.doors import constants as d
|
from . import constants as d
|
||||||
|
from .entitites import DoorIndicator
|
||||||
|
|
||||||
|
|
||||||
class DoorAutoClose(Rule):
|
class DoorAutoClose(Rule):
|
||||||
@ -19,3 +20,13 @@ class DoorAutoClose(Rule):
|
|||||||
return [TickResult(self.name, validity=c.VALID, value=0)]
|
return [TickResult(self.name, validity=c.VALID, value=0)]
|
||||||
state.print('There are no doors, but you loaded the corresponding Module')
|
state.print('There are no doors, but you loaded the corresponding Module')
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class DoorIndicateArea(Rule):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map):
|
||||||
|
for door in state[d.DOORS]:
|
||||||
|
state[d.DOORS].add_items([DoorIndicator(x) for x in state.entities.neighboring_positions(door.pos)])
|
@ -9,6 +9,8 @@ from marl_factory_grid.utils.results import TickResult
|
|||||||
class AgentSingleZonePlacementBeta(Rule):
|
class AgentSingleZonePlacementBeta(Rule):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
# TODO!!!! Is this concept needed any more?
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
@ -21,9 +23,9 @@ class AgentSingleZonePlacementBeta(Rule):
|
|||||||
coordinates = random.choices(self.coordinates, k=len(agents))
|
coordinates = random.choices(self.coordinates, k=len(agents))
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
tiles = [state[c.FLOORS].by_pos(pos) for pos in coordinates]
|
|
||||||
for agent, tile in zip(agents, tiles):
|
for agent, pos in zip(agents, coordinates):
|
||||||
agent.move(tile, state)
|
agent.move(pos, state)
|
||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
return []
|
return []
|
||||||
|
@ -2,7 +2,7 @@ from collections import deque
|
|||||||
|
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.modules.items import constants as i
|
from marl_factory_grid.modules.items import constants as i
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from random import shuffle
|
||||||
|
|
||||||
from marl_factory_grid.modules.items import constants as i
|
from marl_factory_grid.modules.items import constants as i
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
@ -19,10 +21,12 @@ class Items(PositionMixin, EnvObjects):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def trigger_item_spawn(state, n_items, spawn_frequency):
|
def trigger_item_spawn(state, n_items, spawn_frequency):
|
||||||
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
|
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
|
||||||
floor_list = state.entities.floorlist[:item_to_spawns]
|
position_list = [x for x in state.entities.floorlist]
|
||||||
state[i.ITEM].spawn(floor_list)
|
shuffle(position_list)
|
||||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') # spawn in self._next_item_spawn ?
|
position_list = state.entities.floorlist[:item_to_spawns]
|
||||||
return len(floor_list)
|
state[i.ITEM].spawn(position_list)
|
||||||
|
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}')
|
||||||
|
return len(position_list)
|
||||||
else:
|
else:
|
||||||
state.print('No Items are spawning, limit is reached.')
|
state.print('No Items are spawning, limit is reached.')
|
||||||
return 0
|
return 0
|
||||||
@ -100,7 +104,7 @@ class DropOffLocations(PositionMixin, EnvObjects):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trigger_drop_off_location_spawn(state, n_locations):
|
def trigger_drop_off_location_spawn(state, n_locations):
|
||||||
empty_tiles = state.entities.floorlist[:n_locations]
|
empty_positions = state.entities.empty_positions()[:n_locations]
|
||||||
do_entites = state[i.DROP_OFF]
|
do_entites = state[i.DROP_OFF]
|
||||||
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
|
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
||||||
do_entites.add_items(drop_offs)
|
do_entites.add_items(drop_offs)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from ...utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.utils.results import TickResult
|
from marl_factory_grid.utils.results import TickResult
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ class MachineRule(Rule):
|
|||||||
self.n_machines = n_machines
|
self.n_machines = n_machines
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_machines]
|
# TODO Move to spawn!!!
|
||||||
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles)
|
state[m.MACHINES].add_items(Machine(pos) for pos in state.entities.empty_positions())
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
pass
|
pass
|
||||||
|
@ -8,7 +8,7 @@ from ...environment.entity.entity import Entity
|
|||||||
from ..doors import constants as do
|
from ..doors import constants as do
|
||||||
from ..maintenance import constants as mi
|
from ..maintenance import constants as mi
|
||||||
from ...utils.helpers import MOVEMAP
|
from ...utils.helpers import MOVEMAP
|
||||||
from ...utils.render import RenderEntity
|
from ...utils.utility_classes import RenderEntity
|
||||||
from ...utils.states import Gamestate
|
from ...utils.states import Gamestate
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ class Maintainer(Entity):
|
|||||||
self._next = []
|
self._next = []
|
||||||
self._last = []
|
self._last = []
|
||||||
self._last_serviced = 'None'
|
self._last_serviced = 'None'
|
||||||
self._floortile_graph = points_to_graph(state[c.FLOORS].positions)
|
self._floortile_graph = points_to_graph(state.entities.floorlist)
|
||||||
|
|
||||||
def tick(self, state):
|
def tick(self, state):
|
||||||
if found_objective := state[self.objective].by_pos(self.pos):
|
if found_objective := state[self.objective].by_pos(self.pos):
|
||||||
|
@ -14,7 +14,8 @@ class MaintenanceRule(Rule):
|
|||||||
self.n_maintainer = n_maintainer
|
self.n_maintainer = n_maintainer
|
||||||
|
|
||||||
def on_init(self, state: Gamestate, lvl_map):
|
def on_init(self, state: Gamestate, lvl_map):
|
||||||
state[M.MAINTAINERS].spawn(state[c.FLOORS].empty_tiles[:self.n_maintainer], state)
|
# Move to spawn? : #TODO
|
||||||
|
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
|
@ -3,8 +3,7 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment.entity.object import Object
|
from marl_factory_grid.environment.entity.object import Object
|
||||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
from marl_factory_grid.modules.doors import constants as d
|
from marl_factory_grid.modules.doors import constants as d
|
||||||
@ -21,5 +20,5 @@ class Zone(Object):
|
|||||||
self.coords = coords
|
self.coords = coords
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def random_tile(self):
|
def random_pos(self):
|
||||||
return random.choice(self.coords)
|
return random.choice(self.coords)
|
||||||
|
@ -19,7 +19,7 @@ class ZoneInit(Rule):
|
|||||||
while z_idx:
|
while z_idx:
|
||||||
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||||
if len(zone_positions):
|
if len(zone_positions):
|
||||||
zones.append(Zone([state[c.FLOORS].by_pos(pos) for pos in zone_positions]))
|
zones.append(Zone(zone_positions))
|
||||||
z_idx += 1
|
z_idx += 1
|
||||||
else:
|
else:
|
||||||
z_idx = 0
|
z_idx = 0
|
||||||
@ -38,7 +38,7 @@ class AgentSingleZonePlacement(Rule):
|
|||||||
|
|
||||||
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
||||||
for agent in state[c.AGENT]:
|
for agent in state[c.AGENT]:
|
||||||
agent.move(state[z.ZONES][z_idxs.pop()].random_tile, state)
|
agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
@ -65,10 +65,10 @@ class IndividualDestinationZonePlacement(Rule):
|
|||||||
other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
|
other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
|
||||||
already_has_destination = True
|
already_has_destination = True
|
||||||
while already_has_destination:
|
while already_has_destination:
|
||||||
tile = choice(other_zones).random_tile
|
pos = choice(other_zones).random_pos
|
||||||
if state[d.DESTINATION].by_pos(tile.pos) is None:
|
if state[d.DESTINATION].by_pos(pos) is None:
|
||||||
already_has_destination = False
|
already_has_destination = False
|
||||||
destination = Destination(tile, bind_to=agent)
|
destination = Destination(pos, bind_to=agent)
|
||||||
|
|
||||||
state[d.DESTINATION].add_item(destination)
|
state[d.DESTINATION].add_item(destination)
|
||||||
continue
|
continue
|
||||||
|
@ -25,10 +25,8 @@ This file is used for:
|
|||||||
LEVELS_DIR = 'modules/levels' # for use in studies and experiments
|
LEVELS_DIR = 'modules/levels' # for use in studies and experiments
|
||||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||||
|
|
||||||
# Not used anymore? Clean!
|
|
||||||
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
|
||||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
|
||||||
'episode']
|
'episode']
|
||||||
|
|
||||||
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
||||||
@ -223,7 +221,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
|||||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||||
mod = importlib.import_module('.'.join(module_parts))
|
mod = importlib.import_module('.'.join(module_parts))
|
||||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
|
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
|
|
||||||
from marl_factory_grid.environment.groups.agents import Agents
|
from marl_factory_grid.environment.groups.agents import Agents
|
||||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
from marl_factory_grid.environment.groups.global_entities import Entities
|
||||||
from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors
|
from marl_factory_grid.environment.groups.walls import Walls
|
||||||
from marl_factory_grid.utils import helpers as h
|
from marl_factory_grid.utils import helpers as h
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
@ -34,16 +34,14 @@ class LevelParser(object):
|
|||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
# Global Entities
|
# Global Entities
|
||||||
list_of_all_floors = ([tuple(floor) for floor in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
||||||
entities = Entities(list_of_all_floors)
|
entities = Entities(list_of_all_positions)
|
||||||
|
|
||||||
# Walls
|
# Walls
|
||||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||||
entities.add_items({c.WALLS: walls})
|
entities.add_items({c.WALLS: walls})
|
||||||
|
|
||||||
# Floor
|
# Agents
|
||||||
floor = Floors.from_coordinates(list_of_all_floors, self.size)
|
|
||||||
entities.add_items({c.FLOOR: floor})
|
|
||||||
entities.add_items({c.AGENT: Agents(self.size)})
|
entities.add_items({c.AGENT: Agents(self.size)})
|
||||||
|
|
||||||
# All other
|
# All other
|
||||||
|
@ -9,6 +9,7 @@ from numba import njit
|
|||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.groups.utils import Combined
|
from marl_factory_grid.environment.groups.utils import Combined
|
||||||
from marl_factory_grid.utils.states import Gamestate
|
from marl_factory_grid.utils.states import Gamestate
|
||||||
|
from marl_factory_grid.utils.utility_classes import Floor
|
||||||
|
|
||||||
|
|
||||||
class OBSBuilder(object):
|
class OBSBuilder(object):
|
||||||
@ -39,6 +40,7 @@ class OBSBuilder(object):
|
|||||||
|
|
||||||
self.reset_struc_obs_block(state)
|
self.reset_struc_obs_block(state)
|
||||||
self.curr_lightmaps = dict()
|
self.curr_lightmaps = dict()
|
||||||
|
self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist})
|
||||||
|
|
||||||
def reset_struc_obs_block(self, state):
|
def reset_struc_obs_block(self, state):
|
||||||
self._curr_env_step = state.curr_step
|
self._curr_env_step = state.curr_step
|
||||||
@ -82,19 +84,23 @@ class OBSBuilder(object):
|
|||||||
self._sort_and_name_observation_conf(agent)
|
self._sort_and_name_observation_conf(agent)
|
||||||
agent_want_obs = self.obs_layers[agent.name]
|
agent_want_obs = self.obs_layers[agent.name]
|
||||||
|
|
||||||
# Handle in-grid observations aka visible observations
|
# Handle in-grid observations aka visible observations (Things on the map, with pos)
|
||||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities)
|
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities.pos_dict)
|
||||||
pre_sort_obs = defaultdict(lambda: np.zeros((self.pomdp_d, self.pomdp_d)))
|
pre_sort_obs = defaultdict(lambda: np.zeros(self.obs_shape))
|
||||||
for e in set(visible_entitites):
|
if self.pomdp_r:
|
||||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
for e in set(visible_entitites):
|
||||||
try:
|
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
try:
|
||||||
except IndexError:
|
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||||
# Seemded to be visible but is out or range
|
except IndexError:
|
||||||
pass
|
# Seemded to be visible but is out or range
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for e in set(visible_entitites):
|
||||||
|
pre_sort_obs[e.obs_tag][e.x, e.y] += e.encoding
|
||||||
|
|
||||||
pre_sort_obs = dict(pre_sort_obs)
|
pre_sort_obs = dict(pre_sort_obs)
|
||||||
obs = np.zeros((len(agent_want_obs), self.pomdp_d, self.pomdp_d))
|
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
||||||
|
|
||||||
for idx, l_name in enumerate(agent_want_obs):
|
for idx, l_name in enumerate(agent_want_obs):
|
||||||
try:
|
try:
|
||||||
@ -144,13 +150,26 @@ class OBSBuilder(object):
|
|||||||
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.curr_lightmaps[agent.name] = pre_sort_obs[c.FLOORS].astype(bool)
|
light_map = np.zeros(self.obs_shape)
|
||||||
|
visible_floor = set(self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False))
|
||||||
|
if self.pomdp_r:
|
||||||
|
coords = [((f.x - agent.x) + self.pomdp_r, (f.y - agent.y) + self.pomdp_r) for f in visible_floor]
|
||||||
|
else:
|
||||||
|
coords = [x.pos for x in visible_floor]
|
||||||
|
np.put(light_map, np.ravel_multi_index(np.asarray(coords).T, light_map.shape), 1)
|
||||||
|
self.curr_lightmaps[agent.name] = light_map
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print()
|
print()
|
||||||
return obs, self.obs_layers[agent.name]
|
return obs, self.obs_layers[agent.name]
|
||||||
|
|
||||||
def _sort_and_name_observation_conf(self, agent):
|
def _sort_and_name_observation_conf(self, agent):
|
||||||
self.ray_caster[agent.name] = RayCaster(agent, self.pomdp_r)
|
'''
|
||||||
|
Builds the useable observation scheme per agent from conf.yaml.
|
||||||
|
:param agent:
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
# Fixme: no asymetric shapes possible.
|
||||||
|
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||||
obs_layers = []
|
obs_layers = []
|
||||||
|
|
||||||
for obs_str in agent.observations:
|
for obs_str in agent.observations:
|
||||||
@ -173,7 +192,7 @@ class OBSBuilder(object):
|
|||||||
names.extend([x.name for x in agent.collection if x.name != agent.name])
|
names.extend([x.name for x in agent.collection if x.name != agent.name])
|
||||||
else:
|
else:
|
||||||
names.append(val)
|
names.append(val)
|
||||||
combined = Combined(names, self.pomdp_r, identifier=agent.name)
|
combined = Combined(names, self.size, identifier=agent.name)
|
||||||
self.all_obs[combined.name] = combined
|
self.all_obs[combined.name] = combined
|
||||||
obs_layers.append(combined.name)
|
obs_layers.append(combined.name)
|
||||||
elif obs_str == c.OTHERS:
|
elif obs_str == c.OTHERS:
|
||||||
@ -183,19 +202,18 @@ class OBSBuilder(object):
|
|||||||
else:
|
else:
|
||||||
obs_layers.append(obs_str)
|
obs_layers.append(obs_str)
|
||||||
self.obs_layers[agent.name] = obs_layers
|
self.obs_layers[agent.name] = obs_layers
|
||||||
self.curr_lightmaps[agent.name] = np.zeros((self.pomdp_d or self.level_shape[0],
|
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
|
||||||
self.pomdp_d or self.level_shape[1]
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
class RayCaster:
|
class RayCaster:
|
||||||
def __init__(self, agent, pomdp_r, degs=360):
|
def __init__(self, agent, pomdp_r, degs=360):
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.pomdp_r = pomdp_r
|
self.pomdp_r = pomdp_r
|
||||||
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
self.n_rays = (self.pomdp_r + 1) * 8
|
||||||
self.degs = degs
|
self.degs = degs
|
||||||
self.ray_targets = self.build_ray_targets()
|
self.ray_targets = self.build_ray_targets()
|
||||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||||
|
self._cache_dict = {}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
return f'{self.__class__.__name__}({self.agent.name})'
|
||||||
@ -211,30 +229,30 @@ class RayCaster:
|
|||||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||||
return rot_M.astype(int)
|
return rot_M.astype(int)
|
||||||
|
|
||||||
def ray_block_cache(self, cache_dict, key, callback):
|
def ray_block_cache(self, key, callback):
|
||||||
if key not in cache_dict:
|
if key not in self._cache_dict:
|
||||||
cache_dict[key] = callback()
|
self._cache_dict[key] = callback()
|
||||||
return cache_dict[key]
|
return self._cache_dict[key]
|
||||||
|
|
||||||
def visible_entities(self, entities):
|
def visible_entities(self, pos_dict, reset_cache=True):
|
||||||
visible = list()
|
visible = list()
|
||||||
cache_blocking = {}
|
if reset_cache:
|
||||||
|
self._cache_dict = {}
|
||||||
|
|
||||||
for ray in self.get_rays():
|
for ray in self.get_rays():
|
||||||
rx, ry = ray[0]
|
rx, ry = ray[0]
|
||||||
for x, y in ray:
|
for x, y in ray:
|
||||||
cx, cy = x - rx, y - ry
|
cx, cy = x - rx, y - ry
|
||||||
|
|
||||||
entities_hit = entities.pos_dict[(x, y)]
|
entities_hit = pos_dict[(x, y)]
|
||||||
hits = self.ray_block_cache(cache_blocking,
|
hits = self.ray_block_cache((x, y),
|
||||||
(x, y),
|
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
||||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light))
|
)
|
||||||
|
|
||||||
diag_hits = all([
|
diag_hits = all([
|
||||||
self.ray_block_cache(
|
self.ray_block_cache(
|
||||||
cache_blocking,
|
|
||||||
key,
|
key,
|
||||||
lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light))
|
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(pos_dict[key]))
|
||||||
for key in ((x, y-cy), (x-cx, y))
|
for key in ((x, y-cy), (x-cx, y))
|
||||||
]) if (cx != 0 and cy != 0) else False
|
]) if (cx != 0 and cy != 0) else False
|
||||||
|
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RenderEntity:
|
|
||||||
name: str
|
|
||||||
pos: np.array
|
|
||||||
value: float = 1
|
|
||||||
value_operation: str = 'none'
|
|
||||||
state: str = None
|
|
||||||
id: int = 0
|
|
||||||
aux: Any = None
|
|
||||||
real_name: str = 'none'
|
|
@ -9,7 +9,7 @@ import pygame
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
|
|
||||||
AGENT: str = 'agent'
|
AGENT: str = 'agent'
|
||||||
STATE_IDLE: str = 'idle'
|
STATE_IDLE: str = 'idle'
|
||||||
|
@ -3,8 +3,6 @@ from typing import List, Dict, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
|
||||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
|
||||||
from marl_factory_grid.environment.rules import Rule
|
from marl_factory_grid.environment.rules import Rule
|
||||||
from marl_factory_grid.utils.results import Result
|
from marl_factory_grid.utils.results import Result
|
||||||
|
|
||||||
@ -112,15 +110,10 @@ class Gamestate(object):
|
|||||||
results.extend(on_check_done_result)
|
results.extend(on_check_done_result)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# def get_all_tiles_with_collisions(self) -> List[Floor]:
|
|
||||||
# tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
|
||||||
# if sum([x.var_can_collide for x in e]) > 1]
|
|
||||||
# # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
|
||||||
# return tiles
|
|
||||||
|
|
||||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||||
positions = [pos for pos, e in self.entities.pos_dict.items()
|
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
|
||||||
if sum([x.var_can_collide for x in e]) > 1]
|
if any([e.var_can_collide for e in entity_list_for_position])]
|
||||||
return positions
|
return positions
|
||||||
|
|
||||||
def check_move_validity(self, moving_entity, position):
|
def check_move_validity(self, moving_entity, position):
|
||||||
@ -128,6 +121,14 @@ class Gamestate(object):
|
|||||||
# and not (guest.var_is_blocking_pos and self.is_occupied()):
|
# and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||||
if moving_entity.pos != position and not any(
|
if moving_entity.pos != position and not any(
|
||||||
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
||||||
moving_entity.var_is_blocking_pos and moving_entity.is_occupied()):
|
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
|
||||||
return True
|
return True
|
||||||
return False
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_pos_validity(self, position):
|
||||||
|
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ ENTITIES = 'Objects'
|
|||||||
OBSERVATIONS = 'Observations'
|
OBSERVATIONS = 'Observations'
|
||||||
RULES = 'Rule'
|
RULES = 'Rule'
|
||||||
ASSETS = 'Assets'
|
ASSETS = 'Assets'
|
||||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls',
|
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
|
||||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class MarlFrameStack(gym.ObservationWrapper):
|
class MarlFrameStack(gym.ObservationWrapper):
|
||||||
@ -10,3 +14,37 @@ class MarlFrameStack(gym.ObservationWrapper):
|
|||||||
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
|
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||||
return observation[0:].swapaxes(0, 1)
|
return observation[0:].swapaxes(0, 1)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RenderEntity:
|
||||||
|
name: str
|
||||||
|
pos: np.array
|
||||||
|
value: float = 1
|
||||||
|
value_operation: str = 'none'
|
||||||
|
state: str = None
|
||||||
|
id: int = 0
|
||||||
|
aux: Any = None
|
||||||
|
real_name: str = 'none'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Floor:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return f"Floor({self.pos})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pos(self):
|
||||||
|
return self.x, self.y
|
||||||
|
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
var_is_blocking_light: bool = False
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.name == other.name
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.name)
|
||||||
|
Reference in New Issue
Block a user