Merge branch 'main' into unit_testing

This commit is contained in:
Chanumask 2023-11-16 19:45:52 +01:00
commit 0f6ede3f1f
27 changed files with 290 additions and 126 deletions

View File

@ -46,7 +46,7 @@ class LoopMAPPO(LoopSNAC):
# monte carlo returns # monte carlo returns
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma) mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok? mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
advantages = mc_returns - out[nms.CRITIC][:, :-1] advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss # policy loss
log_ap = torch.log_softmax(logits, -1) log_ap = torch.log_softmax(logits, -1)

View File

@ -0,0 +1,66 @@
General:
env_seed: 69
individual_rewards: true
level_name: obs_test_map
pomdp_r: 0
verbose: True
tests: false
Agents:
Wolfgang:
Actions:
- Noop
Observations:
- Walls
- Doors
- Other
- DirtPiles
Positions:
- (1, 3)
Soeren:
Actions:
- Noop
Observations:
- Walls
- Doors
- Other
- DirtPiles
Positions:
- (1, 1)
Juergen:
Actions:
- Noop
Observations:
- Walls
- Doors
- Other
- DirtPiles
Positions:
- (1, 2)
Walter:
Actions:
- Noop
Observations:
- Walls
- Doors
- Other
- DirtPiles
Positions:
- (1, 4)
Entities:
DirtPiles:
Doors:
Rules:
# Utilities
WatchCollisions:
done_at_collisions: false
# Done Conditions
DoneAtMaxStepsReached:
max_steps: 500

View File

@ -2,7 +2,7 @@ Agents:
Wolfgang: Wolfgang:
Actions: Actions:
- Noop - Noop
- BtryCharge - Charge
- CleanUp - CleanUp
- DestAction - DestAction
- DoorUse - DoorUse
@ -79,7 +79,7 @@ Rules:
done_at_collisions: false done_at_collisions: false
# Done Conditions # Done Conditions
DoneAtDestinationReachAny: DoneAtDestinationReach:
DoneOnAllDirtCleaned: DoneOnAllDirtCleaned:
DoneAtBatteryDischarge: DoneAtBatteryDischarge:
DoneAtMaintainerCollision: DoneAtMaintainerCollision:

View File

@ -1,8 +1,20 @@
General:
env_seed: 69
individual_rewards: true
level_name: eight_puzzle
pomdp_r: 0
verbose: True
tests: false
Agents: Agents:
Wolfgang: Wolfgang:
Actions: Actions:
- Noop Noop:
- Move4 fail_reward: -0
valid_reward: 0
Move4:
fail_reward: -0.1
valid_reward: -.01
Observations: Observations:
- Other - Other
- Walls - Walls
@ -35,13 +47,6 @@ Entities:
Walter: 1 Walter: 1
Siggi: 1 Siggi: 1
Dennis: 1 Dennis: 1
General:
env_seed: 69
individual_rewards: true
level_name: eight_puzzle
pomdp_r: 3
verbose: True
tests: false
Rules: Rules:
# Utilities # Utilities

View File

@ -81,9 +81,10 @@ Rules:
reward_at_done: -1 reward_at_done: -1
done_at_collisions: false done_at_collisions: false
# Done Conditions # Done Conditions
# Load any of the rules, to check for done conditions. # Load any of the rules, to check for done conditions.
# DoneAtDestinationReachAny: DoneAtDestinationReach:
DoneAtDestinationReachAll: reward_at_done: 1
# reward_at_done: 1 # We want to give rewards only, when all targets have been reached.
condition: "all"
DoneAtMaxStepsReached: DoneAtMaxStepsReached:
max_steps: 200 max_steps: 200

View File

