Redone the spawn procedute and destination objects

This commit is contained in:
Steffen Illium
2023-10-11 16:36:48 +02:00
parent e64fa84ef1
commit e326a95bf4
32 changed files with 266 additions and 146 deletions

View File

@ -21,7 +21,7 @@ class TSPBaseAgent(ABC):
self.local_optimization = True
self._env = state
self.state = self._env.state[c.AGENT][agent_i]
self._floortile_graph = points_to_graph(self._env[c.FLOOR].positions)
self._floortile_graph = points_to_graph(self._env[c.FLOORS].positions)
self._static_route = None
@abstractmethod

View File

@ -1,4 +1,10 @@
Agents:
Eberhart:
Actions:
- Move8
- Noop
- ItemAction
Observations:
Wolfgang:
Actions:
- Noop
@ -41,7 +47,6 @@ Entities:
Machines: {}
Maintainers: {}
Zones: {}
ReachedDestinations: {}
General:
env_seed: 69

View File

@ -0,0 +1,44 @@
Agents:
Wolfgang:
Actions:
- Noop
- Move8
Observations:
- Walls
- BoundDestination
Positions:
- (2, 1)
- (2, 5)
Karl-Heinz:
Actions:
- Noop
- Move8
Observations:
- Walls
- BoundDestination
Positions:
- (2, 1)
- (2, 5)
Entities:
BoundDestinations: {}
General:
env_seed: 69
individual_rewards: true
level_name: narrow_corridor
pomdp_r: 0
verbose: true
Rules:
SpawnAgents: {}
Collision:
done_at_collisions: true
FixedDestinationSpawn:
per_agent_positions:
Wolfgang:
- (2, 1)
- (2, 5)
Karl-Heinz:
- (2, 1)
- (2, 5)
DestinationReachAll: {}

View File

@ -7,7 +7,6 @@ General:
Entities:
BoundDestinations: {}
ReachedDestinations: {}
Doors: {}
GlobalPositions: {}
Zones: {}

View File

@ -42,13 +42,15 @@ class Move(Action, abc.ABC):
def do(self, entity, env):
new_pos = self._calc_new_pos(entity.pos)
if next_tile := env[c.FLOOR].by_pos(new_pos):
if next_tile := env[c.FLOORS].by_pos(new_pos):
# noinspection PyUnresolvedReferences
valid = entity.move(next_tile)
else:
valid = c.NOT_VALID
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
move_validity = entity.move(next_tile)
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
else: # There is no floor, propably collision
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier]

View File

@ -55,6 +55,12 @@ class Entity(EnvObject, abc.ABC):
curr_x, curr_y = self.pos
return last_x - curr_x, last_y - curr_y
def destroy(self):
valid = self._collection.remove_item(self)
for observer in self.observers:
observer.notify_del_entity(self)
return valid
def move(self, next_tile):
curr_tile = self.tile
if not_same_tile := curr_tile != next_tile:
@ -71,7 +77,7 @@ class Entity(EnvObject, abc.ABC):
super().__init__(**kwargs)
self._status = None
self._tile = tile
tile.enter(self)
assert tile.enter(self, spawn=True), "Positions was not valid!"
def summarize_state(self) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y),

View File

@ -81,8 +81,12 @@ class Floor(EnvObject):
def is_occupied(self):
return bool(len(self._guests))
def enter(self, guest):
if (guest.name not in self._guests and not self.is_blocked) and not (guest.var_is_blocking_pos and self.is_occupied()):
def enter(self, guest, spawn=False):
same_pos = guest.name not in self._guests
not_blocked = not self.is_blocked
no_become_blocked_when_occupied = not (guest.var_is_blocking_pos and self.is_occupied())
not_introduce_collision = not (spawn and guest.var_can_collide and any(x.var_can_collide for x in self.guests))
if same_pos and not_blocked and no_become_blocked_when_occupied and not_introduce_collision:
self._guests.update({guest.name: guest})
return c.VALID
else:

View File

@ -85,17 +85,14 @@ class Factory(gym.Env):
# Init entity:
entities = self.map.do_init()
# Grab all )rules:
# Grab all env-rules:
rules = self.conf.load_rules()
# Agents
# noinspection PyAttributeOutsideInit
self.state = Gamestate(entities, rules, self.conf.env_seed)
# Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf()
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed)
agents = self.conf.load_agents(self.map.size, self[c.FLOOR].empty_tiles)
self.state.entities.add_item({c.AGENT: agents})
# All is set up, trigger additional init (after agent entity spawn etc)
# All is set up, trigger entity init with variable pos
self.state.rules.do_all_init(self.state, self.map)
# Observations
@ -173,6 +170,8 @@ class Factory(gym.Env):
# Combine Info dicts into a global one
combined_info_dict = defaultdict(lambda: 0.0)
for result in chain(tick_results, done_check_results):
if not result:
raise ValueError()
if result.reward is not None:
try:
rewards[result.entity.name] += result.reward

