mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
WIP: removing tiles
This commit is contained in:
@@ -14,7 +14,6 @@ MODULE_PATH = 'modules'
|
||||
|
||||
|
||||
class FactoryConfigParser(object):
|
||||
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
@@ -89,7 +88,7 @@ class FactoryConfigParser(object):
|
||||
|
||||
def load_agents(self, size, free_tiles):
|
||||
agents = Agents(size)
|
||||
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
||||
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
||||
for name in self.agents:
|
||||
# Actions
|
||||
actions = list()
|
||||
|
||||
@@ -22,7 +22,7 @@ class LevelParser(object):
|
||||
self._parsed_level = h.parse_level(Path(level_file_path))
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
self.level_shape = level_array.shape
|
||||
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
|
||||
def get_coordinates_for_symbol(self, symbol, negate=False):
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol)
|
||||
@@ -32,16 +32,16 @@ class LevelParser(object):
|
||||
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
|
||||
|
||||
def do_init(self):
|
||||
entities = Entities()
|
||||
list_of_all_floors = ([tuple(floor) for floor in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
||||
entities = Entities(list_of_all_floors)
|
||||
|
||||
# Walls
|
||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||
# walls = self.get_coordinates_for_symbol(c.SYMBOL_WALL)
|
||||
entities.add_items({c.WALL: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
|
||||
floor = Floors.from_coordinates(list_of_all_floors, self.size)
|
||||
entities.add_items({c.FLOOR: floor})
|
||||
# entities.add_items({c.WALL: self.get_coordinates_for_symbol(c.SYMBOL_WALL, negative=True)})
|
||||
|
||||
# All other
|
||||
for es_name in self.e_p_dict:
|
||||
@@ -55,10 +55,7 @@ class LevelParser(object):
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
||||
if np.any(level_array):
|
||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
||||
)
|
||||
# e_coords = (np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist()) # braucht e_class?
|
||||
# entities.add_items({e.name: e_coords})
|
||||
self.size, entity_kwargs=e_kwargs)
|
||||
else:
|
||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||
f'Check your level file!')
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
@@ -108,8 +107,22 @@ class Gamestate(object):
|
||||
results.extend(on_check_done_result)
|
||||
return results
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Floor]: # -> List[Tuple(Int,Int)]
|
||||
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.var_can_collide for x in e]) > 1]
|
||||
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
return tiles
|
||||
# def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
# tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
# if sum([x.var_can_collide for x in e]) > 1]
|
||||
# # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
# return tiles
|
||||
|
||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||
positions = [pos for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.var_can_collide for x in e]) > 1]
|
||||
return positions
|
||||
|
||||
def check_move_validity(self, moving_entity, position):
|
||||
# if (guest.name not in self._guests and not self.is_blocked)
|
||||
# and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||
if moving_entity.pos != position and not any(
|
||||
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
||||
moving_entity.var_is_blocking_pos and moving_entity.is_occupied()):
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user