@ -1,4 +1,5 @@
import abc import abc
import random
from typing import Union from typing import Union
from marl_factory_grid.environment import rewards as r, constants as c from marl_factory_grid.environment import rewards as r, constants as c
@ -13,45 +14,55 @@ class Action(abc.ABC):
return self._identifier return self._identifier
@abc.abstractmethod @abc.abstractmethod
def __init__(self, identifier: str): def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
valid_reward: float | None = None, fail_reward: float | None = None):
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
self._identifier = identifier self._identifier = identifier
@abc.abstractmethod @abc.abstractmethod
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
print() validity = bool(random.choice([0, 1]))
return return self.get_result(validity, entity)
def __repr__(self): def __repr__(self):
return f'Action[{self._identifier}]' return f'Action[{self._identifier}]'
def get_result(self, validity, entity):
reward = self.valid_reward if validity else self.fail_reward
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity)
class Noop(Action): class Noop(Action):
def __init__(self): def __init__(self, **kwargs):
super().__init__(c.NOOP) super().__init__(c.NOOP, r.NOOP, r.NOOP, **kwargs)
def do(self, entity, *_) -> Union[None, ActionResult]: def do(self, entity, *_) -> Union[None, ActionResult]:
return ActionResult(identifier=self._identifier, validity=c.VALID, return self.get_result(c.VALID, entity)
reward=r.NOOP, entity=entity)
class Move(Action, abc.ABC): class Move(Action, abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __init__(self, *args, **kwargs): def __init__(self, identifier, **kwargs):
super().__init__(*args, **kwargs) super().__init__(identifier, r.MOVEMENTS_VALID, r.MOVEMENTS_FAIL, **kwargs)
def do(self, entity, state): def do(self, entity, state):
new_pos = self._calc_new_pos(entity.pos) new_pos = self._calc_new_pos(entity.pos)
if state.check_move_validity(entity, new_pos): if state.check_move_validity(entity, new_pos):
# noinspection PyUnresolvedReferences valid = entity.move(new_pos, state)
move_validity = entity.move(new_pos, state)
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL else:
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward) # There is no place to go, propably collision
else: # There is no place to go, propably collision
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml # This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION) # return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID) valid = c.NOT_VALID
if valid:
state.print(f'{entity.name} just moved to {entity.pos}.')
else:
state.print(f'{entity.name} just tried to move to {new_pos} but either failed or hat a Collision.')
return self.get_result(valid, entity)
def _calc_new_pos(self, pos): def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier] x_diff, y_diff = MOVEMAP[self._identifier]
@ -59,43 +70,43 @@ class Move(Action, abc.ABC):
class North(Move): class North(Move):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super().__init__(c.NORTH, *args, **kwargs) super().__init__(c.NORTH, **kwargs)
class NorthEast(Move): class NorthEast(Move):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super().__init__(c.NORTHEAST, *args, **kwargs) super().__init__(c.NORTHEAST, **kwargs)
class East(Move): class East(Move):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super().__init__(c.EAST, *args, **kwargs) super().__init__(c.EAST, **kwargs)
class SouthEast(Move): class SouthEast(Move):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super().__init__(c.SOUTHEAST, *args, **kwargs) super().__init__(c.SOUTHEAST, **kwargs)
class South(Move): class South(Move):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super().__init__(c.SOUTH, *args, **kwargs) super().__init__(c.SOUTH, **kwargs)
class SouthWest(Move): class SouthWest(Move):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super().__init__(c.SOUTHWEST, *args, **kwargs) super().__init__(c.SOUTHWEST, **kwargs)
class West(Move): class West(Move):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super().__init__(c.WEST, *args, **kwargs) super().__init__(c.WEST, **kwargs)
class NorthWest(Move): class NorthWest(Move):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super().__init__(c.NORTHWEST, *args, **kwargs) super().__init__(c.NORTHWEST, **kwargs)
Move4 = [North, East, South, West] Move4 = [North, East, South, West]

View File