View File

@ -57,6 +57,16 @@ class Objects:
observer.notify_add_entity(item)
return self
def remove_item(self, item: _entity):
for observer in self.observers:
observer.notify_del_entity(item)
# noinspection PyTypeChecker
del self._data[item.name]
return True
def __delitem__(self, name):
return self.remove_item(self[name])
# noinspection PyUnresolvedReferences
def del_observer(self, observer):
self.observers.remove(observer)
@ -71,12 +81,6 @@ class Objects:
if observer not in entity.observers:
entity.add_observer(observer)
def __delitem__(self, name):
for observer in self.observers:
observer.notify_del_entity(name)
# noinspection PyTypeChecker
del self._data[name]
def add_items(self, items: List[_entity]):
for item in items:
self.add_item(item)
@ -114,7 +118,8 @@ class Objects:
raise TypeError
def __repr__(self):
return f'{self.__class__.__name__}[{dict(self._data)}]'
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS, c.FLOORS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def spawn(self, n: int):
self.add_items([self._entity() for _ in range(n)])
@ -138,6 +143,7 @@ class Objects:
def notify_del_entity(self, entity: Object):
try:
entity.del_observer(self)
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
@ -146,7 +152,9 @@ class Objects:
try:
if self not in entity.observers:
entity.add_observer(self)
self.pos_dict[entity.pos].append(entity)
if entity.var_has_position:
if entity not in self.pos_dict[entity.pos]:
self.pos_dict[entity.pos].append(entity)
except (ValueError, AttributeError):
pass

View File

@ -1,6 +1,9 @@
import abc
from random import shuffle
from typing import List
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import rewards as r, constants as c
@ -36,6 +39,40 @@ class Rule(abc.ABC):
return []
class SpawnAgents(Rule):
def __init__(self):
super().__init__()
pass
def on_init(self, state, lvl_map):
agent_conf = state.agents_conf
# agents = Agents(lvl_map.size)
agents = state[c.AGENT]
empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)]
for agent_name in agent_conf:
actions = agent_conf[agent_name]['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy()
positions = agent_conf[agent_name]['positions'].copy()
if positions:
shuffle(positions)
while True:
try:
tile = state[c.FLOORS].by_pos(positions.pop())
except IndexError as e:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_name[agent_name]["positions"].copy()}')
try:
agents.add_item(Agent(actions, observations, tile, str_ident=agent_name))
except AssertionError:
state.print(f'No valid pos:{tile.pos} for {agent_name}')
continue
break
else:
agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name))
pass
class MaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
@ -91,6 +128,8 @@ class Collision(Rule):
return results
def on_check_done(self, state) -> List[DoneResult]:
if self.curr_done and self.done_at_collisions:
inter_entity_collision_detected = self.curr_done and self.done_at_collisions
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]

View File

@ -0,0 +1,28 @@
from random import shuffle
from typing import List, Tuple
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.modules.destinations.entitites import BoundDestination
class NarrowCorridorSpawn(Rule):
def __init__(self, positions: List[Tuple[int, int]], fixed: bool = False):
super().__init__()
self.fixed = fixed
self.positions = positions
def on_init(self, state, lvl_map):
if not self.fixed:
shuffle(self.positions)
for agent in state[c.AGENT]:
pass
def trigger_destination_spawn(self, state):
for (agent_name, position_list) in self.per_agent_positions.items():
agent = state[c.AGENT][agent_name]
destinations = [BoundDestination(agent, pos) for pos in position_list]
state[d.DESTINATION].add_items(destinations)
return c.VALID

View File

