mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-15 23:37:14 +02:00
Remove BoundDestination Object
New Variable 'var_can_be_bound' Observations adjusted accordingly
This commit is contained in:
@@ -56,6 +56,7 @@ class Entity(EnvObject, abc.ABC):
|
||||
return last_x - curr_x, last_y - curr_y
|
||||
|
||||
def destroy(self):
|
||||
if
|
||||
valid = self._collection.remove_item(self)
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(self)
|
||||
@@ -73,10 +74,17 @@ class Entity(EnvObject, abc.ABC):
|
||||
return valid
|
||||
return not_same_tile
|
||||
|
||||
def __init__(self, tile, **kwargs):
|
||||
def __init__(self, tile, bind_to=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._status = None
|
||||
self._tile = tile
|
||||
if bind_to:
|
||||
try:
|
||||
self.bind_to(bind_to)
|
||||
except AttributeError:
|
||||
print(f'Objects of {self.__class__.__name__} can not be bound to other entities.')
|
||||
exit()
|
||||
|
||||
assert tile.enter(self, spawn=True), "Positions was not valid!"
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
|
@@ -9,10 +9,16 @@ class BoundEntityMixin:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self.bound_entity.name})'
|
||||
if self.bound_entity:
|
||||
return f'{self.__class__.__name__}({self.bound_entity.name})'
|
||||
else:
|
||||
print()
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return entity == self.bound_entity
|
||||
|
||||
def bind_to(self, entity):
|
||||
self._bound_entity = entity
|
||||
|
||||
def unbind(self):
|
||||
self._bound_entity = None
|
||||
|
@@ -91,6 +91,13 @@ class EnvObject(Object):
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
try:
|
||||
return self._collection.var_can_be_bound or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
try:
|
||||
|
@@ -90,7 +90,7 @@ class Factory(gym.Env):
|
||||
|
||||
# Parse the agent conf
|
||||
parsed_agents_conf = self.conf.parse_agents_conf()
|
||||
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed)
|
||||
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose)
|
||||
|
||||
# All is set up, trigger entity init with variable pos
|
||||
self.state.rules.do_all_init(self.state, self.map)
|
||||
@@ -235,10 +235,6 @@ class Factory(gym.Env):
|
||||
del summary[key]
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
if self.conf.verbose:
|
||||
print(string)
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
# noinspection PyProtectedMember
|
||||
filepath = Path(filepath)
|
||||
|
@@ -11,10 +11,6 @@ class Agents(PositionMixin, EnvObjects):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(a.name, a) for a in self]
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
from gymnasium import spaces
|
||||
|
@@ -9,6 +9,7 @@ class EnvObjects(Objects):
|
||||
var_can_collide: bool = False
|
||||
var_has_position: bool = False
|
||||
var_can_move: bool = False
|
||||
var_can_be_bound: bool = False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
|
@@ -92,11 +92,11 @@ class HasBoundMixin:
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
@@ -24,7 +24,9 @@ class Objects:
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(self.name, self)]
|
||||
pair_list = [(self.name, self)]
|
||||
pair_list.extend([(a.name, a) for a in self])
|
||||
return pair_list
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
|
@@ -124,12 +124,13 @@ class Collision(Rule):
|
||||
pass
|
||||
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
||||
reward=r.COLLISION, validity=c.VALID))
|
||||
self.curr_done = True
|
||||
self.curr_done = True if self.done_at_collisions else False
|
||||
return results
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
inter_entity_collision_detected = self.curr_done and self.done_at_collisions
|
||||
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
||||
if inter_entity_collision_detected or move_failed:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||
if self.done_at_collisions:
|
||||
inter_entity_collision_detected = self.curr_done
|
||||
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
||||
if inter_entity_collision_detected or move_failed:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
|
Reference in New Issue
Block a user