@ -43,9 +43,6 @@ class Agent(Entity):
def var_is_blocking_pos(self): def var_is_blocking_pos(self):
return self._is_blocking_pos return self._is_blocking_pos
@property
def state(self):
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID)
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs): def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
super(Agent, self).__init__(*args, **kwargs) super(Agent, self).__init__(*args, **kwargs)
@ -53,21 +50,16 @@ class Agent(Entity):
self.step_result = dict() self.step_result = dict()
self._actions = actions self._actions = actions
self._observations = observations self._observations = observations
self._state: Union[Result, None] = None self._status: Union[Result, None] = None
self._is_blocking_pos = is_blocking_pos self._is_blocking_pos = is_blocking_pos
# noinspection PyAttributeOutsideInit
def clear_temp_state(self):
self._state = None
return self
def summarize_state(self): def summarize_state(self):
state_dict = super().summarize_state() state_dict = super().summarize_state()
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier)) state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
return state_dict return state_dict
def set_state(self, action_result): def set_state(self, action_result):
self._state = action_result self._status = action_result
def paralyze(self, reason): def paralyze(self, reason):
self._paralyzed.add(reason) self._paralyzed.add(reason)

View File

@ -90,7 +90,14 @@ class Entity(Object, abc.ABC):
self.set_pos(next_pos) self.set_pos(next_pos)
for observer in self.observers: for observer in self.observers:
observer.notify_add_entity(self) observer.notify_add_entity(self)
# Aftermath Collision Check
if len([x for x in state.entities.by_pos(next_pos) if x.var_can_collide]) > 1:
# The entity did move, but there was something to collide with...
# Is then reported as a non-valid move, which did work.
valid = False
return valid return valid
# Bad naming... Was the same was the same pos, not moving....
return not_same_pos return not_same_pos
def __init__(self, pos, bind_to=None, **kwargs): def __init__(self, pos, bind_to=None, **kwargs):

View File

@ -40,6 +40,7 @@ class Object:
return True return True
def __init__(self, str_ident: Union[str, None] = None, **kwargs): def __init__(self, str_ident: Union[str, None] = None, **kwargs):
self._status = None
self._bound_entity = None self._bound_entity = None
self._observers = set() self._observers = set()
self._str_ident = str_ident self._str_ident = str_ident
@ -84,6 +85,10 @@ class Object:
def summarize_state(self): def summarize_state(self):
return dict() return dict()
def clear_temp_state(self):
self._status = None
return self
def bind_to(self, entity): def bind_to(self, entity):
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
self._bound_entity = entity self._bound_entity = entity

View File

@ -132,7 +132,8 @@ class WatchCollisions(Rule):
for i, guest in enumerate(guests): for i, guest in enumerate(guests):
try: try:
guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward, guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward,
validity=c.NOT_VALID, entity=self)) validity=c.NOT_VALID, entity=guest)
)
except AttributeError: except AttributeError:
pass pass
results.append(TickResult(entity=guest, identifier=c.COLLISION, results.append(TickResult(entity=guest, identifier=c.COLLISION,

View File

@ -0,0 +1,12 @@
############
#----------#
#-#######--#
#-#-----D--#
#-#######--#
#-D-----D--#
#-#-#-#-#-##
#----------#
#----------#
#----------#
#----------#
############

View File

@ -1,4 +1,4 @@
from .actions import BtryCharge from .actions import Charge
from .entitites import ChargePod, Battery from .entitites import ChargePod, Battery
from .groups import ChargePods, Batteries from .groups import ChargePods, Batteries
from .rules import DoneAtBatteryDischarge, BatteryDecharge from .rules import DoneAtBatteryDischarge, BatteryDecharge

View File

@ -8,14 +8,14 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
class BtryCharge(Action): class Charge(Action):
def __init__(self): def __init__(self):
super().__init__(b.ACTION_CHARGE) super().__init__(b.ACTION_CHARGE, b.REWARD_CHARGE_VALID, b.Reward_CHARGE_FAIL)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)): if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
valid = h.get_first(charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))) valid = h.get_first(charge_pod.charge_battery(entity, state))
if valid: if valid:
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
else: else:
@ -24,5 +24,4 @@ class BtryCharge(Action):
valid = c.NOT_VALID valid = c.NOT_VALID
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.') state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, return self.get_result(valid, entity)
reward=b.REWARD_CHARGE_VALID if valid else b.Reward_CHARGE_FAIL)