@ -70,7 +70,7 @@ class PodRules(Rule):
def on_init(self, state, lvl_map):
pod_collection = state[b.CHARGE_PODS]
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_pods]
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_pods]
pods = pod_collection.from_tiles(empty_tiles, entity_kwargs=dict(
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
)

View File

@ -47,7 +47,7 @@ class DirtPiles(PositionMixin, EnvObjects):
return c.VALID
def trigger_dirt_spawn(self, state, initial_spawn=False) -> bool:
free_for_dirt = [x for x in state[c.FLOOR]
free_for_dirt = [x for x in state[c.FLOORS]
if len(x.guests) == 0 or (
len(x.guests) == 1 and
isinstance(next(y for y in x.guests), DirtPile))

View File

@ -1,4 +1,4 @@
from .actions import DestAction
from .entitites import Destination
from .groups import ReachedDestinations, Destinations
from .rules import DestinationDone, DestinationReach, DestinationSpawn
from .groups import Destinations, BoundDestinations
from .rules import DestinationReachAll, DestinationSpawn

View File

@ -3,8 +3,6 @@
DESTINATION = 'Destinations'
BOUNDDESTINATION = 'BoundDestinations'
DEST_SYMBOL = 1
DEST_REACHED_REWARD = 0.5
DEST_REACHED = 'ReachedDestinations'
WAIT_ON_DEST = 'WAIT'

View File

@ -16,42 +16,31 @@ class Destination(Entity):
var_is_blocking_pos = False
var_is_blocking_light = False
@property
def any_agent_has_dwelled(self):
return bool(len(self._per_agent_times))
@property
def currently_dwelling_names(self):
return list(self._per_agent_times.keys())
@property
def encoding(self):
return d.DEST_SYMBOL
def __init__(self, *args, dwell_time: int = 0, **kwargs):
def __init__(self, *args, action_counts=0, **kwargs):
super(Destination, self).__init__(*args, **kwargs)
self.dwell_time = dwell_time
self._per_agent_times = defaultdict(lambda: dwell_time)
self.action_counts = action_counts
self._per_agent_actions = defaultdict(lambda: 0)
def do_wait_action(self, agent: Agent):
self._per_agent_times[agent.name] -= 1
self._per_agent_actions[agent.name] += 1
return c.VALID
def leave(self, agent: Agent):
del self._per_agent_times[agent.name]
@property
def is_considered_reached(self):
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
def agent_is_dwelling(self, agent: Agent):
return self._per_agent_times[agent.name] < self.dwell_time
def agent_did_action(self, agent: Agent):
return self._per_agent_actions[agent.name] >= self.action_counts
def summarize_state(self) -> dict:
state_summary = super().summarize_state()
state_summary.update(per_agent_times=[
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.items()], dwell_time=self.dwell_time)
dict(belongs_to=key, time=val) for key, val in self._per_agent_actions.items()], counts=self.action_counts)
return state_summary
def render(self):
@ -68,9 +57,8 @@ class BoundDestination(BoundEntityMixin, Destination):
self.bind_to(entity)
super().__init__(*args, **kwargs)
@property
def is_considered_reached(self):
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
return (agent_at_position and not self.dwell_time) \
or any(x == 0 for x in self._per_agent_times[self.bound_entity.name])
return ((agent_at_position and not self.action_counts)
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)

View File

@ -23,14 +23,3 @@ class BoundDestinations(HasBoundMixin, Destinations):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class ReachedDestinations(Destinations):
_entity = Destination
is_blocking_light = False
can_collide = False
def __init__(self, *args, **kwargs):
super(ReachedDestinations, self).__init__(*args, **kwargs)
def __repr__(self):
return super(ReachedDestinations, self).__repr__()

View File

@ -1,84 +1,61 @@
from typing import List, Union
import ast
from random import shuffle
from typing import List, Union, Dict, Tuple
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d, rewards as r
from marl_factory_grid.modules.destinations.entitites import Destination
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination
class DestinationReach(Rule):
class DestinationReachAll(Rule):
def __init__(self, n_dests: int = 1, tiles: Union[List, None] = None):
super(DestinationReach, self).__init__()
self.n_dests = n_dests or len(tiles)
self._tiles = tiles
def __init__(self):
super(DestinationReachAll, self).__init__()
def tick_step(self, state) -> List[TickResult]:
for dest in list(state[d.DESTINATION].values()):
results = []
for dest in list(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
if dest.is_considered_reached:
dest.change_parent_collection(state[d.DEST_REACHED])
agent = state[c.AGENT].by_pos(dest.pos)
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
state.print(f'{dest.name} is reached now, removing...')
assert dest.destroy(), f'{dest.name} could not be destroyed. Critical Error.'
else:
for agent_name in dest.currently_dwelling_names:
agent = state[c.AGENT][agent_name]
if agent.pos == dest.pos:
state.print(f'{agent.name} is still waiting.')
pass
else:
dest.leave(agent)
state.print(f'{agent.name} left the destination early.')
pass
return [TickResult(self.name, validity=c.VALID, reward=0, entity=None)]
def tick_post_step(self, state) -> List[TickResult]:
results = list()
for reached_dest in state[d.DEST_REACHED]:
for guest in reached_dest.tile.guests:
if guest in state[c.AGENT]:
state.print(f'{guest.name} just reached destination at {guest.pos}')
state[d.DEST_REACHED].delete_env_object(reached_dest)
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=guest))
return results
class DestinationDone(Rule):
def __init__(self):
super(DestinationDone, self).__init__()
def on_check_done(self, state) -> List[DoneResult]:
if not len(state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return []
class DoneOnReach(Rule):
def __init__(self):
super(DoneOnReach, self).__init__()
def on_check_done(self, state) -> List[DoneResult]:
dests = [x.pos for x in state[d.DESTINATION]]
agents = [x.pos for x in state[c.AGENT]]
if any([x in dests for x in agents]):
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
class DestinationReachAny(DestinationReachAll):
def __init__(self):
super(DestinationReachAny, self).__init__()
def on_check_done(self, state) -> List[DoneResult]:
if not len(state[next(key for key in state.entities.names if d.DESTINATION in key)]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return []
class DestinationSpawn(Rule):
def __init__(self, spawn_frequency: int = 5, n_dests: int = 1,
def __init__(self, n_dests: int = 1,
spawn_mode: str = d.MODE_GROUPED):
super(DestinationSpawn, self).__init__()
self.spawn_frequency = spawn_frequency
self.n_dests = n_dests
self.spawn_mode = spawn_mode
def on_init(self, state, lvl_map):
# noinspection PyAttributeOutsideInit
self._dest_spawn_timer = self.spawn_frequency
self.trigger_destination_spawn(self.n_dests, state)
pass
@ -88,16 +65,40 @@ class DestinationSpawn(Rule):
def tick_step(self, state) -> List[TickResult]:
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
validity = state.rules['DestinationReach'].trigger_destination_spawn(n_dest_spawn, state)
validity = self.trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
validity = self.trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
else:
pass
@staticmethod
def trigger_destination_spawn(n_dests, state, tiles=None):
tiles = tiles or state[c.FLOOR].empty_tiles[:n_dests]
if destinations := [Destination(tile) for tile in tiles]:
def trigger_destination_spawn(self, n_dests, state):
empty_positions = state[c.FLOORS].empty_tiles[:n_dests]
if destinations := [Destination(pos) for pos in empty_positions]:
state[d.DESTINATION].add_items(destinations)
state.print(f'{n_dests} new destinations have been spawned')
return c.VALID
else:
state.print('No Destiantions are spawning, limit is reached.')
return c.NOT_VALID
class FixedDestinationSpawn(Rule):
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
super(Rule, self).__init__()
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}
def on_init(self, state, lvl_map):
for (agent_name, position_list) in self.per_agent_positions.items():
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
shuffle(position_list)
while True:
pos = position_list.pop()
if pos != agent.pos and not state[d.BOUNDDESTINATION].by_pos(pos):
destination = BoundDestination(agent, state[c.FLOORS].by_pos(pos))
break
else:
continue
state[d.BOUNDDESTINATION].add_item(destination)
pass

View File

@ -21,7 +21,7 @@ class AgentSingleZonePlacementBeta(Rule):
coordinates = random.choices(self.coordinates, k=len(agents))
else:
raise ValueError
tiles = [state[c.FLOOR].by_pos(pos) for pos in coordinates]
tiles = [state[c.FLOORS].by_pos(pos) for pos in coordinates]
for agent, tile in zip(agents, tiles):
agent.move(tile)

View File

@ -41,7 +41,7 @@ class ItemRules(Rule):
def trigger_item_spawn(self, state):
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
empty_tiles = state[c.FLOORS].empty_tiles[:item_to_spawns]
state[i.ITEM].spawn(empty_tiles)
self._next_item_spawn = self.spawn_frequency
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
@ -73,7 +73,7 @@ class ItemRules(Rule):
return []
def trigger_drop_off_location_spawn(self, state):
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_locations]
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_locations]
do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
do_entites.add_items(drop_offs)

