mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
Comments, small bugfixes removed legacy elements
This commit is contained in:
@@ -249,4 +249,11 @@ def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: T
|
||||
|
||||
|
||||
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
"""
|
||||
todo
|
||||
|
||||
:param iterable:
|
||||
:param filter_by:
|
||||
:return:
|
||||
"""
|
||||
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)
|
||||
|
||||
@@ -18,12 +18,24 @@ class OBSBuilder(object):
|
||||
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.pomdp_r:
|
||||
return (self.pomdp_r * 2) + 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.all_obs = dict()
|
||||
self.ray_caster = dict()
|
||||
|
||||
|
||||
@@ -7,6 +7,12 @@ from numba import njit
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||
|
||||
@@ -33,6 +33,12 @@ class Renderer:
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
# TODO: Customn_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
|
||||
@@ -10,6 +10,12 @@ TYPES = [TYPE_VALUE, TYPE_REWARD]
|
||||
|
||||
@dataclass
|
||||
class InfoObject:
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
identifier: str
|
||||
val_type: str
|
||||
value: Union[float, int]
|
||||
@@ -17,6 +23,12 @@ class InfoObject:
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
identifier: str
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
@@ -40,6 +52,17 @@ class Result:
|
||||
|
||||
@dataclass
|
||||
class TickResult(Result):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionResult(Result):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -47,6 +70,11 @@ class TickResult(Result):
|
||||
class ActionResult(Result):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class State(Result):
|
||||
# TODO: change identifiert to action/last_action
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DoneResult(Result):
|
||||
|
||||
@@ -12,6 +12,12 @@ from marl_factory_grid.utils.results import Result, DoneResult
|
||||
|
||||
class StepRules:
|
||||
def __init__(self, *args):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
if args:
|
||||
self.rules = list(args)
|
||||
else:
|
||||
@@ -77,6 +83,12 @@ class Gamestate(object):
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.lvl_shape = lvl_shape
|
||||
self.entities = entities
|
||||
self.curr_step = 0
|
||||
|
||||
@@ -22,6 +22,12 @@ EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPosition
|
||||
class ConfigExplainer:
|
||||
|
||||
def __init__(self, custom_path: Union[None, PathLike] = None):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.base_path = Path(__file__).parent.parent.resolve()
|
||||
self.custom_path = custom_path
|
||||
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, ASSETS]
|
||||
|
||||
@@ -18,6 +18,12 @@ class MarlFrameStack(gym.ObservationWrapper):
|
||||
|
||||
@dataclass
|
||||
class RenderEntity:
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
name: str
|
||||
pos: np.array
|
||||
value: float = 1
|
||||
@@ -30,6 +36,12 @@ class RenderEntity:
|
||||
|
||||
@dataclass
|
||||
class Floor:
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
|
||||
Reference in New Issue
Block a user