View File

@ -1,4 +1,5 @@
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.modules.batteries import constants as b from marl_factory_grid.modules.batteries import constants as b
@ -62,11 +63,11 @@ class ChargePod(Entity):
self.charge_rate = charge_rate self.charge_rate = charge_rate
self.multi_charge = multi_charge self.multi_charge = multi_charge
def charge_battery(self, battery: Battery): def charge_battery(self, entity, state):
if battery.charge_level == 1.0: battery = state[b.BATTERIES].by_entity(entity)
if battery.charge_level >= 1.0:
return c.NOT_VALID return c.NOT_VALID
if sum(1 for key, val in self.state.entities.pos_dict[self.pos] for guest in val if if len([x for x in state[c.AGENT].by_pos(entity.pos)]) > 1:
'agent' in guest.name.lower()) > 1:
return c.NOT_VALID return c.NOT_VALID
valid = battery.do_charge_action(self.charge_rate) valid = battery.do_charge_action(self.charge_rate)
return valid return valid

View File

@ -11,7 +11,7 @@ from marl_factory_grid.environment import constants as c
class CleanUp(Action): class CleanUp(Action):
def __init__(self): def __init__(self):
super().__init__(d.CLEAN_UP) super().__init__(d.CLEAN_UP, r.CLEAN_UP_VALID, r.CLEAN_UP_FAIL)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if dirt := next((x for x in state.entities.pos_dict[entity.pos] if "dirt" in x.name.lower()), None): if dirt := next((x for x in state.entities.pos_dict[entity.pos] if "dirt" in x.name.lower()), None):
@ -24,13 +24,10 @@ class CleanUp(Action):
valid = c.VALID valid = c.VALID
print_str = f'{entity.name} did just clean up some dirt at {entity.pos}.' print_str = f'{entity.name} did just clean up some dirt at {entity.pos}.'
state.print(print_str) state.print(print_str)
reward = r.CLEAN_UP_VALID
identifier = d.CLEAN_UP
else: else:
valid = c.NOT_VALID valid = c.NOT_VALID
print_str = f'{entity.name} just tried to clean up some dirt at {entity.pos}, but failed.' print_str = f'{entity.name} just tried to clean up some dirt at {entity.pos}, but failed.'
state.print(print_str) state.print(print_str)
reward = r.CLEAN_UP_FAIL
identifier = d.CLEAN_UP_FAIL
return ActionResult(identifier=identifier, validity=valid, reward=reward, entity=entity) return self.get_result(valid, entity)

View File

@ -11,7 +11,7 @@ from marl_factory_grid.environment import constants as c
class DestAction(Action): class DestAction(Action):
def __init__(self): def __init__(self):
super().__init__(d.DESTINATION) super().__init__(d.DESTINATION, d.REWARD_WAIT_VALID, d.REWARD_WAIT_FAIL)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if destination := state[d.DESTINATION].by_pos(entity.pos): if destination := state[d.DESTINATION].by_pos(entity.pos):
@ -19,6 +19,5 @@ class DestAction(Action):
state.print(f'{entity.name} just waited at {entity.pos}') state.print(f'{entity.name} just waited at {entity.pos}')
else: else:
valid = c.NOT_VALID valid = c.NOT_VALID
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') state.print(f'{entity.name} just tried to "do_wait_action" at {entity.pos} but failed')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, return self.get_result(valid, entity)
reward=d.REWARD_WAIT_VALID if valid else d.REWARD_WAIT_FAIL)

View File