View File

@ -0,0 +1,5 @@
#######
###-###
#1---2#
###-###
#######

View File

@ -13,7 +13,7 @@ class MachineRule(Rule):
self.n_machines = n_machines
def on_init(self, state, lvl_map):
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_machines]
empty_tiles = state[c.FLOORS].empty_tiles[:self.n_machines]
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles)
def tick_pre_step(self, state) -> List[TickResult]:

View File

@ -39,7 +39,7 @@ class Maintainer(Entity):
self._next = []
self._last = []
self._last_serviced = 'None'
self._floortile_graph = points_to_graph(state[c.FLOOR].positions)
self._floortile_graph = points_to_graph(state[c.FLOORS].positions)
def tick(self, state):
if found_objective := state[self.objective].by_pos(self.pos):
@ -89,7 +89,7 @@ class Maintainer(Entity):
def _predict_move(self, state):
next_pos = self._path[0]
if len(state[c.FLOOR].by_pos(next_pos).guests_that_can_collide) > 0:
if len(state[c.FLOORS].by_pos(next_pos).guests_that_can_collide) > 0:
action = c.NOOP
else:
next_pos = self._path.pop(0)

View File

@ -5,11 +5,9 @@ from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin
from ..machines.actions import MachineAction
from ...utils.render import RenderEntity
from ...utils.states import Gamestate
from ..machines import constants as mc
from . import constants as mi
class Maintainers(PositionMixin, EnvObjects):

