mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-26 05:01:36 +02:00
Comments, small bugfixes removed legacy elements
This commit is contained in:
@ -26,12 +26,24 @@ class Agents(Collection):
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
from gymnasium import spaces
|
||||
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
|
||||
return space
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
def named_action_space(self) -> dict[str, dict[str, list[int]]]:
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
named_space = dict()
|
||||
for agent in self:
|
||||
named_space[agent.name] = {action.name: idx for idx, action in enumerate(agent.actions)}
|
||||
|
@ -118,12 +118,6 @@ class Collection(Objects):
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def render(self):
|
||||
if self.var_has_position:
|
||||
return [y for y in [x.render() for x in self] if y is not None]
|
||||
|
@ -28,9 +28,3 @@ class HasBoundMixin:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
@ -160,7 +160,7 @@ class Objects:
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
|
||||
return h.get_first_index(self, filter_by=lambda x: x == entity)
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
@ -39,17 +39,36 @@ class GlobalPositions(Collection):
|
||||
|
||||
_entity = GlobalPosition
|
||||
|
||||
var_is_blocking_light = False
|
||||
var_can_be_bound = True
|
||||
var_can_collide = False
|
||||
var_has_position = False
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||
|
||||
def spawn(self, agents, level_shape, *args, **kwargs):
|
||||
def spawn(self, agents, level_shape, *args, **kwargs) -> list[Result]:
|
||||
self.add_items([self._entity(agent, level_shape, *args, **kwargs) for agent in agents])
|
||||
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
|
||||
|
||||
def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> [Result]:
|
||||
return self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)
|
||||
def trigger_spawn(self, state: Gamestate, *args, **kwargs) -> list[Result]:
|
||||
result = self.spawn(state[c.AGENT], state.lvl_shape, *args, **kwargs)
|
||||
state.print(f'{len(self)} new {self.__class__.__name__} have been spawned for {[x for x in state[c.AGENT]]}')
|
||||
return result
|
||||
|
Reference in New Issue
Block a user