@ -1,6 +1,7 @@
from typing import Union from typing import Union
from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.actions import Action
from marl_factory_grid.modules.doors.entitites import Door
from marl_factory_grid.modules.doors import constants as d, rewards as r from marl_factory_grid.modules.doors import constants as d, rewards as r
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.utils.results import ActionResult
@ -8,21 +9,23 @@ from marl_factory_grid.utils.results import ActionResult
class DoorUse(Action): class DoorUse(Action):
def __init__(self): def __init__(self, **kwargs):
super().__init__(d.ACTION_DOOR_USE) super().__init__(d.ACTION_DOOR_USE, r.USE_DOOR_VALID, r.USE_DOOR_FAIL, **kwargs)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
# Check if agent really is standing on a door: # Check if agent really is standing on a door:
e = state.entities.get_entities_near_pos(entity.pos) entities_close = state.entities.get_entities_near_pos(entity.pos)
try:
# Only one door opens TODO introduce loop
door = next(x for x in e if x.name.startswith(d.DOOR))
valid = door.use()
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.USE_DOOR_VALID)
except StopIteration: valid = False
# When he doesn't... for door in [e for e in entities_close if isinstance(e, Door)]:
try:
# Will always be true, when there is at least a single door.
valid = door.use()
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
except AttributeError:
pass
if not valid:
# When he doesn't stand necxxt to a door tell me.
state.print(f'{entity.name} just tried to use a door at {entity.pos}, but there is none.') state.print(f'{entity.name} just tried to use a door at {entity.pos}, but there is none.')
return ActionResult(entity=entity, identifier=self._identifier, return self.get_result(valid, entity)
validity=c.NOT_VALID, reward=r.USE_DOOR_FAIL)

View File

@ -44,22 +44,19 @@ class Door(Entity):
@property @property
def is_closed(self): def is_closed(self):
return self._status == d.STATE_CLOSED return self._state == d.STATE_CLOSED
@property @property
def is_open(self): def is_open(self):
return self._status == d.STATE_OPEN return self._state == d.STATE_OPEN
@property
def status(self):
return self._status
@property @property
def time_to_close(self): def time_to_close(self):
return self._time_to_close return self._time_to_close
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs): def __init__(self, *args, closed_on_init=True, auto_close_interval=10, **kwargs):
self._status = d.STATE_CLOSED self._state = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs) super(Door, self).__init__(*args, **kwargs)
self._auto_close_interval = auto_close_interval self._auto_close_interval = auto_close_interval
self._time_to_close = 0 self._time_to_close = 0
@ -78,7 +75,7 @@ class Door(Entity):
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1) return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
def use(self): def use(self):
if self._status == d.STATE_OPEN: if self._state == d.STATE_OPEN:
self._close() self._close()
else: else:
self._open() self._open()
@ -102,12 +99,12 @@ class Door(Entity):
return Result(f"{d.DOOR}_reset", c.VALID, entity=self) return Result(f"{d.DOOR}_reset", c.VALID, entity=self)
def _open(self): def _open(self):
self._status = d.STATE_OPEN self._state = d.STATE_OPEN
self._reset_timer() self._reset_timer()
return True return True
def _close(self): def _close(self):
self._status = d.STATE_CLOSED self._state = d.STATE_CLOSED
return True return True
def _decrement_timer(self): def _decrement_timer(self):

View File

