mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-13 22:44:00 +02:00
WIP: removing tiles
This commit is contained in:
@@ -54,9 +54,10 @@ class TSPBaseAgent(ABC):
|
|||||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||||
return route
|
return route
|
||||||
|
|
||||||
def _door_is_close(self):
|
def _door_is_close(self, state):
|
||||||
try:
|
try:
|
||||||
return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||||
|
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@@ -14,7 +14,7 @@ class TSPDirtAgent(TSPBaseAgent):
|
|||||||
if self._env.state[di.DIRT].by_pos(self.state.pos) is not None:
|
if self._env.state[di.DIRT].by_pos(self.state.pos) is not None:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = di.CLEAN_UP
|
action = di.CLEAN_UP
|
||||||
elif door := self._door_is_close():
|
elif door := self._door_is_close(self._env):
|
||||||
action = self._use_door_or_move(door, di.DIRT)
|
action = self._use_door_or_move(door, di.DIRT)
|
||||||
else:
|
else:
|
||||||
action = self._predict_move(di.DIRT)
|
action = self._predict_move(di.DIRT)
|
||||||
|
@@ -24,7 +24,7 @@ class TSPItemAgent(TSPBaseAgent):
|
|||||||
elif self._env.state[i.DROP_OFF].by_pos(self.state.pos) is not None:
|
elif self._env.state[i.DROP_OFF].by_pos(self.state.pos) is not None:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = i.ITEM_ACTION
|
action = i.ITEM_ACTION
|
||||||
elif door := self._door_is_close():
|
elif door := self._door_is_close(self._env):
|
||||||
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
|
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
|
||||||
else:
|
else:
|
||||||
action = self._choose()
|
action = self._choose()
|
||||||
|
@@ -11,15 +11,16 @@ class TSPTargetAgent(TSPBaseAgent):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _handle_doors(self):
|
def _handle_doors(self, state):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
# return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||||
|
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict(self, *_, **__):
|
def predict(self, *_, **__):
|
||||||
if door := self._door_is_close():
|
if door := self._door_is_close(self._env):
|
||||||
action = self._use_door_or_move(door, d.DESTINATION)
|
action = self._use_door_or_move(door, d.DESTINATION)
|
||||||
else:
|
else:
|
||||||
action = self._predict_move(d.DESTINATION)
|
action = self._predict_move(d.DESTINATION)
|
||||||
|
@@ -40,11 +40,11 @@ class Move(Action, abc.ABC):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def do(self, entity, env):
|
def do(self, entity, state):
|
||||||
new_pos = self._calc_new_pos(entity.pos)
|
new_pos = self._calc_new_pos(entity.pos)
|
||||||
if next_tile := env[c.FLOOR].by_pos(new_pos):
|
if state.check_move_validity(entity, new_pos): # next_tile := state[c.FLOOR].by_pos(new_pos):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
valid = entity.move(next_tile)
|
valid = entity.move(new_pos, state)
|
||||||
else:
|
else:
|
||||||
valid = c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
||||||
|
@@ -31,7 +31,7 @@ class Entity(EnvObject, abc.ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def tile(self):
|
def tile(self):
|
||||||
return self._tile # wall_n_floors funktionalität
|
return self._tile # wall_n_floors funktionalität
|
||||||
|
|
||||||
# @property
|
# @property
|
||||||
# def last_tile(self):
|
# def last_tile(self):
|
||||||
@@ -48,12 +48,11 @@ class Entity(EnvObject, abc.ABC):
|
|||||||
curr_x, curr_y = self.pos
|
curr_x, curr_y = self.pos
|
||||||
return last_x - curr_x, last_y - curr_y
|
return last_x - curr_x, last_y - curr_y
|
||||||
|
|
||||||
def move(self, next_pos):
|
def move(self, next_pos, state):
|
||||||
next_pos = next_pos
|
next_pos = next_pos
|
||||||
curr_pos = self._pos
|
curr_pos = self._pos
|
||||||
if not_same_pos := curr_pos != next_pos:
|
if not_same_pos := curr_pos != next_pos:
|
||||||
if valid := next_tile.enter(self): # muss abgefragt werden über observer? alle obs? wie sonst posdict
|
if valid := state.check_move_validity(self, next_pos):
|
||||||
# curr_tile.leave(self) kann raus wegen notify change pos
|
|
||||||
self._pos = next_pos
|
self._pos = next_pos
|
||||||
self._last_pos = curr_pos
|
self._last_pos = curr_pos
|
||||||
for observer in self.observers:
|
for observer in self.observers:
|
||||||
@@ -66,7 +65,6 @@ class Entity(EnvObject, abc.ABC):
|
|||||||
self._status = None
|
self._status = None
|
||||||
self._pos = position
|
self._pos = position
|
||||||
self._last_pos = c.VALUE_NO_POS
|
self._last_pos = c.VALUE_NO_POS
|
||||||
# tile.enter(self)
|
|
||||||
|
|
||||||
def summarize_state(self) -> dict: # tile=str(self.tile.name)
|
def summarize_state(self) -> dict: # tile=str(self.tile.name)
|
||||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y), can_collide=bool(self.var_can_collide))
|
return dict(name=str(self.name), x=int(self.x), y=int(self.y), can_collide=bool(self.var_can_collide))
|
||||||
|
@@ -45,9 +45,9 @@ class Floor(EnvObject):
|
|||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.VALUE_OCCUPIED_CELL
|
return c.VALUE_OCCUPIED_CELL
|
||||||
|
|
||||||
@property
|
# @property
|
||||||
def guests_that_can_collide(self):
|
# def guests_that_can_collide(self):
|
||||||
return [x for x in self.guests if x.var_can_collide]
|
# return [x for x in self.guests if x.var_can_collide]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guests(self):
|
def guests(self):
|
||||||
|
@@ -64,7 +64,8 @@ class Factory(gym.Env):
|
|||||||
self.map: LevelParser
|
self.map: LevelParser
|
||||||
self.obs_builder: OBSBuilder
|
self.obs_builder: OBSBuilder
|
||||||
|
|
||||||
# TODO: Reset ---> document this
|
# reset env to initial state, preparing env for new episode.
|
||||||
|
# returns tuple where the first dict contains initial observation for each agent in the env
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
@@ -80,22 +81,22 @@ class Factory(gym.Env):
|
|||||||
|
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
# Init entity:
|
# Init entities
|
||||||
entities = self.map.do_init() # done
|
entities = self.map.do_init()
|
||||||
|
|
||||||
# Grab all )rules:
|
# Init rules
|
||||||
rules = self.conf.load_rules()
|
rules = self.conf.load_rules()
|
||||||
|
|
||||||
# Agents
|
# Init agents
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
self.state = Gamestate(entities, rules, self.conf.env_seed) # get_all_tiles_with_collisions
|
self.state = Gamestate(entities, rules, self.conf.env_seed) # get_all_tiles_with_collisions
|
||||||
agents = self.conf.load_agents(self.map.size, self[c.FLOOR].empty_tiles) # empty_tiles -> entity(tile)
|
agents = self.conf.load_agents(self.map.size, self.state.entities.floorlist)
|
||||||
self.state.entities.add_item({c.AGENT: agents})
|
self.state.entities.add_item({c.AGENT: agents})
|
||||||
|
|
||||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||||
self.state.rules.do_all_init(self.state, self.map)
|
self.state.rules.do_all_init(self.state, self.map)
|
||||||
|
|
||||||
# Observations
|
# Build initial observations for all agents
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
||||||
return self.obs_builder.refresh_and_build_for_all(self.state)
|
return self.obs_builder.refresh_and_build_for_all(self.state)
|
||||||
@@ -165,7 +166,7 @@ class Factory(gym.Env):
|
|||||||
if not self._renderer: # lazy init
|
if not self._renderer: # lazy init
|
||||||
from marl_factory_grid.utils.renderer import Renderer
|
from marl_factory_grid.utils.renderer import Renderer
|
||||||
global Renderer
|
global Renderer
|
||||||
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
||||||
|
|
||||||
render_entities = self.state.entities.render()
|
render_entities = self.state.entities.render()
|
||||||
if self.conf.pomdp_r:
|
if self.conf.pomdp_r:
|
||||||
|
@@ -23,10 +23,33 @@ class Entities(Objects):
|
|||||||
def names(self):
|
def names(self):
|
||||||
return list(self._data.keys())
|
return list(self._data.keys())
|
||||||
|
|
||||||
def __init__(self):
|
@property
|
||||||
|
def floorlist(self):
|
||||||
|
return self._floor_positions
|
||||||
|
|
||||||
|
def __init__(self, floor_positions):
|
||||||
|
self._floor_positions = floor_positions
|
||||||
self.pos_dict = defaultdict(list)
|
self.pos_dict = defaultdict(list)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# def all_floors(self):
|
||||||
|
# return[key for key, val in self.pos_dict.items() if any('floor' in x.name.lower() for x in val)]
|
||||||
|
|
||||||
|
def guests_that_can_collide(self, pos):
|
||||||
|
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||||
|
|
||||||
|
def empty_tiles(self):
|
||||||
|
return[key for key in self.floorlist if not any(self.pos_dict[key])]
|
||||||
|
|
||||||
|
def occupied_tiles(self): # positions that are not empty
|
||||||
|
return[key for key in self.floorlist if any(self.pos_dict[key])]
|
||||||
|
|
||||||
|
def is_blocked(self):
|
||||||
|
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
||||||
|
|
||||||
|
def is_not_blocked(self):
|
||||||
|
return[key for key, val in self.pos_dict.items() if not all([x.var_is_blocking_pos for x in val])]
|
||||||
|
|
||||||
def iter_entities(self):
|
def iter_entities(self):
|
||||||
return iter((x for sublist in self.values() for x in sublist))
|
return iter((x for sublist in self.values() for x in sublist))
|
||||||
|
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||||
@@ -11,30 +13,31 @@ class PositionMixin:
|
|||||||
var_can_collide: bool = True
|
var_can_collide: bool = True
|
||||||
var_has_position: bool = True
|
var_has_position: bool = True
|
||||||
|
|
||||||
def spawn(self, coords: List[Tuple[(int, int)]]): # runde klammern?
|
def spawn(self, coords: List[Tuple[(int, int)]]):
|
||||||
self.add_items([self._entity(pos) for pos in coords])
|
self.add_items([self._entity(pos) for pos in coords])
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return [y for y in [x.render() for x in self] if y is not None]
|
return [y for y in [x.render() for x in self] if y is not None]
|
||||||
|
|
||||||
@classmethod
|
# @classmethod
|
||||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
# def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||||
collection = cls(*args, **kwargs)
|
# collection = cls(*args, **kwargs)
|
||||||
entities = [cls._entity(tile, str_ident=i,
|
# entities = [cls._entity(tile, str_ident=i,
|
||||||
**entity_kwargs if entity_kwargs is not None else {})
|
# **entity_kwargs if entity_kwargs is not None else {})
|
||||||
for i, tile in enumerate(tiles)]
|
# for i, tile in enumerate(tiles)]
|
||||||
collection.add_items(entities)
|
# collection.add_items(entities)
|
||||||
return collection
|
# return collection
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
|
||||||
return cls.from_tiles([tiles.by_pos(position) for position in positions], tiles.size, *args,
|
collection = cls(*args, **kwargs)
|
||||||
entity_kwargs=entity_kwargs,
|
collection.add_items(
|
||||||
**kwargs)
|
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
|
||||||
|
return collection
|
||||||
|
|
||||||
def __delitem__(self, name):
|
def __delitem__(self, name):
|
||||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||||
obj.tile.leave(obj)
|
obj.tile.leave(obj) # observer notify?
|
||||||
super().__delitem__(name)
|
super().__delitem__(name)
|
||||||
|
|
||||||
def by_pos(self, pos: (int, int)):
|
def by_pos(self, pos: (int, int)):
|
||||||
|
@@ -126,17 +126,12 @@ class Objects:
|
|||||||
del self[item]
|
del self[item]
|
||||||
|
|
||||||
def notify_change_pos(self, entity: object):
|
def notify_change_pos(self, entity: object):
|
||||||
# print("notifychange")
|
|
||||||
try:
|
try:
|
||||||
# print("lastpos")
|
|
||||||
# print(self.pos_dict[entity.last_pos])
|
|
||||||
self.pos_dict[entity.last_pos].remove(entity)
|
self.pos_dict[entity.last_pos].remove(entity)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
pass
|
pass
|
||||||
if entity.var_has_position:
|
if entity.var_has_position:
|
||||||
try:
|
try:
|
||||||
# print("pos")
|
|
||||||
# print(self.pos_dict[entity.pos])
|
|
||||||
self.pos_dict[entity.pos].append(entity)
|
self.pos_dict[entity.pos].append(entity)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
pass
|
pass
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
@@ -15,16 +15,12 @@ class Walls(PositionMixin, EnvObjects):
|
|||||||
super(Walls, self).__init__(*args, **kwargs)
|
super(Walls, self).__init__(*args, **kwargs)
|
||||||
self._value = c.VALUE_OCCUPIED_CELL
|
self._value = c.VALUE_OCCUPIED_CELL
|
||||||
|
|
||||||
@classmethod
|
# @classmethod
|
||||||
def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
# def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||||
tiles = cls(*args, **kwargs)
|
# tiles = cls(*args, **kwargs)
|
||||||
# noinspection PyTypeChecker
|
# # noinspection PyTypeChecker
|
||||||
tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates])
|
# tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates])
|
||||||
return tiles
|
# return tiles
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
|
||||||
raise RuntimeError()
|
|
||||||
|
|
||||||
def by_pos(self, pos: (int, int)):
|
def by_pos(self, pos: (int, int)):
|
||||||
try:
|
try:
|
||||||
@@ -43,17 +39,4 @@ class Floors(Walls):
|
|||||||
super(Floors, self).__init__(*args, **kwargs)
|
super(Floors, self).__init__(*args, **kwargs)
|
||||||
self._value = c.VALUE_FREE_CELL
|
self._value = c.VALUE_FREE_CELL
|
||||||
|
|
||||||
@property
|
|
||||||
def occupied_tiles(self):
|
|
||||||
tiles = [tile for tile in self if tile.is_occupied()]
|
|
||||||
random.shuffle(tiles)
|
|
||||||
return tiles
|
|
||||||
|
|
||||||
@property
|
|
||||||
def empty_tiles(self) -> List[Floor]:
|
|
||||||
# def empty_tiles(self) -> List[Tuple[int, int]]:
|
|
||||||
tiles = [tile for tile in self if tile.is_empty()]
|
|
||||||
# positions = [tile.pos for tile in self if tile.is_empty()]
|
|
||||||
random.shuffle(tiles)
|
|
||||||
return tiles
|
|
||||||
|
|
||||||
|
@@ -74,10 +74,10 @@ class Collision(Rule):
|
|||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
self.curr_done = False
|
self.curr_done = False
|
||||||
tiles_with_collisions = state.get_all_tiles_with_collisions()
|
pos_with_collisions = state.get_all_pos_with_collisions()
|
||||||
results = list()
|
results = list()
|
||||||
for tile in tiles_with_collisions:
|
for pos in pos_with_collisions:
|
||||||
guests = tile.guests_that_can_collide
|
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
|
||||||
if len(guests) >= 2:
|
if len(guests) >= 2:
|
||||||
for i, guest in enumerate(guests):
|
for i, guest in enumerate(guests):
|
||||||
try:
|
try:
|
||||||
|
@@ -66,7 +66,8 @@ class Pod(Entity):
|
|||||||
def charge_battery(self, battery: Battery):
|
def charge_battery(self, battery: Battery):
|
||||||
if battery.charge_level == 1.0:
|
if battery.charge_level == 1.0:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
# if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||||
|
if sum(1 for key, val in self.state.entities.pos_dict[self.pos] for guest in val if 'agent' in guest.name.lower()) > 1:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
valid = battery.do_charge_action(self.charge_rate)
|
valid = battery.do_charge_action(self.charge_rate)
|
||||||
return valid
|
return valid
|
||||||
|
@@ -1,6 +1,4 @@
|
|||||||
from .actions import CleanUp
|
from .actions import CleanUp
|
||||||
from .entitites import DirtPile
|
from .entitites import DirtPile
|
||||||
from .groups import DirtPiles
|
from .groups import DirtPiles
|
||||||
from .rule_respawn import DirtRespawnRule
|
from .rules import DirtRespawnRule, DirtSmearOnMove, DirtAllCleanDone
|
||||||
from .rule_smear_on_move import DirtSmearOnMove
|
|
||||||
from .rule_done_on_all_clean import DirtAllCleanDone
|
|
||||||
|
@@ -14,7 +14,7 @@ class CleanUp(Action):
|
|||||||
super().__init__(d.CLEAN_UP)
|
super().__init__(d.CLEAN_UP)
|
||||||
|
|
||||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||||
if dirt := state[d.DIRT].by_pos(entity.pos):
|
if dirt := next((x for x in state.entities.pos_dict[entity.pos] if "dirt" in x.name.lower()), None):
|
||||||
new_dirt_amount = dirt.amount - state[d.DIRT].clean_amount
|
new_dirt_amount = dirt.amount - state[d.DIRT].clean_amount
|
||||||
|
|
||||||
if new_dirt_amount <= 0:
|
if new_dirt_amount <= 0:
|
||||||
|
@@ -7,7 +7,6 @@ from marl_factory_grid.environment import constants as c
|
|||||||
|
|
||||||
|
|
||||||
class DirtPiles(PositionMixin, EnvObjects):
|
class DirtPiles(PositionMixin, EnvObjects):
|
||||||
|
|
||||||
_entity = DirtPile
|
_entity = DirtPile
|
||||||
is_blocking_light: bool = False
|
is_blocking_light: bool = False
|
||||||
can_collide: bool = False
|
can_collide: bool = False
|
||||||
@@ -31,27 +30,28 @@ class DirtPiles(PositionMixin, EnvObjects):
|
|||||||
self.max_global_amount = max_global_amount
|
self.max_global_amount = max_global_amount
|
||||||
self.max_local_amount = max_local_amount
|
self.max_local_amount = max_local_amount
|
||||||
|
|
||||||
def spawn(self, then_dirty_tiles, amount) -> bool:
|
def spawn(self, then_dirty_positions, amount) -> bool:
|
||||||
if isinstance(then_dirty_tiles, Floor):
|
# if isinstance(then_dirty_tiles, Floor):
|
||||||
then_dirty_tiles = [then_dirty_tiles]
|
# then_dirty_tiles = [then_dirty_tiles]
|
||||||
for tile in then_dirty_tiles:
|
for pos in then_dirty_positions:
|
||||||
if not self.amount > self.max_global_amount:
|
if not self.amount > self.max_global_amount:
|
||||||
if dirt := self.by_pos(tile.pos):
|
if dirt := self.by_pos(pos):
|
||||||
new_value = dirt.amount + amount
|
new_value = dirt.amount + amount
|
||||||
dirt.set_new_amount(new_value)
|
dirt.set_new_amount(new_value)
|
||||||
else:
|
else:
|
||||||
dirt = DirtPile(tile, initial_amount=amount, spawn_variation=self.dirt_spawn_r_var)
|
dirt = DirtPile(pos, initial_amount=amount, spawn_variation=self.dirt_spawn_r_var)
|
||||||
self.add_item(dirt)
|
self.add_item(dirt)
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
def trigger_dirt_spawn(self, state, initial_spawn=False) -> bool:
|
def trigger_dirt_spawn(self, state, initial_spawn=False) -> bool:
|
||||||
free_for_dirt = [x for x in state[c.FLOOR]
|
free_for_dirt = [x for x in state.entities.floorlist if len(state.entities.pos_dict[x]) == 1 or (
|
||||||
if len(x.guests) == 0 or (
|
len(state.entities.pos_dict[x]) == 2 and isinstance(next(y for y in x), DirtPile))]
|
||||||
len(x.guests) == 1 and
|
# free_for_dirt = [x for x in state[c.FLOOR]
|
||||||
isinstance(next(y for y in x.guests), DirtPile))
|
# if len(x.guests) == 0 or (
|
||||||
]
|
# len(x.guests) == 1 and
|
||||||
|
# isinstance(next(y for y in x.guests), DirtPile))]
|
||||||
state.rng.shuffle(free_for_dirt)
|
state.rng.shuffle(free_for_dirt)
|
||||||
|
|
||||||
var = self.dirt_spawn_r_var
|
var = self.dirt_spawn_r_var
|
||||||
|
@@ -1,15 +0,0 @@
|
|||||||
from marl_factory_grid.environment import constants as c
|
|
||||||
from marl_factory_grid.environment.rules import Rule
|
|
||||||
from marl_factory_grid.utils.results import DoneResult
|
|
||||||
from marl_factory_grid.modules.clean_up import constants as d, rewards as r
|
|
||||||
|
|
||||||
|
|
||||||
class DirtAllCleanDone(Rule):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def on_check_done(self, state) -> [DoneResult]:
|
|
||||||
if len(state[d.DIRT]) == 0 and state.curr_step:
|
|
||||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=r.CLEAN_UP_ALL)]
|
|
||||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
|
@@ -1,28 +0,0 @@
|
|||||||
from marl_factory_grid.environment.rules import Rule
|
|
||||||
from marl_factory_grid.utils.results import TickResult
|
|
||||||
|
|
||||||
from marl_factory_grid.modules.clean_up import constants as d
|
|
||||||
|
|
||||||
|
|
||||||
class DirtRespawnRule(Rule):
|
|
||||||
|
|
||||||
def __init__(self, spawn_freq=15):
|
|
||||||
super().__init__()
|
|
||||||
self.spawn_freq = spawn_freq
|
|
||||||
self._next_dirt_spawn = spawn_freq
|
|
||||||
|
|
||||||
def on_init(self, state, lvl_map) -> str:
|
|
||||||
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
|
|
||||||
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
|
|
||||||
|
|
||||||
def tick_step(self, state):
|
|
||||||
if self._next_dirt_spawn < 0:
|
|
||||||
pass # No DirtPile Spawn
|
|
||||||
elif not self._next_dirt_spawn:
|
|
||||||
validity = state[d.DIRT].trigger_dirt_spawn(state)
|
|
||||||
|
|
||||||
return [TickResult(entity=None, validity=validity, identifier=self.name, reward=0)]
|
|
||||||
self._next_dirt_spawn = self.spawn_freq
|
|
||||||
else:
|
|
||||||
self._next_dirt_spawn -= 1
|
|
||||||
return []
|
|
@@ -1,24 +0,0 @@
|
|||||||
from marl_factory_grid.environment.rules import Rule
|
|
||||||
from marl_factory_grid.utils.helpers import is_move
|
|
||||||
from marl_factory_grid.utils.results import TickResult
|
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
|
||||||
from marl_factory_grid.modules.clean_up import constants as d
|
|
||||||
|
|
||||||
|
|
||||||
class DirtSmearOnMove(Rule):
|
|
||||||
|
|
||||||
def __init__(self, smear_amount: float = 0.2):
|
|
||||||
super().__init__()
|
|
||||||
self.smear_amount = smear_amount
|
|
||||||
|
|
||||||
def tick_post_step(self, state):
|
|
||||||
results = list()
|
|
||||||
for entity in state.moving_entites:
|
|
||||||
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
|
||||||
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
|
||||||
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
|
|
||||||
if state[d.DIRT].spawn(entity.tile, amount=smeared_dirt):
|
|
||||||
results.append(TickResult(identifier=self.name, entity=entity,
|
|
||||||
reward=0, validity=c.VALID))
|
|
||||||
return results
|
|
60
marl_factory_grid/modules/clean_up/rules.py
Normal file
60
marl_factory_grid/modules/clean_up/rules.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from marl_factory_grid.modules.clean_up import constants as d, rewards as r
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
|
from marl_factory_grid.environment.rules import Rule
|
||||||
|
from marl_factory_grid.utils.helpers import is_move
|
||||||
|
from marl_factory_grid.utils.results import TickResult
|
||||||
|
from marl_factory_grid.utils.results import DoneResult
|
||||||
|
|
||||||
|
|
||||||
|
class DirtAllCleanDone(Rule):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def on_check_done(self, state) -> [DoneResult]:
|
||||||
|
if len(state[d.DIRT]) == 0 and state.curr_step:
|
||||||
|
return [DoneResult(validity=c.VALID, identifier=self.name, reward=r.CLEAN_UP_ALL)]
|
||||||
|
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||||
|
|
||||||
|
|
||||||
|
class DirtRespawnRule(Rule):
|
||||||
|
|
||||||
|
def __init__(self, spawn_freq=15):
|
||||||
|
super().__init__()
|
||||||
|
self.spawn_freq = spawn_freq
|
||||||
|
self._next_dirt_spawn = spawn_freq
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map) -> str:
|
||||||
|
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
|
||||||
|
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
|
||||||
|
|
||||||
|
def tick_step(self, state):
|
||||||
|
if self._next_dirt_spawn < 0:
|
||||||
|
pass # No DirtPile Spawn
|
||||||
|
elif not self._next_dirt_spawn:
|
||||||
|
validity = state[d.DIRT].trigger_dirt_spawn(state)
|
||||||
|
|
||||||
|
return [TickResult(entity=None, validity=validity, identifier=self.name, reward=0)]
|
||||||
|
self._next_dirt_spawn = self.spawn_freq
|
||||||
|
else:
|
||||||
|
self._next_dirt_spawn -= 1
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class DirtSmearOnMove(Rule):
|
||||||
|
|
||||||
|
def __init__(self, smear_amount: float = 0.2):
|
||||||
|
super().__init__()
|
||||||
|
self.smear_amount = smear_amount
|
||||||
|
|
||||||
|
def tick_post_step(self, state):
|
||||||
|
results = list()
|
||||||
|
for entity in state.moving_entites:
|
||||||
|
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
||||||
|
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
||||||
|
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
|
||||||
|
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): # pos statt tile
|
||||||
|
results.append(TickResult(identifier=self.name, entity=entity,
|
||||||
|
reward=0, validity=c.VALID))
|
||||||
|
return results
|
@@ -9,7 +9,6 @@ from marl_factory_grid.modules.destinations import constants as d
|
|||||||
|
|
||||||
|
|
||||||
class Destination(Entity):
|
class Destination(Entity):
|
||||||
|
|
||||||
var_can_move = False
|
var_can_move = False
|
||||||
var_can_collide = False
|
var_can_collide = False
|
||||||
var_has_position = True
|
var_has_position = True
|
||||||
@@ -40,9 +39,9 @@ class Destination(Entity):
|
|||||||
def leave(self, agent: Agent):
|
def leave(self, agent: Agent):
|
||||||
del self._per_agent_times[agent.name]
|
del self._per_agent_times[agent.name]
|
||||||
|
|
||||||
@property
|
def is_considered_reached(self, state):
|
||||||
def is_considered_reached(self):
|
agent_at_position = any(
|
||||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
c.AGENT.lower() in x.name.lower() for x in state.entities.pos_dict[self.pos] if x.var_can_collide)
|
||||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||||
|
|
||||||
def agent_is_dwelling(self, agent: Agent):
|
def agent_is_dwelling(self, agent: Agent):
|
||||||
@@ -68,9 +67,9 @@ class BoundDestination(BoundEntityMixin, Destination):
|
|||||||
self.bind_to(entity)
|
self.bind_to(entity)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_considered_reached(self):
|
def is_considered_reached(self):
|
||||||
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
|
agent_at_position = any(
|
||||||
|
self.bound_entity == x for x in self.state.entities.pos_dict[self.pos] if x.var_can_collide)
|
||||||
return (agent_at_position and not self.dwell_time) \
|
return (agent_at_position and not self.dwell_time) \
|
||||||
or any(x == 0 for x in self._per_agent_times[self.bound_entity.name])
|
or any(x == 0 for x in self._per_agent_times[self.bound_entity.name])
|
||||||
|
@@ -1,10 +1,11 @@
|
|||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
|
||||||
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination
|
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.modules.destinations import constants as d
|
||||||
|
|
||||||
|
|
||||||
class Destinations(PositionMixin, EnvObjects):
|
class Destinations(PositionMixin, EnvObjects):
|
||||||
|
|
||||||
_entity = Destination
|
_entity = Destination
|
||||||
is_blocking_light: bool = False
|
is_blocking_light: bool = False
|
||||||
can_collide: bool = False
|
can_collide: bool = False
|
||||||
@@ -15,9 +16,19 @@ class Destinations(PositionMixin, EnvObjects):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return super(Destinations, self).__repr__()
|
return super(Destinations, self).__repr__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trigger_destination_spawn(n_dests, state):
|
||||||
|
coordinates = state.entities.floorlist[:n_dests]
|
||||||
|
if destinations := [Destination(pos) for pos in coordinates]:
|
||||||
|
state[d.DESTINATION].add_items(destinations)
|
||||||
|
state.print(f'{n_dests} new destinations have been spawned')
|
||||||
|
return c.VALID
|
||||||
|
else:
|
||||||
|
state.print('No Destiantions are spawning, limit is reached.')
|
||||||
|
return c.NOT_VALID
|
||||||
|
|
||||||
|
|
||||||
class BoundDestinations(HasBoundMixin, Destinations):
|
class BoundDestinations(HasBoundMixin, Destinations):
|
||||||
|
|
||||||
_entity = BoundDestination
|
_entity = BoundDestination
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@@ -12,12 +12,12 @@ class DestinationReach(Rule):
|
|||||||
def __init__(self, n_dests: int = 1, tiles: Union[List, None] = None):
|
def __init__(self, n_dests: int = 1, tiles: Union[List, None] = None):
|
||||||
super(DestinationReach, self).__init__()
|
super(DestinationReach, self).__init__()
|
||||||
self.n_dests = n_dests or len(tiles)
|
self.n_dests = n_dests or len(tiles)
|
||||||
self._tiles = tiles
|
# self._tiles = tiles
|
||||||
|
|
||||||
def tick_step(self, state) -> List[TickResult]:
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
|
|
||||||
for dest in list(state[d.DESTINATION].values()):
|
for dest in list(state[d.DESTINATION].values()):
|
||||||
if dest.is_considered_reached:
|
if dest.is_considered_reached(state):
|
||||||
dest.change_parent_collection(state[d.DEST_REACHED])
|
dest.change_parent_collection(state[d.DEST_REACHED])
|
||||||
state.print(f'{dest.name} is reached now, removing...')
|
state.print(f'{dest.name} is reached now, removing...')
|
||||||
else:
|
else:
|
||||||
@@ -34,7 +34,7 @@ class DestinationReach(Rule):
|
|||||||
def tick_post_step(self, state) -> List[TickResult]:
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
results = list()
|
results = list()
|
||||||
for reached_dest in state[d.DEST_REACHED]:
|
for reached_dest in state[d.DEST_REACHED]:
|
||||||
for guest in reached_dest.tile.guests:
|
for guest in state.entities.pos_dict[reached_dest].values(): # reached_dest.tile.guests:
|
||||||
if guest in state[c.AGENT]:
|
if guest in state[c.AGENT]:
|
||||||
state.print(f'{guest.name} just reached destination at {guest.pos}')
|
state.print(f'{guest.name} just reached destination at {guest.pos}')
|
||||||
state[d.DEST_REACHED].delete_env_object(reached_dest)
|
state[d.DEST_REACHED].delete_env_object(reached_dest)
|
||||||
@@ -79,7 +79,7 @@ class DestinationSpawn(Rule):
|
|||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
self._dest_spawn_timer = self.spawn_frequency
|
self._dest_spawn_timer = self.spawn_frequency
|
||||||
self.trigger_destination_spawn(self.n_dests, state)
|
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
@@ -91,13 +91,3 @@ class DestinationSpawn(Rule):
|
|||||||
validity = state.rules['DestinationReach'].trigger_destination_spawn(n_dest_spawn, state)
|
validity = state.rules['DestinationReach'].trigger_destination_spawn(n_dest_spawn, state)
|
||||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def trigger_destination_spawn(n_dests, state, tiles=None):
|
|
||||||
tiles = tiles or state[c.FLOOR].empty_tiles[:n_dests]
|
|
||||||
if destinations := [Destination(tile) for tile in tiles]:
|
|
||||||
state[d.DESTINATION].add_items(destinations)
|
|
||||||
state.print(f'{n_dests} new destinations have been spawned')
|
|
||||||
return c.VALID
|
|
||||||
else:
|
|
||||||
state.print('No Destiantions are spawning, limit is reached.')
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
@@ -81,11 +81,11 @@ class Door(Entity):
|
|||||||
self._open()
|
self._open()
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
def tick(self):
|
def tick(self, state):
|
||||||
if self.is_open and len(self.tile) == 1 and self.time_to_close:
|
if self.is_open and len(state.entities.pos_dict[self.pos]) == 2 and self.time_to_close:
|
||||||
self.time_to_close -= 1
|
self.time_to_close -= 1
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
elif self.is_open and not self.time_to_close and len(self.tile) == 1:
|
elif self.is_open and not self.time_to_close and len(state.entities.pos_dict[self.pos]) == 2:
|
||||||
self.use()
|
self.use()
|
||||||
return c.VALID
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
|
@@ -14,9 +14,9 @@ class Doors(PositionMixin, EnvObjects):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
|
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
|
||||||
|
|
||||||
def tick_doors(self):
|
def tick_doors(self, state):
|
||||||
result_dict = dict()
|
result_dict = dict()
|
||||||
for door in self:
|
for door in self:
|
||||||
did_tick = door.tick()
|
did_tick = door.tick(state)
|
||||||
result_dict.update({door.name: did_tick})
|
result_dict.update({door.name: did_tick})
|
||||||
return result_dict
|
return result_dict
|
||||||
|
@@ -12,7 +12,7 @@ class DoorAutoClose(Rule):
|
|||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
if doors := state[d.DOORS]:
|
if doors := state[d.DOORS]:
|
||||||
doors_tick_result = doors.tick_doors()
|
doors_tick_result = doors.tick_doors(state)
|
||||||
doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
|
doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
|
||||||
state.print(f'{doors_that_ticked} were auto-closed'
|
state.print(f'{doors_that_ticked} were auto-closed'
|
||||||
if doors_that_ticked else 'No Doors were auto-closed')
|
if doors_that_ticked else 'No Doors were auto-closed')
|
||||||
|
@@ -23,7 +23,7 @@ class AgentSingleZonePlacementBeta(Rule):
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
tiles = [state[c.FLOOR].by_pos(pos) for pos in coordinates]
|
tiles = [state[c.FLOOR].by_pos(pos) for pos in coordinates]
|
||||||
for agent, tile in zip(agents, tiles):
|
for agent, tile in zip(agents, tiles):
|
||||||
agent.move(tile)
|
agent.move(tile, state)
|
||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
return []
|
return []
|
||||||
|
@@ -1,15 +1,14 @@
|
|||||||
from typing import List
|
from marl_factory_grid.modules.items import constants as i
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.objects import Objects
|
from marl_factory_grid.environment.groups.objects import Objects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundMixin
|
||||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
||||||
|
|
||||||
|
|
||||||
class Items(PositionMixin, EnvObjects):
|
class Items(PositionMixin, EnvObjects):
|
||||||
|
|
||||||
_entity = Item
|
_entity = Item
|
||||||
is_blocking_light: bool = False
|
is_blocking_light: bool = False
|
||||||
can_collide: bool = False
|
can_collide: bool = False
|
||||||
@@ -17,9 +16,19 @@ class Items(PositionMixin, EnvObjects):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trigger_item_spawn(state, n_items, spawn_frequency):
|
||||||
|
if item_to_spawns := max(0, (n_items - len(state[i.ITEM]))):
|
||||||
|
floor_list = state.entities.floorlist[:item_to_spawns]
|
||||||
|
state[i.ITEM].spawn(floor_list)
|
||||||
|
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {spawn_frequency}') # spawn in self._next_item_spawn ?
|
||||||
|
return len(floor_list)
|
||||||
|
else:
|
||||||
|
state.print('No Items are spawning, limit is reached.')
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class Inventory(IsBoundMixin, EnvObjects):
|
class Inventory(IsBoundMixin, EnvObjects):
|
||||||
|
|
||||||
_accepted_objects = Item
|
_accepted_objects = Item
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -27,7 +36,7 @@ class Inventory(IsBoundMixin, EnvObjects):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def __init__(self, agent: Agent, *args, **kwargs):
|
def __init__(self, agent: Agent, *args, **kwargs):
|
||||||
super(Inventory, self).__init__(*args, **kwargs)
|
super(Inventory, self).__init__(*args, **kwargs)
|
||||||
self._collection = None
|
self._collection = None
|
||||||
self.bind(agent)
|
self.bind(agent)
|
||||||
|
|
||||||
@@ -47,7 +56,6 @@ class Inventory(IsBoundMixin, EnvObjects):
|
|||||||
|
|
||||||
|
|
||||||
class Inventories(HasBoundMixin, Objects):
|
class Inventories(HasBoundMixin, Objects):
|
||||||
|
|
||||||
_entity = Inventory
|
_entity = Inventory
|
||||||
var_can_move = False
|
var_can_move = False
|
||||||
|
|
||||||
@@ -58,7 +66,7 @@ class Inventories(HasBoundMixin, Objects):
|
|||||||
self._lazy_eval_transforms = []
|
self._lazy_eval_transforms = []
|
||||||
|
|
||||||
def spawn(self, agents):
|
def spawn(self, agents):
|
||||||
inventories = [self._entity(agent, self.size,)
|
inventories = [self._entity(agent, self.size, )
|
||||||
for _, agent in enumerate(agents)]
|
for _, agent in enumerate(agents)]
|
||||||
self.add_items(inventories)
|
self.add_items(inventories)
|
||||||
|
|
||||||
@@ -77,12 +85,22 @@ class Inventories(HasBoundMixin, Objects):
|
|||||||
def summarize_states(self, **kwargs):
|
def summarize_states(self, **kwargs):
|
||||||
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trigger_inventory_spawn(state):
|
||||||
|
state[i.INVENTORY].spawn(state[c.AGENT])
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(PositionMixin, EnvObjects):
|
class DropOffLocations(PositionMixin, EnvObjects):
|
||||||
|
|
||||||
_entity = DropOffLocation
|
_entity = DropOffLocation
|
||||||
is_blocking_light: bool = False
|
is_blocking_light: bool = False
|
||||||
can_collide: bool = False
|
can_collide: bool = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(DropOffLocations, self).__init__(*args, **kwargs)
|
super(DropOffLocations, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trigger_drop_off_location_spawn(state, n_locations):
|
||||||
|
empty_tiles = state.entities.floorlist[:n_locations]
|
||||||
|
do_entites = state[i.DROP_OFF]
|
||||||
|
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
|
||||||
|
do_entites.add_items(drop_offs)
|
||||||
|
@@ -4,7 +4,6 @@ from marl_factory_grid.environment.rules import Rule
|
|||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.utils.results import TickResult
|
from marl_factory_grid.utils.results import TickResult
|
||||||
from marl_factory_grid.modules.items import constants as i
|
from marl_factory_grid.modules.items import constants as i
|
||||||
from marl_factory_grid.modules.items.entitites import DropOffLocation
|
|
||||||
|
|
||||||
|
|
||||||
class ItemRules(Rule):
|
class ItemRules(Rule):
|
||||||
@@ -19,10 +18,10 @@ class ItemRules(Rule):
|
|||||||
self.n_locations = n_locations
|
self.n_locations = n_locations
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
self.trigger_drop_off_location_spawn(state)
|
state[i.DROP_OFF].trigger_drop_off_location_spawn(state, self.n_locations)
|
||||||
self._next_item_spawn = self.spawn_frequency
|
self._next_item_spawn = self.spawn_frequency
|
||||||
self.trigger_inventory_spawn(state)
|
state[i.INVENTORY].trigger_inventory_spawn(state)
|
||||||
self.trigger_item_spawn(state)
|
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
|
||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
for item in list(state[i.ITEM].values()):
|
for item in list(state[i.ITEM].values()):
|
||||||
@@ -34,26 +33,11 @@ class ItemRules(Rule):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if not self._next_item_spawn:
|
if not self._next_item_spawn:
|
||||||
self.trigger_item_spawn(state)
|
state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency)
|
||||||
else:
|
else:
|
||||||
self._next_item_spawn = max(0, self._next_item_spawn - 1)
|
self._next_item_spawn = max(0, self._next_item_spawn - 1)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def trigger_item_spawn(self, state):
|
|
||||||
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
|
|
||||||
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
|
|
||||||
state[i.ITEM].spawn(empty_tiles)
|
|
||||||
self._next_item_spawn = self.spawn_frequency
|
|
||||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
|
||||||
return len(empty_tiles)
|
|
||||||
else:
|
|
||||||
state.print('No Items are spawning, limit is reached.')
|
|
||||||
return 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def trigger_inventory_spawn(state):
|
|
||||||
state[i.INVENTORY].spawn(state[c.AGENT])
|
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
for item in list(state[i.ITEM].values()):
|
for item in list(state[i.ITEM].values()):
|
||||||
if item.auto_despawn >= 1:
|
if item.auto_despawn >= 1:
|
||||||
@@ -64,7 +48,7 @@ class ItemRules(Rule):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if not self._next_item_spawn:
|
if not self._next_item_spawn:
|
||||||
if spawned_items := self.trigger_item_spawn(state):
|
if spawned_items := state[i.ITEM].trigger_item_spawn(state, self.n_items, self.spawn_frequency):
|
||||||
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
|
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
|
||||||
else:
|
else:
|
||||||
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
|
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
|
||||||
@@ -72,8 +56,3 @@ class ItemRules(Rule):
|
|||||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def trigger_drop_off_location_spawn(self, state):
|
|
||||||
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_locations]
|
|
||||||
do_entites = state[i.DROP_OFF]
|
|
||||||
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
|
|
||||||
do_entites.add_items(drop_offs)
|
|
||||||
|
@@ -47,9 +47,11 @@ class Machine(Entity):
|
|||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def tick(self):
|
def tick(self):
|
||||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
# if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||||
|
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
|
||||||
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
|
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
|
||||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
# elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||||
|
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.state.entities.pos_dict[self.pos]]):
|
||||||
self.status = m.STATE_WORK
|
self.status = m.STATE_WORK
|
||||||
self.reset_counter()
|
self.reset_counter()
|
||||||
return None
|
return None
|
||||||
|
@@ -61,7 +61,7 @@ class Maintainer(Entity):
|
|||||||
self._last.append(self._next.pop())
|
self._last.append(self._next.pop())
|
||||||
self._path = self.calculate_route(self._last[-1])
|
self._path = self.calculate_route(self._last[-1])
|
||||||
|
|
||||||
if door := self._door_is_close():
|
if door := self._door_is_close(state):
|
||||||
if door.is_closed:
|
if door.is_closed:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = do.ACTION_DOOR_USE
|
action = do.ACTION_DOOR_USE
|
||||||
@@ -81,15 +81,18 @@ class Maintainer(Entity):
|
|||||||
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
||||||
return route[1:]
|
return route[1:]
|
||||||
|
|
||||||
def _door_is_close(self):
|
def _door_is_close(self, state):
|
||||||
|
print("doorclose")
|
||||||
try:
|
try:
|
||||||
return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
# return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||||
|
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _predict_move(self, state):
|
def _predict_move(self, state):
|
||||||
next_pos = self._path[0]
|
next_pos = self._path[0]
|
||||||
if len(state[c.FLOOR].by_pos(next_pos).guests_that_can_collide) > 0:
|
# if len(state[c.FLOOR].by_pos(next_pos).guests_that_can_collide) > 0:
|
||||||
|
if any(x for x in state.entities.pos_dict[next_pos] if x.var_can_collide) > 0:
|
||||||
action = c.NOOP
|
action = c.NOOP
|
||||||
else:
|
else:
|
||||||
next_pos = self._path.pop(0)
|
next_pos = self._path.pop(0)
|
||||||
|
@@ -23,5 +23,5 @@ class Maintainers(PositionMixin, EnvObjects):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def spawn(self, tiles: List[Floor], state: Gamestate):
|
def spawn(self, position, state: Gamestate):
|
||||||
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), tile) for tile in tiles])
|
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in position])
|
||||||
|
@@ -38,7 +38,7 @@ class AgentSingleZonePlacement(Rule):
|
|||||||
|
|
||||||
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
||||||
for agent in state[c.AGENT]:
|
for agent in state[c.AGENT]:
|
||||||
agent.move(state[z.ZONES][z_idxs.pop()].random_tile)
|
agent.move(state[z.ZONES][z_idxs.pop()].random_tile, state)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
|
@@ -14,7 +14,6 @@ MODULE_PATH = 'modules'
|
|||||||
|
|
||||||
|
|
||||||
class FactoryConfigParser(object):
|
class FactoryConfigParser(object):
|
||||||
|
|
||||||
default_entites = []
|
default_entites = []
|
||||||
default_rules = ['MaxStepsReached', 'Collision']
|
default_rules = ['MaxStepsReached', 'Collision']
|
||||||
default_actions = [c.MOVE8, c.NOOP]
|
default_actions = [c.MOVE8, c.NOOP]
|
||||||
@@ -89,7 +88,7 @@ class FactoryConfigParser(object):
|
|||||||
|
|
||||||
def load_agents(self, size, free_tiles):
|
def load_agents(self, size, free_tiles):
|
||||||
agents = Agents(size)
|
agents = Agents(size)
|
||||||
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
||||||
for name in self.agents:
|
for name in self.agents:
|
||||||
# Actions
|
# Actions
|
||||||
actions = list()
|
actions = list()
|
||||||
|
@@ -22,7 +22,7 @@ class LevelParser(object):
|
|||||||
self._parsed_level = h.parse_level(Path(level_file_path))
|
self._parsed_level = h.parse_level(Path(level_file_path))
|
||||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||||
self.level_shape = level_array.shape
|
self.level_shape = level_array.shape
|
||||||
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
|
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
|
||||||
|
|
||||||
def get_coordinates_for_symbol(self, symbol, negate=False):
|
def get_coordinates_for_symbol(self, symbol, negate=False):
|
||||||
level_array = h.one_hot_level(self._parsed_level, symbol)
|
level_array = h.one_hot_level(self._parsed_level, symbol)
|
||||||
@@ -32,16 +32,16 @@ class LevelParser(object):
|
|||||||
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
|
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
entities = Entities()
|
list_of_all_floors = ([tuple(floor) for floor in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
||||||
|
entities = Entities(list_of_all_floors)
|
||||||
|
|
||||||
# Walls
|
# Walls
|
||||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||||
# walls = self.get_coordinates_for_symbol(c.SYMBOL_WALL)
|
|
||||||
entities.add_items({c.WALL: walls})
|
entities.add_items({c.WALL: walls})
|
||||||
|
|
||||||
# Floor
|
# Floor
|
||||||
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
|
floor = Floors.from_coordinates(list_of_all_floors, self.size)
|
||||||
entities.add_items({c.FLOOR: floor})
|
entities.add_items({c.FLOOR: floor})
|
||||||
# entities.add_items({c.WALL: self.get_coordinates_for_symbol(c.SYMBOL_WALL, negative=True)})
|
|
||||||
|
|
||||||
# All other
|
# All other
|
||||||
for es_name in self.e_p_dict:
|
for es_name in self.e_p_dict:
|
||||||
@@ -55,10 +55,7 @@ class LevelParser(object):
|
|||||||
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
||||||
if np.any(level_array):
|
if np.any(level_array):
|
||||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
self.size, entity_kwargs=e_kwargs)
|
||||||
)
|
|
||||||
# e_coords = (np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist()) # braucht e_class?
|
|
||||||
# entities.add_items({e.name: e_coords})
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||||
f'Check your level file!')
|
f'Check your level file!')
|
||||||
|
@@ -1,8 +1,7 @@
|
|||||||
from typing import List, Dict
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||||
from marl_factory_grid.environment.rules import Rule
|
from marl_factory_grid.environment.rules import Rule
|
||||||
@@ -108,8 +107,22 @@ class Gamestate(object):
|
|||||||
results.extend(on_check_done_result)
|
results.extend(on_check_done_result)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_all_tiles_with_collisions(self) -> List[Floor]: # -> List[Tuple(Int,Int)]
|
# def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||||
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
# tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||||
if sum([x.var_can_collide for x in e]) > 1]
|
# if sum([x.var_can_collide for x in e]) > 1]
|
||||||
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
# # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||||
return tiles
|
# return tiles
|
||||||
|
|
||||||
|
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||||
|
positions = [pos for pos, e in self.entities.pos_dict.items()
|
||||||
|
if sum([x.var_can_collide for x in e]) > 1]
|
||||||
|
return positions
|
||||||
|
|
||||||
|
def check_move_validity(self, moving_entity, position):
|
||||||
|
# if (guest.name not in self._guests and not self.is_blocked)
|
||||||
|
# and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||||
|
if moving_entity.pos != position and not any(
|
||||||
|
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
||||||
|
moving_entity.var_is_blocking_pos and moving_entity.is_occupied()):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
Reference in New Issue
Block a user