mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
Redone the spawn procedute and destination objects
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -80,15 +82,15 @@ class FactoryConfigParser(object):
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e3:
|
||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents))
|
||||
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
||||
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
||||
return entity_classes
|
||||
|
||||
def load_agents(self, size, free_tiles):
|
||||
agents = Agents(size)
|
||||
def parse_agents_conf(self):
|
||||
parsed_agents_conf = dict()
|
||||
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
||||
for name in self.agents:
|
||||
# Actions
|
||||
@@ -116,9 +118,9 @@ class FactoryConfigParser(object):
|
||||
if c.DEFAULTS in self.agents[name]['Observations']:
|
||||
observations.extend(self.default_observations)
|
||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||
agent = Agent(parsed_actions, observations, free_tiles.pop(), str_ident=name)
|
||||
agents.add_item(agent)
|
||||
return agents
|
||||
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
|
||||
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
||||
return parsed_agents_conf
|
||||
|
||||
def load_rules(self):
|
||||
# entites = Entities()
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.groups.agents import Agents
|
||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
||||
from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
@@ -35,11 +36,12 @@ class LevelParser(object):
|
||||
entities = Entities()
|
||||
# Walls
|
||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||
entities.add_items({c.WALL: walls})
|
||||
entities.add_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
|
||||
entities.add_items({c.FLOOR: floor})
|
||||
entities.add_items({c.FLOORS: floor})
|
||||
entities.add_items({c.AGENT: Agents(self.size)})
|
||||
|
||||
# All other
|
||||
for es_name in self.e_p_dict:
|
||||
@@ -52,8 +54,9 @@ class LevelParser(object):
|
||||
for symbol in symbols:
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
||||
if np.any(level_array):
|
||||
# TODO: Get rid of this!
|
||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
||||
entities[c.FLOORS], self.size, entity_kwargs=e_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||
|
||||
@@ -41,7 +41,7 @@ class OBSBuilder(object):
|
||||
self.curr_lightmaps = dict()
|
||||
|
||||
def reset_struc_obs_block(self, state):
|
||||
self._curr_env_step = state.curr_step.copy()
|
||||
self._curr_env_step = state.curr_step
|
||||
# Construct an empty obs (array) for possible placeholders
|
||||
self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float)
|
||||
# Fill the all_obs-dict with all available entities
|
||||
|
||||
@@ -5,6 +5,7 @@ import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
@@ -43,7 +44,7 @@ class StepRules:
|
||||
def tick_pre_step_all(self, state):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if tick_pre_step_result := rule.tick_post_step(state):
|
||||
if tick_pre_step_result := rule.tick_pre_step(state):
|
||||
results.extend(tick_pre_step_result)
|
||||
return results
|
||||
|
||||
@@ -61,11 +62,12 @@ class Gamestate(object):
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
self.entities = entitites
|
||||
def __init__(self, entitites, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
self.entities: Entities = entitites
|
||||
self.NO_POS_TILE = Floor(c.VALUE_NO_POS)
|
||||
self.curr_step = 0
|
||||
self.curr_actions = None
|
||||
self.agents_conf = agents_conf
|
||||
self.verbose = verbose
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
|
||||
@@ -113,7 +115,7 @@ class Gamestate(object):
|
||||
return results
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.var_can_collide for x in e]) > 1]
|
||||
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
return tiles
|
||||
|
||||
Reference in New Issue
Block a user