@ -9,8 +9,14 @@ from marl_factory_grid.environment import constants as c
class ItemAction(Action): class ItemAction(Action):
def __init__(self): def __init__(self, failed_dropoff_reward: float | None = None, valid_dropoff_reward: float | None = None, **kwargs):
super().__init__(i.ITEM_ACTION) super().__init__(i.ITEM_ACTION, r.PICK_UP_FAIL, r.PICK_UP_VALID, **kwargs)
self.failed_drop_off_reward = failed_dropoff_reward if failed_dropoff_reward is not None else r.DROP_OFF_FAIL
self.valid_drop_off_reward = valid_dropoff_reward if valid_dropoff_reward is not None else r.DROP_OFF_FAIL
def get_dropoff_result(self, validity, entity):
reward = self.valid_drop_off_reward if validity else self.failed_drop_off_reward
return ActionResult(self.__name__, validity, reward=reward, entity=entity)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
inventory = state[i.INVENTORY].by_entity(entity) inventory = state[i.INVENTORY].by_entity(entity)
@ -23,16 +29,15 @@ class ItemAction(Action):
state.print(f'{entity.name} just dropped of an item at {drop_off.pos}.') state.print(f'{entity.name} just dropped of an item at {drop_off.pos}.')
else: else:
state.print(f'{entity.name} just tried to drop off at {entity.pos}, but failed.') state.print(f'{entity.name} just tried to drop off at {entity.pos}, but failed.')
reward = r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL return self.get_dropoff_result(valid, entity)
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
elif items := state[i.ITEM].by_pos(entity.pos): elif items := state[i.ITEM].by_pos(entity.pos):
item = items[0] item = items[0]
item.change_parent_collection(inventory) item.change_parent_collection(inventory)
item.set_pos(c.VALUE_NO_POS) item.set_pos(c.VALUE_NO_POS)
state.print(f'{entity.name} just picked up an item at {entity.pos}') state.print(f'{entity.name} just picked up an item at {entity.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID) return self.get_result(c.VALID, entity)
else: else:
state.print(f'{entity.name} just tried to pick up an item at {entity.pos}, but failed.') state.print(f'{entity.name} just tried to pick up an item at {entity.pos}, but failed.')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.PICK_UP_FAIL) return self.get_result(c.NOT_VALID, entity)

View File

@ -70,6 +70,10 @@ class Inventory(IsBoundMixin, Collection):
def set_collection(self, collection): def set_collection(self, collection):
self._collection = collection self._collection = collection
def clear_temp_state(self):
# Entites need this, but inventories have no state....
pass
class Inventories(Objects): class Inventories(Objects):
_entity = Inventory _entity = Inventory

View File

@ -12,15 +12,12 @@ from marl_factory_grid.utils import helpers as h
class MachineAction(Action): class MachineAction(Action):
def __init__(self): def __init__(self):
super().__init__(m.MACHINE_ACTION) super().__init__(m.MACHINE_ACTION, m.MAINTAIN_VALID, m.MAINTAIN_FAIL)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)): if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
if valid := machine.maintain(): valid = machine.maintain()
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_VALID) return self.get_result(valid, entity)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL)
else: else:
return ActionResult(entity=entity, identifier=self._identifier, return self.get_result(c.NOT_VALID, entity)
validity=c.NOT_VALID, reward=marl_factory_grid.modules.machines.constants.MAINTAIN_FAIL
)

View File

@ -28,14 +28,19 @@ class Maintainer(Entity):
def tick(self, state): def tick(self, state):
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)): if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
if found_objective.name != self._last_serviced: if found_objective.name != self._last_serviced:
self.action.do(self, state) result = self.action.do(self, state)
self._last_serviced = found_objective.name self._last_serviced = found_objective.name
else: else:
action = self.get_move_action(state) action = self.get_move_action(state)
return action.do(self, state) result = action.do(self, state)
else: else:
action = self.get_move_action(state) action = self.get_move_action(state)
return action.do(self, state) result = action.do(self, state)
self.set_state(result)
return result
def set_state(self, action_result):
self._status = action_result
def get_move_action(self, state) -> Action: def get_move_action(self, state) -> Action:
if self._path is None or not len(self._path): if self._path is None or not len(self._path):

View File

