Redone the spawn procedute and destination objects

This commit is contained in:
Steffen Illium
2023-10-11 16:36:48 +02:00
parent e64fa84ef1
commit e326a95bf4
32 changed files with 266 additions and 146 deletions

View File

@@ -42,13 +42,15 @@ class Move(Action, abc.ABC):
def do(self, entity, env):
new_pos = self._calc_new_pos(entity.pos)
if next_tile := env[c.FLOOR].by_pos(new_pos):
if next_tile := env[c.FLOORS].by_pos(new_pos):
# noinspection PyUnresolvedReferences
valid = entity.move(next_tile)
else:
valid = c.NOT_VALID
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
move_validity = entity.move(next_tile)
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
else: # There is no floor, propably collision
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier]

View File

@@ -55,6 +55,12 @@ class Entity(EnvObject, abc.ABC):
curr_x, curr_y = self.pos
return last_x - curr_x, last_y - curr_y
def destroy(self):
valid = self._collection.remove_item(self)
for observer in self.observers:
observer.notify_del_entity(self)
return valid
def move(self, next_tile):
curr_tile = self.tile
if not_same_tile := curr_tile != next_tile:
@@ -71,7 +77,7 @@ class Entity(EnvObject, abc.ABC):
super().__init__(**kwargs)
self._status = None
self._tile = tile
tile.enter(self)
assert tile.enter(self, spawn=True), "Positions was not valid!"
def summarize_state(self) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y),

View File

@@ -81,8 +81,12 @@ class Floor(EnvObject):
def is_occupied(self):
return bool(len(self._guests))
def enter(self, guest):
if (guest.name not in self._guests and not self.is_blocked) and not (guest.var_is_blocking_pos and self.is_occupied()):
def enter(self, guest, spawn=False):
same_pos = guest.name not in self._guests
not_blocked = not self.is_blocked
no_become_blocked_when_occupied = not (guest.var_is_blocking_pos and self.is_occupied())
not_introduce_collision = not (spawn and guest.var_can_collide and any(x.var_can_collide for x in self.guests))
if same_pos and not_blocked and no_become_blocked_when_occupied and not_introduce_collision:
self._guests.update({guest.name: guest})
return c.VALID
else:

View File

@@ -85,17 +85,14 @@ class Factory(gym.Env):
# Init entity:
entities = self.map.do_init()
# Grab all )rules:
# Grab all env-rules:
rules = self.conf.load_rules()
# Agents
# noinspection PyAttributeOutsideInit
self.state = Gamestate(entities, rules, self.conf.env_seed)
# Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf()
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed)
agents = self.conf.load_agents(self.map.size, self[c.FLOOR].empty_tiles)
self.state.entities.add_item({c.AGENT: agents})
# All is set up, trigger additional init (after agent entity spawn etc)
# All is set up, trigger entity init with variable pos
self.state.rules.do_all_init(self.state, self.map)
# Observations
@@ -173,6 +170,8 @@ class Factory(gym.Env):
# Combine Info dicts into a global one
combined_info_dict = defaultdict(lambda: 0.0)
for result in chain(tick_results, done_check_results):
if not result:
raise ValueError()
if result.reward is not None:
try:
rewards[result.entity.name] += result.reward

View File

@@ -57,6 +57,16 @@ class Objects:
observer.notify_add_entity(item)
return self
def remove_item(self, item: _entity):
for observer in self.observers:
observer.notify_del_entity(item)
# noinspection PyTypeChecker
del self._data[item.name]
return True
def __delitem__(self, name):
return self.remove_item(self[name])
# noinspection PyUnresolvedReferences
def del_observer(self, observer):
self.observers.remove(observer)
@@ -71,12 +81,6 @@ class Objects:
if observer not in entity.observers:
entity.add_observer(observer)
def __delitem__(self, name):
for observer in self.observers:
observer.notify_del_entity(name)
# noinspection PyTypeChecker
del self._data[name]
def add_items(self, items: List[_entity]):
for item in items:
self.add_item(item)
@@ -114,7 +118,8 @@ class Objects:
raise TypeError
def __repr__(self):
return f'{self.__class__.__name__}[{dict(self._data)}]'
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS, c.FLOORS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def spawn(self, n: int):
self.add_items([self._entity() for _ in range(n)])
@@ -138,6 +143,7 @@ class Objects:
def notify_del_entity(self, entity: Object):
try:
entity.del_observer(self)
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
@@ -146,7 +152,9 @@ class Objects:
try:
if self not in entity.observers:
entity.add_observer(self)
self.pos_dict[entity.pos].append(entity)
if entity.var_has_position:
if entity not in self.pos_dict[entity.pos]:
self.pos_dict[entity.pos].append(entity)
except (ValueError, AttributeError):
pass

View File

@@ -1,6 +1,9 @@
import abc
from random import shuffle
from typing import List
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import rewards as r, constants as c
@@ -36,6 +39,40 @@ class Rule(abc.ABC):
return []
class SpawnAgents(Rule):
def __init__(self):
super().__init__()
pass
def on_init(self, state, lvl_map):
agent_conf = state.agents_conf
# agents = Agents(lvl_map.size)
agents = state[c.AGENT]
empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)]
for agent_name in agent_conf:
actions = agent_conf[agent_name]['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy()
positions = agent_conf[agent_name]['positions'].copy()
if positions:
shuffle(positions)
while True:
try:
tile = state[c.FLOORS].by_pos(positions.pop())
except IndexError as e:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_name[agent_name]["positions"].copy()}')
try:
agents.add_item(Agent(actions, observations, tile, str_ident=agent_name))
except AssertionError:
state.print(f'No valid pos:{tile.pos} for {agent_name}')
continue
break
else:
agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name))
pass
class MaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
@@ -91,6 +128,8 @@ class Collision(Rule):
return results
def on_check_done(self, state) -> List[DoneResult]:
if self.curr_done and self.done_at_collisions:
inter_entity_collision_detected = self.curr_done and self.done_at_collisions
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]