mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-15 23:37:14 +02:00
Redone the spawn procedute and destination objects
This commit is contained in:
@@ -42,13 +42,15 @@ class Move(Action, abc.ABC):
|
||||
|
||||
def do(self, entity, env):
|
||||
new_pos = self._calc_new_pos(entity.pos)
|
||||
if next_tile := env[c.FLOOR].by_pos(new_pos):
|
||||
if next_tile := env[c.FLOORS].by_pos(new_pos):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = entity.move(next_tile)
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
|
||||
move_validity = entity.move(next_tile)
|
||||
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
|
||||
else: # There is no floor, propably collision
|
||||
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
|
||||
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
|
||||
|
||||
def _calc_new_pos(self, pos):
|
||||
x_diff, y_diff = MOVEMAP[self._identifier]
|
||||
|
@@ -55,6 +55,12 @@ class Entity(EnvObject, abc.ABC):
|
||||
curr_x, curr_y = self.pos
|
||||
return last_x - curr_x, last_y - curr_y
|
||||
|
||||
def destroy(self):
|
||||
valid = self._collection.remove_item(self)
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(self)
|
||||
return valid
|
||||
|
||||
def move(self, next_tile):
|
||||
curr_tile = self.tile
|
||||
if not_same_tile := curr_tile != next_tile:
|
||||
@@ -71,7 +77,7 @@ class Entity(EnvObject, abc.ABC):
|
||||
super().__init__(**kwargs)
|
||||
self._status = None
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
assert tile.enter(self, spawn=True), "Positions was not valid!"
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
|
@@ -81,8 +81,12 @@ class Floor(EnvObject):
|
||||
def is_occupied(self):
|
||||
return bool(len(self._guests))
|
||||
|
||||
def enter(self, guest):
|
||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||
def enter(self, guest, spawn=False):
|
||||
same_pos = guest.name not in self._guests
|
||||
not_blocked = not self.is_blocked
|
||||
no_become_blocked_when_occupied = not (guest.var_is_blocking_pos and self.is_occupied())
|
||||
not_introduce_collision = not (spawn and guest.var_can_collide and any(x.var_can_collide for x in self.guests))
|
||||
if same_pos and not_blocked and no_become_blocked_when_occupied and not_introduce_collision:
|
||||
self._guests.update({guest.name: guest})
|
||||
return c.VALID
|
||||
else:
|
||||
|
@@ -85,17 +85,14 @@ class Factory(gym.Env):
|
||||
# Init entity:
|
||||
entities = self.map.do_init()
|
||||
|
||||
# Grab all )rules:
|
||||
# Grab all env-rules:
|
||||
rules = self.conf.load_rules()
|
||||
|
||||
# Agents
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.state = Gamestate(entities, rules, self.conf.env_seed)
|
||||
# Parse the agent conf
|
||||
parsed_agents_conf = self.conf.parse_agents_conf()
|
||||
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed)
|
||||
|
||||
agents = self.conf.load_agents(self.map.size, self[c.FLOOR].empty_tiles)
|
||||
self.state.entities.add_item({c.AGENT: agents})
|
||||
|
||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||
# All is set up, trigger entity init with variable pos
|
||||
self.state.rules.do_all_init(self.state, self.map)
|
||||
|
||||
# Observations
|
||||
@@ -173,6 +170,8 @@ class Factory(gym.Env):
|
||||
# Combine Info dicts into a global one
|
||||
combined_info_dict = defaultdict(lambda: 0.0)
|
||||
for result in chain(tick_results, done_check_results):
|
||||
if not result:
|
||||
raise ValueError()
|
||||
if result.reward is not None:
|
||||
try:
|
||||
rewards[result.entity.name] += result.reward
|
||||
|
@@ -57,6 +57,16 @@ class Objects:
|
||||
observer.notify_add_entity(item)
|
||||
return self
|
||||
|
||||
def remove_item(self, item: _entity):
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(item)
|
||||
# noinspection PyTypeChecker
|
||||
del self._data[item.name]
|
||||
return True
|
||||
|
||||
def __delitem__(self, name):
|
||||
return self.remove_item(self[name])
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def del_observer(self, observer):
|
||||
self.observers.remove(observer)
|
||||
@@ -71,12 +81,6 @@ class Objects:
|
||||
if observer not in entity.observers:
|
||||
entity.add_observer(observer)
|
||||
|
||||
def __delitem__(self, name):
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(name)
|
||||
# noinspection PyTypeChecker
|
||||
del self._data[name]
|
||||
|
||||
def add_items(self, items: List[_entity]):
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
@@ -114,7 +118,8 @@ class Objects:
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS, c.FLOORS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
|
||||
def spawn(self, n: int):
|
||||
self.add_items([self._entity() for _ in range(n)])
|
||||
@@ -138,6 +143,7 @@ class Objects:
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
entity.del_observer(self)
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
@@ -146,7 +152,9 @@ class Objects:
|
||||
try:
|
||||
if self not in entity.observers:
|
||||
entity.add_observer(self)
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
if entity.var_has_position:
|
||||
if entity not in self.pos_dict[entity.pos]:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
|
@@ -1,6 +1,9 @@
|
||||
import abc
|
||||
from random import shuffle
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
|
||||
@@ -36,6 +39,40 @@ class Rule(abc.ABC):
|
||||
return []
|
||||
|
||||
|
||||
class SpawnAgents(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
agent_conf = state.agents_conf
|
||||
# agents = Agents(lvl_map.size)
|
||||
agents = state[c.AGENT]
|
||||
empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)]
|
||||
for agent_name in agent_conf:
|
||||
actions = agent_conf[agent_name]['actions'].copy()
|
||||
observations = agent_conf[agent_name]['observations'].copy()
|
||||
positions = agent_conf[agent_name]['positions'].copy()
|
||||
if positions:
|
||||
shuffle(positions)
|
||||
while True:
|
||||
try:
|
||||
tile = state[c.FLOORS].by_pos(positions.pop())
|
||||
except IndexError as e:
|
||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||
f'\n{agent_name[agent_name]["positions"].copy()}')
|
||||
try:
|
||||
agents.add_item(Agent(actions, observations, tile, str_ident=agent_name))
|
||||
except AssertionError:
|
||||
state.print(f'No valid pos:{tile.pos} for {agent_name}')
|
||||
continue
|
||||
break
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name))
|
||||
pass
|
||||
|
||||
|
||||
class MaxStepsReached(Rule):
|
||||
|
||||
def __init__(self, max_steps: int = 500):
|
||||
@@ -91,6 +128,8 @@ class Collision(Rule):
|
||||
return results
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if self.curr_done and self.done_at_collisions:
|
||||
inter_entity_collision_detected = self.curr_done and self.done_at_collisions
|
||||
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
||||
if inter_entity_collision_detected or move_failed:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
|
Reference in New Issue
Block a user