@ -124,16 +124,28 @@ class FactoryConfigParser(object):
def parse_agents_conf(self): def parse_agents_conf(self):
parsed_agents_conf = dict() parsed_agents_conf = dict()
base_env_actions = self.default_actions.copy() + [c.MOVE4]
for name in self.agents: for name in self.agents:
# Actions # Actions
conf_actions = self.agents[name]['Actions']
actions = list() actions = list()
if c.DEFAULTS in self.agents[name]['Actions']:
actions.extend(self.default_actions) if isinstance(conf_actions, dict):
actions.extend(x for x in self.agents[name]['Actions'] if x != c.DEFAULTS) conf_kwargs = conf_actions.copy()
conf_actions = list(conf_actions.keys())
elif isinstance(conf_actions, list):
conf_kwargs = {}
if isinstance(conf_actions, dict):
raise ValueError
pass
for action in conf_actions:
if action == c.DEFAULTS:
actions.extend(self.default_actions)
else:
actions.append(action)
parsed_actions = list() parsed_actions = list()
for action in actions: for action in actions:
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH folder_path = MODULE_PATH if action not in [c.MOVE8, c.NOOP, c.MOVE4] else DEFAULT_PATH
folder_path = Path(__file__).parent.parent / folder_path folder_path = Path(__file__).parent.parent / folder_path
try: try:
class_or_classes = locate_and_import_class(action, folder_path) class_or_classes = locate_and_import_class(action, folder_path)
@ -144,7 +156,7 @@ class FactoryConfigParser(object):
except TypeError: except TypeError:
parsed_actions.append(class_or_classes) parsed_actions.append(class_or_classes)
parsed_actions = [x() for x in parsed_actions] parsed_actions = [x(**conf_kwargs.get(x, {})) for x in parsed_actions]
# Observation # Observation
observations = list() observations = list()

View File

@ -224,8 +224,8 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent', 'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any', 'Factory',
]]) 'Move8']])
try: try:
model_class = mod.__getattribute__(class_name) model_class = mod.__getattribute__(class_name)
return model_class return model_class

View File

@ -158,6 +158,9 @@ class Gamestate(object):
test_results = list() test_results = list()
self.curr_step += 1 self.curr_step += 1
for entity in self.entities.iter_entities():
entity.clear_temp_state()
# Main Agent Step # Main Agent Step
results.extend(self.rules.tick_pre_step_all(self)) results.extend(self.rules.tick_pre_step_all(self))
if self.tests: if self.tests:
@ -222,6 +225,7 @@ class Gamestate(object):
""" """
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute, Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
when position is allready occupied. when position is allready occupied.
!!! Will still report true even though, there could be an enity, which var_can_collide == true !!!
:param moving_entity: Entity :param moving_entity: Entity
:param target_position: pos :param target_position: pos

View File

@ -29,7 +29,7 @@ if __name__ == '__main__':
ce.save_all(run_path / 'all_out.yaml') ce.save_all(run_path / 'all_out.yaml')
# Path to config File # Path to config File
path = Path('marl_factory_grid/configs/eight_puzzle.yaml') path = Path('marl_factory_grid/configs/default_config.yaml')
# Env Init # Env Init
factory = Factory(path) factory = Factory(path)

41
test_observations.py Normal file
View File

@ -0,0 +1,41 @@
from pathlib import Path
from random import randint
from tqdm import trange
from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.utils.logging.recorder import EnvRecorder
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
from marl_factory_grid.utils.tools import ConfigExplainer
if __name__ == '__main__':
# Render at each step?
render = True
run_path = Path('study_out')
# Path to config File
path = Path('marl_factory_grid/configs/_obs_test.yaml')
# Env Init
factory = Factory(path)
# RL learn Loop
for episode in trange(10):
_ = factory.reset()
done = False
if render:
factory.render()
action_spaces = factory.action_space
while not done:
a = [randint(0, x.n - 1) for x in action_spaces]
obs_type, _, _, done, info = factory.step(a)
if render:
factory.render()
if done:
print(f'Episode {episode} done...')
break
print('Done!!! Goodbye....')