View File

@ -14,7 +14,7 @@ class MaintenanceRule(Rule):
self.n_maintainer = n_maintainer
def on_init(self, state: Gamestate, lvl_map):
state[M.MAINTAINERS].spawn(state[c.FLOOR].empty_tiles[:self.n_maintainer], state)
state[M.MAINTAINERS].spawn(state[c.FLOORS].empty_tiles[:self.n_maintainer], state)
pass
def tick_pre_step(self, state) -> List[TickResult]:

View File

@ -19,7 +19,7 @@ class ZoneInit(Rule):
while z_idx:
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
if len(zone_positions):
zones.append(Zone([state[c.FLOOR].by_pos(pos) for pos in zone_positions]))
zones.append(Zone([state[c.FLOORS].by_pos(pos) for pos in zone_positions]))
z_idx += 1
else:
z_idx = 0

View File

@ -1,3 +1,5 @@
import ast
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union
@ -80,15 +82,15 @@ class FactoryConfigParser(object):
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents))
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
return entity_classes
def load_agents(self, size, free_tiles):
agents = Agents(size)
def parse_agents_conf(self):
parsed_agents_conf = dict()
base_env_actions = self.default_actions.copy() + [c.MOVE4]
for name in self.agents:
# Actions
@ -116,9 +118,9 @@ class FactoryConfigParser(object):
if c.DEFAULTS in self.agents[name]['Observations']:
observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
agent = Agent(parsed_actions, observations, free_tiles.pop(), str_ident=name)
agents.add_item(agent)
return agents
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
return parsed_agents_conf
def load_rules(self):
# entites = Entities()

View File

@ -4,6 +4,7 @@ from typing import Dict
import numpy as np
from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.groups.global_entities import Entities
from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors
from marl_factory_grid.utils import helpers as h
@ -35,11 +36,12 @@ class LevelParser(object):
entities = Entities()
# Walls
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
entities.add_items({c.WALL: walls})
entities.add_items({c.WALLS: walls})
# Floor
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
entities.add_items({c.FLOOR: floor})
entities.add_items({c.FLOORS: floor})
entities.add_items({c.AGENT: Agents(self.size)})
# All other
for es_name in self.e_p_dict:
@ -52,8 +54,9 @@ class LevelParser(object):
for symbol in symbols:
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
if np.any(level_array):
# TODO: Get rid of this!
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
entities[c.FLOORS], self.size, entity_kwargs=e_kwargs
)
else:
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'

View File

@ -41,7 +41,7 @@ class OBSBuilder(object):
self.curr_lightmaps = dict()
def reset_struc_obs_block(self, state):
self._curr_env_step = state.curr_step.copy()
self._curr_env_step = state.curr_step
# Construct an empty obs (array) for possible placeholders
self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float)
# Fill the all_obs-dict with all available entities

View File

@ -5,6 +5,7 @@ import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.groups.global_entities import Entities
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result
@ -43,7 +44,7 @@ class StepRules:
def tick_pre_step_all(self, state):
results = list()
for rule in self.rules:
if tick_pre_step_result := rule.tick_post_step(state):
if tick_pre_step_result := rule.tick_pre_step(state):
results.extend(tick_pre_step_result)
return results
@ -61,11 +62,12 @@ class Gamestate(object):
def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
self.entities = entitites
def __init__(self, entitites, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
self.entities: Entities = entitites
self.NO_POS_TILE = Floor(c.VALUE_NO_POS)
self.curr_step = 0
self.curr_actions = None
self.agents_conf = agents_conf
self.verbose = verbose
self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
@ -113,7 +115,7 @@ class Gamestate(object):
return results
def get_all_tiles_with_collisions(self) -> List[Floor]:
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
tiles = [self[c.FLOORS].by_pos(pos) for pos, e in self.entities.pos_dict.items()
if sum([x.var_can_collide for x in e]) > 1]
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
return tiles