Redone the spawn procedute and destination objects

This commit is contained in:
Steffen Illium
2023-10-11 16:36:48 +02:00
parent e64fa84ef1
commit e326a95bf4
32 changed files with 266 additions and 146 deletions

View File

@@ -1,3 +1,5 @@
import ast
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union
@@ -80,15 +82,15 @@ class FactoryConfigParser(object):
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents))
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
return entity_classes
def load_agents(self, size, free_tiles):
agents = Agents(size)
def parse_agents_conf(self):
parsed_agents_conf = dict()
base_env_actions = self.default_actions.copy() + [c.MOVE4]
for name in self.agents:
# Actions
@@ -116,9 +118,9 @@ class FactoryConfigParser(object):
if c.DEFAULTS in self.agents[name]['Observations']:
observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
agent = Agent(parsed_actions, observations, free_tiles.pop(), str_ident=name)
agents.add_item(agent)
return agents
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
return parsed_agents_conf
def load_rules(self):
# entites = Entities()

View File

@@ -4,6 +4,7 @@ from typing import Dict
import numpy as np
from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.groups.global_entities import Entities
from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors
from marl_factory_grid.utils import helpers as h
@@ -35,11 +36,12 @@ class LevelParser(object):
entities = Entities()
# Walls
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
entities.add_items({c.WALL: walls})
entities.add_items({c.WALLS: walls})
# Floor
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
entities.add_items({c.FLOOR: floor})
entities.add_items({c.FLOORS: floor})
entities.add_items({c.AGENT: Agents(self.size)})
# All other
for es_name in self.e_p_dict:
@@ -52,8 +54,9 @@ class LevelParser(object):
for symbol in symbols:
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
if np.any(level_array):
# TODO: Get rid of this!
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
entities[c.FLOORS], self.size, entity_kwargs=e_kwargs
)
else:
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'

View File

@@ -41,7 +41,7 @@ class OBSBuilder(object):
self.curr_lightmaps = dict()
def reset_struc_obs_block(self, state):
self._curr_env_step = state.curr_step.copy()
self._curr_env_step = state.curr_step
# Construct an empty obs (array) for possible placeholders
self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float)
# Fill the all_obs-dict with all available entities

View File

@@ -5,6 +5,7 @@ import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.groups.global_entities import Entities
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result
@@ -43,7 +44,7 @@ class StepRules:
def tick_pre_step_all(self, state):
results = list()
for rule in self.rules:
if tick_pre_step_result := rule.tick_post_step(state):
if tick_pre_step_result := rule.tick_pre_step(state):
results.extend(tick_pre_step_result)
return results
@@ -61,11 +62,12 @@ class Gamestate(object):
def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
self.entities = entitites
def __init__(self, entitites, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
self.entities: Entities = entitites
self.NO_POS_TILE = Floor(c.VALUE_NO_POS)
self.curr_step = 0
self.curr_actions = None
self.agents_conf = agents_conf
self.verbose = verbose
self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
@@ -113,7 +115,7 @@ class Gamestate(object):
return results
def get_all_tiles_with_collisions(self) -> List[Floor]:
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items()
if sum([x.var_can_collide for x in e]) > 1]
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
return tiles