mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 17:11:35 +02:00
Merge remote-tracking branch 'origin/unit_testing' into unit_testing
This commit is contained in:
@ -1 +1,7 @@
|
||||
from .quickstart import init
|
||||
from marl_factory_grid.environment.factory import Factory
|
||||
"""
|
||||
Main module of the 'marl-factory-grid'-environment.
|
||||
Configure the :class:.Factory with any 'conf.yaml' file.
|
||||
Examples can be found in :module:.levels .
|
||||
"""
|
||||
|
@ -17,6 +17,16 @@ future_planning = 7
|
||||
class TSPBaseAgent(ABC):
|
||||
|
||||
def __init__(self, state, agent_i, static_problem: bool = True):
|
||||
"""
|
||||
Abstract base class for agents in the environment.
|
||||
|
||||
:param state: The environment state
|
||||
:type state:
|
||||
:param agent_i: Index of the agent
|
||||
:type agent_i: int
|
||||
:param static_problem: Indicates whether the TSP is a static problem. (Default: True)
|
||||
:type static_problem: bool
|
||||
"""
|
||||
self.static_problem = static_problem
|
||||
self.local_optimization = True
|
||||
self._env = state
|
||||
@ -26,9 +36,25 @@ class TSPBaseAgent(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, *_, **__) -> int:
|
||||
"""
|
||||
Predicts the next action based on the environment state.
|
||||
|
||||
:return: Predicted action.
|
||||
:rtype: int
|
||||
"""
|
||||
return 0
|
||||
|
||||
def _use_door_or_move(self, door, target):
|
||||
"""
|
||||
Helper method to decide whether to use a door or move towards a target.
|
||||
|
||||
:param door: Door entity.
|
||||
:type door: Door
|
||||
:param target: Target type. For example 'Dirt', 'Dropoff' or 'Destination'
|
||||
:type target: str
|
||||
|
||||
:return: Action to perform (use door or move).
|
||||
"""
|
||||
if door.is_closed:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = do.ACTION_DOOR_USE
|
||||
@ -37,6 +63,15 @@ class TSPBaseAgent(ABC):
|
||||
return action
|
||||
|
||||
def calculate_tsp_route(self, target_identifier):
|
||||
"""
|
||||
Calculate the TSP route to reach a target.
|
||||
|
||||
:param target_identifier: Identifier of the target entity
|
||||
:type target_identifier: str
|
||||
|
||||
:return: TSP route
|
||||
:rtype: List[int]
|
||||
"""
|
||||
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||
if self.local_optimization:
|
||||
nodes = \
|
||||
@ -55,6 +90,15 @@ class TSPBaseAgent(ABC):
|
||||
return route
|
||||
|
||||
def _door_is_close(self, state):
|
||||
"""
|
||||
Check if a door is close to the agent's position.
|
||||
|
||||
:param state: Current environment state.
|
||||
:type state: Gamestate
|
||||
|
||||
:return: Closest door entity or None if no door is close.
|
||||
:rtype: Door | None
|
||||
"""
|
||||
try:
|
||||
return next(y for x in state.entities.neighboring_positions(self.state.pos)
|
||||
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||
@ -62,9 +106,27 @@ class TSPBaseAgent(ABC):
|
||||
return None
|
||||
|
||||
def _has_targets(self, target_identifier):
|
||||
"""
|
||||
Check if there are targets available in the environment.
|
||||
|
||||
:param target_identifier: Identifier of the target entity.
|
||||
:type target_identifier: str
|
||||
|
||||
:return: True if there are targets, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return bool(len([x for x in self._env.state[target_identifier] if x.pos != c.VALUE_NO_POS]) >= 1)
|
||||
|
||||
def _predict_move(self, target_identifier):
|
||||
"""
|
||||
Predict the next move based on the given target.
|
||||
|
||||
:param target_identifier: Identifier of the target entity.
|
||||
:type target_identifier: str
|
||||
|
||||
:return: Predicted action.
|
||||
:rtype: int
|
||||
"""
|
||||
if self._has_targets(target_identifier):
|
||||
if self.static_problem:
|
||||
if not self._static_route:
|
||||
|
@ -8,9 +8,18 @@ future_planning = 7
|
||||
class TSPDirtAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initializes a TSPDirtAgent that aims to clean dirt in the environment.
|
||||
"""
|
||||
super(TSPDirtAgent, self).__init__(*args, **kwargs)
|
||||
|
||||
def predict(self, *_, **__):
|
||||
"""
|
||||
Predicts the next action based on the presence of dirt in the environment.
|
||||
|
||||
:return: Predicted action.
|
||||
:rtype: int
|
||||
"""
|
||||
dirt_at_position = self._env.state[di.DIRT].by_pos(self.state.pos)
|
||||
if dirt_at_position:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
|
@ -14,6 +14,12 @@ MODE_BRING = 'Mode_Bring'
|
||||
class TSPItemAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, *args, mode=MODE_GET, **kwargs):
|
||||
"""
|
||||
Initializes a TSPItemAgent that colects items in the environment, stores them in his inventory and drops them off
|
||||
at a drop-off location.
|
||||
|
||||
:param mode: Mode of the agent, either MODE_GET or MODE_BRING.
|
||||
"""
|
||||
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
||||
self.mode = mode
|
||||
|
||||
@ -48,6 +54,12 @@ class TSPItemAgent(TSPBaseAgent):
|
||||
return action_obj
|
||||
|
||||
def _choose(self):
|
||||
"""
|
||||
Internal Usage. Chooses the action based on the agent's mode and the environment state.
|
||||
|
||||
:return: Chosen action.
|
||||
:rtype: int
|
||||
"""
|
||||
target = i.DROP_OFF if self.mode == MODE_BRING else i.ITEM
|
||||
if len(self._env.state[i.ITEM]) >= 1:
|
||||
action = self._predict_move(target)
|
||||
|
@ -9,9 +9,20 @@ future_planning = 7
|
||||
class TSPTargetAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initializes a TSPTargetAgent that aims to reach destinations.
|
||||
"""
|
||||
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
||||
|
||||
def _handle_doors(self, state):
|
||||
"""
|
||||
Internal Usage. Handles the doors in the environment.
|
||||
|
||||
:param state: The current environment state.
|
||||
:type state: marl_factory_grid.utils.states.Gamestate
|
||||
:return: Closest door entity or None if no doors are close.
|
||||
:rtype: marl_factory_grid.environment.entity.object.Entity or None
|
||||
"""
|
||||
|
||||
try:
|
||||
return next(y for x in state.entities.neighboring_positions(self.state.pos)
|
||||
|
@ -8,8 +8,20 @@ future_planning = 7
|
||||
class TSPRandomAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, n_actions, *args, **kwargs):
|
||||
"""
|
||||
Initializes a TSPRandomAgent that performs random actions from within his action space.
|
||||
|
||||
:param n_actions: Number of possible actions.
|
||||
:type n_actions: int
|
||||
"""
|
||||
super(TSPRandomAgent, self).__init__(*args, **kwargs)
|
||||
self.n_action = n_actions
|
||||
|
||||
def predict(self, *_, **__):
|
||||
"""
|
||||
Predicts the next action randomly.
|
||||
|
||||
:return: Predicted action.
|
||||
:rtype: int
|
||||
"""
|
||||
return randint(0, self.n_action - 1)
|
||||
|
@ -8,10 +8,10 @@ def points_to_graph(coordiniates, allow_euclidean_connections=True, allow_manhat
|
||||
"""
|
||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||
There are three combinations of settings:
|
||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
Allow only manhattan: Distance(a, b) == 1
|
||||
Allow only Euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
- Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
- Allow only manhattan: Distance(a, b) == 1
|
||||
- Allow only Euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
:param coordiniates: A set of coordinates.
|
||||
:type coordiniates: Tuple[int, int]
|
||||
|
@ -1,17 +1,35 @@
|
||||
General:
|
||||
# RNG-seed to sample the same "random" numbers every time, to make the different runs comparable.
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: rooms
|
||||
# Radius of Partially observable Markov decision process
|
||||
pomdp_r: 3
|
||||
# Print all messages and events
|
||||
verbose: true
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
# In the "clean and bring" Scenario one agent aims to pick up all items and drop them at drop-off locations while all
|
||||
# other agents aim to clean dirt piles.
|
||||
Agents:
|
||||
# The clean agents
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Move8
|
||||
- DoorUse
|
||||
- Clean
|
||||
- Noop
|
||||
- Move8
|
||||
- DoorUse
|
||||
- Clean
|
||||
- Noop
|
||||
Observations:
|
||||
- Walls
|
||||
- Doors
|
||||
- Other
|
||||
- DirtPiles
|
||||
- Walls
|
||||
- Doors
|
||||
- Other
|
||||
- DirtPiles
|
||||
Clones: 8
|
||||
|
||||
# The item agent
|
||||
Juergen:
|
||||
Actions:
|
||||
- Move8
|
||||
@ -38,37 +56,37 @@ Entities:
|
||||
DropOffLocations:
|
||||
coords_or_quantity: 1
|
||||
max_dropoff_storage_size: 0
|
||||
Inventories: {}
|
||||
Inventories: { }
|
||||
Items:
|
||||
coords_or_quantity: 5
|
||||
|
||||
|
||||
General:
|
||||
env_seed: 69
|
||||
individual_rewards: true
|
||||
level_name: rooms
|
||||
pomdp_r: 3
|
||||
verbose: True
|
||||
tests: false
|
||||
|
||||
# Rules section specifies the rules governing the dynamics of the environment.
|
||||
Rules:
|
||||
# Environment Dynamics
|
||||
# When stepping over a dirt pile, entities carry a ratio of the dirt to their next position
|
||||
EntitiesSmearDirtOnMove:
|
||||
smear_ratio: 0.2
|
||||
# Doors automatically close after a certain number of time steps
|
||||
DoorAutoClose:
|
||||
close_frequency: 7
|
||||
|
||||
# Respawn Stuff
|
||||
# Define how dirt should respawn after the initial spawn
|
||||
RespawnDirt:
|
||||
respawn_freq: 30
|
||||
# Define how items should respawn after the initial spawn
|
||||
RespawnItems:
|
||||
respawn_freq: 50
|
||||
|
||||
# Utilities
|
||||
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||
# Can be omitted/ignored if you do not want to take care of collisions at all.
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Done Conditions
|
||||
# Define the conditions for the environment to stop. Either success or a fail conditions.
|
||||
# The environment stops when all dirt is cleaned
|
||||
DoneOnAllDirtCleaned:
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 500
|
||||
|
@ -1,37 +1,74 @@
|
||||
# Default Configuration File
|
||||
|
||||
General:
|
||||
# RNG-seed to sample the same "random" numbers every time, to make the different runs comparable.
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: large
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 3
|
||||
# Print all messages and events
|
||||
verbose: false
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
# Agents section defines the characteristics of different agents in the environment.
|
||||
|
||||
# An Agent requires a list of actions and observations.
|
||||
# Possible actions: Noop, Charge, Clean, DestAction, DoorUse, ItemAction, MachineAction, Move8, Move4, North, NorthEast, ...
|
||||
# Possible observations: All, Combined, GlobalPosition, Battery, ChargePods, DirtPiles, Destinations, Doors, Items, Inventory, DropOffLocations, Maintainers, ...
|
||||
# You can use 'clone' as the agent name to have multiple instances with either a list of names or an int specifying the number of clones.
|
||||
Agents:
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Noop
|
||||
- Charge
|
||||
- Clean
|
||||
- DestAction
|
||||
- DoorUse
|
||||
- ItemAction
|
||||
- Move8
|
||||
- Noop
|
||||
- Charge
|
||||
- Clean
|
||||
- DestAction
|
||||
- DoorUse
|
||||
- ItemAction
|
||||
- Move8
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- DirtPiles
|
||||
- Destinations
|
||||
- Doors
|
||||
- Items
|
||||
- Inventory
|
||||
- DropOffLocations
|
||||
- Maintainers
|
||||
Entities:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- DirtPiles
|
||||
- Destinations
|
||||
- Doors
|
||||
- Items
|
||||
- Inventory
|
||||
- DropOffLocations
|
||||
- Maintainers
|
||||
|
||||
# Entities section defines the initial parameters and behaviors of different entities in the environment.
|
||||
# Entities all spawn using coords_or_quantity, a number of entities or coordinates to place them.
|
||||
Entities:
|
||||
# Batteries: Entities representing power sources for agents.
|
||||
Batteries:
|
||||
initial_charge: 0.8
|
||||
per_action_costs: 0.02
|
||||
|
||||
# ChargePods: Entities representing charging stations for Batteries.
|
||||
ChargePods:
|
||||
coords_or_quantity: 2
|
||||
|
||||
# Destinations: Entities representing target locations for agents.
|
||||
# - spawn_mode: GROUPED or SINGLE. Determines how destinations are spawned.
|
||||
Destinations:
|
||||
coords_or_quantity: 1
|
||||
spawn_mode: GROUPED
|
||||
|
||||
# DirtPiles: Entities representing piles of dirt.
|
||||
# - initial_amount: Initial amount of dirt in each pile.
|
||||
# - clean_amount: Amount of dirt cleaned in each cleaning action.
|
||||
# - dirt_spawn_r_var: Random variation in dirt spawn amounts.
|
||||
# - max_global_amount: Maximum total amount of dirt allowed in the environment.
|
||||
# - max_local_amount: Maximum amount of dirt allowed in one position.
|
||||
DirtPiles:
|
||||
coords_or_quantity: 10
|
||||
initial_amount: 2
|
||||
@ -39,50 +76,71 @@ Entities:
|
||||
dirt_spawn_r_var: 0.1
|
||||
max_global_amount: 20
|
||||
max_local_amount: 5
|
||||
|
||||
# Doors are spawned using the level map.
|
||||
Doors:
|
||||
|
||||
# DropOffLocations: Entities representing locations where agents can drop off items.
|
||||
# - max_dropoff_storage_size: Maximum storage capacity at each drop-off location.
|
||||
DropOffLocations:
|
||||
coords_or_quantity: 1
|
||||
max_dropoff_storage_size: 0
|
||||
GlobalPositions: {}
|
||||
Inventories: {}
|
||||
|
||||
# GlobalPositions.
|
||||
GlobalPositions: { }
|
||||
|
||||
# Inventories: Entities representing inventories for agents.
|
||||
Inventories: { }
|
||||
|
||||
# Items: Entities representing items in the environment.
|
||||
Items:
|
||||
coords_or_quantity: 5
|
||||
|
||||
# Machines: Entities representing machines in the environment.
|
||||
Machines:
|
||||
coords_or_quantity: 2
|
||||
|
||||
# Maintainers: Entities representing maintainers that aim to maintain machines.
|
||||
Maintainers:
|
||||
coords_or_quantity: 1
|
||||
Zones: {}
|
||||
|
||||
General:
|
||||
env_seed: 69
|
||||
individual_rewards: true
|
||||
level_name: large
|
||||
pomdp_r: 3
|
||||
verbose: False
|
||||
tests: false
|
||||
|
||||
# Rules section specifies the rules governing the dynamics of the environment.
|
||||
Rules:
|
||||
# Environment Dynamics
|
||||
# When stepping over a dirt pile, entities carry a ratio of the dirt to their next position
|
||||
EntitiesSmearDirtOnMove:
|
||||
smear_ratio: 0.2
|
||||
# Doors automatically close after a certain number of time steps
|
||||
DoorAutoClose:
|
||||
close_frequency: 10
|
||||
# Maintainers move at every time step
|
||||
MoveMaintainers:
|
||||
|
||||
# Respawn Stuff
|
||||
# Define how dirt should respawn after the initial spawn
|
||||
RespawnDirt:
|
||||
respawn_freq: 15
|
||||
# Define how items should respawn after the initial spawn
|
||||
RespawnItems:
|
||||
respawn_freq: 15
|
||||
|
||||
# Utilities
|
||||
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||
# Can be omitted/ignored if you do not want to take care of collisions at all.
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Done Conditions
|
||||
# Define the conditions for the environment to stop. Either success or a fail conditions.
|
||||
# The environment stops when an agent reaches a destination
|
||||
DoneAtDestinationReach:
|
||||
# The environment stops when all dirt is cleaned
|
||||
DoneOnAllDirtCleaned:
|
||||
# The environment stops when a battery is discharged
|
||||
DoneAtBatteryDischarge:
|
||||
# The environment stops when a maintainer reports a collision
|
||||
DoneAtMaintainerCollision:
|
||||
# The environment stops after max steps
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 500
|
||||
|
@ -84,6 +84,6 @@ Rules:
|
||||
# On every step, should there be a reward for agets that reach their associated destination? No!
|
||||
dest_reach_reward: 0 # Do not touch. This is usefull in other settings!
|
||||
# Reward should only be given when all destiantions are reached in parallel!
|
||||
condition: "simultanious"
|
||||
condition: "simultaneous"
|
||||
# Reward if this is the case. Granted to each agent when all agents are at their target position simultaniously.
|
||||
reward_at_done: 1
|
||||
|
@ -1,14 +1,16 @@
|
||||
General:
|
||||
# Your Seed
|
||||
env_seed: 69
|
||||
# Individual or global rewards?
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: narrow_corridor
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 0
|
||||
# print all messages and events
|
||||
verbose: true
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
Agents:
|
||||
# Agents are identified by their name
|
||||
|
@ -1,51 +1,61 @@
|
||||
General:
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: two_rooms
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 3
|
||||
# Print all messages and events
|
||||
verbose: false
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
Entities:
|
||||
Destinations: {}
|
||||
Doors: {}
|
||||
GlobalPositions: {}
|
||||
Zones: {}
|
||||
|
||||
Rules:
|
||||
# Init:
|
||||
AssignGlobalPositions: {}
|
||||
ZoneInit: {}
|
||||
AgentSingleZonePlacement: {}
|
||||
IndividualDestinationZonePlacement: {}
|
||||
# Env Rules
|
||||
MaxStepsReached:
|
||||
max_steps: 10
|
||||
Collision:
|
||||
done_at_collisions: false
|
||||
DoorAutoClose:
|
||||
close_frequency: 10
|
||||
|
||||
# In "two rooms one door" scenario 2 agents spawn in 2 different rooms that are connected by a single door. Their aim
|
||||
# is to reach the destination in the room they didn't spawn in leading to a conflict at the door.
|
||||
Agents:
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Move8
|
||||
- Noop
|
||||
- DestAction
|
||||
- DoorUse
|
||||
- Move8
|
||||
- Noop
|
||||
- DestAction
|
||||
- DoorUse
|
||||
Observations:
|
||||
- Walls
|
||||
- Other
|
||||
- Doors
|
||||
- Destination
|
||||
- Walls
|
||||
- Other
|
||||
- Doors
|
||||
- Destination
|
||||
Sigmund:
|
||||
Actions:
|
||||
- Move8
|
||||
- Noop
|
||||
- DestAction
|
||||
- DoorUse
|
||||
- Move8
|
||||
- Noop
|
||||
- DestAction
|
||||
- DoorUse
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- Destination
|
||||
- Doors
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- Destination
|
||||
- Doors
|
||||
|
||||
Entities:
|
||||
Destinations: { }
|
||||
Doors: { }
|
||||
GlobalPositions: { }
|
||||
|
||||
Rules:
|
||||
# Environment Dynamics
|
||||
DoorAutoClose:
|
||||
close_frequency: 10
|
||||
|
||||
# Utilities
|
||||
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Init
|
||||
AssignGlobalPositions: { }
|
||||
|
||||
# Done Conditions
|
||||
MaxStepsReached:
|
||||
max_steps: 10
|
||||
|
@ -6,24 +6,44 @@ from marl_factory_grid.environment import rewards as r, constants as c
|
||||
from marl_factory_grid.utils.helpers import MOVEMAP
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
|
||||
TYPE_COLLISION = 'collision'
|
||||
|
||||
class Action(abc.ABC):
|
||||
|
||||
class Action(abc.ABC):
|
||||
@property
|
||||
def name(self):
|
||||
return self._identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
|
||||
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
|
||||
valid_reward: float | None = None, fail_reward: float | None = None):
|
||||
"""
|
||||
Abstract base class representing an action that can be performed in the environment.
|
||||
|
||||
:param identifier: A unique identifier for the action.
|
||||
:type identifier: str
|
||||
:param default_valid_reward: Default reward for a valid action.
|
||||
:type default_valid_reward: float
|
||||
:param default_fail_reward: Default reward for a failed action.
|
||||
:type default_fail_reward: float
|
||||
:param valid_reward: Custom reward for a valid action (optional).
|
||||
:type valid_reward: Union[float, optional]
|
||||
:param fail_reward: Custom reward for a failed action (optional).
|
||||
:type fail_reward: Union[float, optional]
|
||||
"""
|
||||
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
|
||||
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
|
||||
self._identifier = identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
"""
|
||||
Let the :class:`marl_factory_grid.environment.entity.entity.Entity` perform the given.
|
||||
|
||||
:param entity: The entity to perform the action; mostly `marl_factory_grid.environment.entity.agent.Agent`
|
||||
:param state: The current :class:'marl_factory_grid.utils.states.Gamestate'
|
||||
:return:
|
||||
"""
|
||||
validity = bool(random.choice([0, 1]))
|
||||
return self.get_result(validity, entity)
|
||||
|
||||
@ -31,6 +51,9 @@ class Action(abc.ABC):
|
||||
return f'Action[{self._identifier}]'
|
||||
|
||||
def get_result(self, validity, entity, action_introduced_collision=False):
|
||||
"""
|
||||
Generate an ActionResult for the action based on its validity.
|
||||
"""
|
||||
reward = self.valid_reward if validity else self.fail_reward
|
||||
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
|
||||
action_introduced_collision=action_introduced_collision)
|
||||
|
@ -74,7 +74,6 @@ class Agent(Entity):
|
||||
self.step_result = dict()
|
||||
self._actions = actions
|
||||
self._observations = observations
|
||||
self._status: Union[Result, None] = None
|
||||
self._is_blocking_pos = is_blocking_pos
|
||||
|
||||
def summarize_state(self) -> dict[str]:
|
||||
|
@ -13,30 +13,28 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def state(self):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
Get the current status of the entity. Not to be confused with the Gamestate.
|
||||
:return: status
|
||||
"""
|
||||
return self._status or State(entity=self, identifier=c.NOOP, validity=c.VALID)
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
"""
|
||||
TODO
|
||||
Check if the entity has a position.
|
||||
|
||||
|
||||
:return:
|
||||
:return: True if the entity has a position, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return self.pos != c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
"""
|
||||
TODO
|
||||
Check if the entity is blocking light.
|
||||
|
||||
|
||||
:return:
|
||||
:return: True if the entity is blocking light, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
try:
|
||||
return self._collection.var_is_blocking_light or False
|
||||
@ -46,10 +44,10 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def var_can_move(self):
|
||||
"""
|
||||
TODO
|
||||
Check if the entity can move.
|
||||
|
||||
|
||||
:return:
|
||||
:return: True if the entity can move, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
try:
|
||||
return self._collection.var_can_move or False
|
||||
@ -59,10 +57,10 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def var_is_blocking_pos(self):
|
||||
"""
|
||||
TODO
|
||||
Check if the entity is blocking a position when standing on it.
|
||||
|
||||
|
||||
:return:
|
||||
:return: True if the entity is blocking a position, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
try:
|
||||
return self._collection.var_is_blocking_pos or False
|
||||
@ -72,10 +70,10 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
"""
|
||||
TODO
|
||||
Check if the entity can collide.
|
||||
|
||||
|
||||
:return:
|
||||
:return: True if the entity can collide, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
try:
|
||||
return self._collection.var_can_collide or False
|
||||
@ -85,39 +83,40 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def x(self):
|
||||
"""
|
||||
TODO
|
||||
Get the x-coordinate of the entity's position.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The x-coordinate of the entity's position.
|
||||
:rtype: int
|
||||
"""
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
"""
|
||||
TODO
|
||||
Get the y-coordinate of the entity's position.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The y-coordinate of the entity's position.
|
||||
:rtype: int
|
||||
"""
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
"""
|
||||
TODO
|
||||
Get the current position of the entity.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The current position of the entity.
|
||||
:rtype: tuple
|
||||
"""
|
||||
return self._pos
|
||||
|
||||
def set_pos(self, pos) -> bool:
|
||||
"""
|
||||
TODO
|
||||
Set the position of the entity.
|
||||
|
||||
|
||||
:return:
|
||||
:param pos: The new position.
|
||||
:type pos: tuple
|
||||
:return: True if setting the position is successful, False otherwise.
|
||||
"""
|
||||
assert isinstance(pos, tuple) and len(pos) == 2
|
||||
self._pos = pos
|
||||
@ -126,10 +125,10 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def last_pos(self):
|
||||
"""
|
||||
TODO
|
||||
Get the last position of the entity.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The last position of the entity.
|
||||
:rtype: tuple
|
||||
"""
|
||||
try:
|
||||
return self._last_pos
|
||||
@ -141,22 +140,49 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def direction_of_view(self):
|
||||
"""
|
||||
TODO
|
||||
Get the current direction of view of the entity.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The current direction of view of the entity.
|
||||
:rtype: int
|
||||
"""
|
||||
if self._last_pos != c.VALUE_NO_POS:
|
||||
return 0, 0
|
||||
else:
|
||||
return np.subtract(self._last_pos, self.pos)
|
||||
|
||||
def __init__(self, pos, bind_to=None, **kwargs):
|
||||
"""
|
||||
Abstract base class representing entities in the environment grid.
|
||||
|
||||
:param pos: The initial position of the entity.
|
||||
:type pos: tuple
|
||||
:param bind_to: Entity to which this entity is bound (Default: None)
|
||||
:type bind_to: Entity or None
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._view_directory = c.VALUE_NO_POS
|
||||
self._status = None
|
||||
self._pos = pos
|
||||
self._last_pos = pos
|
||||
self._collection = None
|
||||
if bind_to:
|
||||
try:
|
||||
self.bind_to(bind_to)
|
||||
except AttributeError:
|
||||
print(f'Objects of class "{self.__class__.__name__}" can not be bound to other entities.')
|
||||
exit()
|
||||
|
||||
def move(self, next_pos, state):
|
||||
"""
|
||||
TODO
|
||||
Move the entity to a new position.
|
||||
|
||||
:param next_pos: The next position to move the entity to.
|
||||
:type next_pos: tuple
|
||||
:param state: The current state of the environment.
|
||||
:type state: marl_factory_grid.environment.state.Gamestate
|
||||
|
||||
:return:
|
||||
:return: True if the move is valid, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
next_pos = next_pos
|
||||
curr_pos = self._pos
|
||||
@ -172,43 +198,22 @@ class Entity(Object, abc.ABC):
|
||||
# Bad naming... Was the same was the same pos, not moving....
|
||||
return not_same_pos
|
||||
|
||||
def __init__(self, pos, bind_to=None, **kwargs):
|
||||
"""
|
||||
Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._view_directory = c.VALUE_NO_POS
|
||||
self._status = None
|
||||
self._pos = pos
|
||||
self._last_pos = pos
|
||||
self._collection = None
|
||||
if bind_to:
|
||||
try:
|
||||
self.bind_to(bind_to)
|
||||
except AttributeError:
|
||||
print(f'Objects of class "{self.__class__.__name__}" can not be bound to other entities.')
|
||||
exit()
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
"""
|
||||
TODO
|
||||
Summarize the current state of the entity.
|
||||
|
||||
|
||||
:return:
|
||||
:return: A dictionary containing the name, x-coordinate, y-coordinate, and can_collide property of the entity.
|
||||
:rtype: dict
|
||||
"""
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y), can_collide=bool(self.var_can_collide))
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
"""
|
||||
TODO
|
||||
Abstract method to render the entity.
|
||||
|
||||
|
||||
:return:
|
||||
:return: A rendering entity representing the entity's appearance in the environment.
|
||||
:rtype: marl_factory_grid.utils.utility_classes.RenderEntity
|
||||
"""
|
||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||
|
||||
@ -223,19 +228,22 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def encoding(self):
|
||||
"""
|
||||
TODO
|
||||
Get the encoded representation of the entity.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The encoded representation.
|
||||
:rtype: int
|
||||
"""
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def change_parent_collection(self, other_collection):
|
||||
"""
|
||||
TODO
|
||||
Change the parent collection of the entity.
|
||||
|
||||
:param other_collection: The new parent collection.
|
||||
:type other_collection: marl_factory_grid.environment.collections.Collection
|
||||
|
||||
:return:
|
||||
:return: True if the change is successful, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
@ -245,9 +253,9 @@ class Entity(Object, abc.ABC):
|
||||
@property
|
||||
def collection(self):
|
||||
"""
|
||||
TODO
|
||||
Get the parent collection of the entity.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The parent collection.
|
||||
:rtype: marl_factory_grid.environment.collections.Collection
|
||||
"""
|
||||
return self._collection
|
||||
|
@ -12,17 +12,15 @@ class Object:
|
||||
@property
|
||||
def bound_entity(self):
|
||||
"""
|
||||
TODO
|
||||
Returns the entity to which this object is bound.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The bound entity.
|
||||
"""
|
||||
return self._bound_entity
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self) -> bool:
|
||||
"""
|
||||
TODO
|
||||
Indicates if it is possible to bind this object to another Entity or Object.
|
||||
|
||||
:return: Whether this object can be bound.
|
||||
@ -35,30 +33,27 @@ class Object:
|
||||
@property
|
||||
def observers(self) -> set:
|
||||
"""
|
||||
TODO
|
||||
Returns the set of observers for this object.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Set of observers.
|
||||
"""
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
TODO
|
||||
Returns a string representation of the object's name.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The name of the object.
|
||||
"""
|
||||
return f'{self.__class__.__name__}[{self.identifier}]'
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
"""
|
||||
TODO
|
||||
Returns the unique identifier of the object.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The unique identifier.
|
||||
"""
|
||||
if self._str_ident is not None:
|
||||
return self._str_ident
|
||||
@ -67,23 +62,19 @@ class Object:
|
||||
|
||||
def reset_uid(self):
|
||||
"""
|
||||
TODO
|
||||
Resets the unique identifier counter for this class.
|
||||
|
||||
|
||||
:return:
|
||||
:return: True if the reset was successful.
|
||||
"""
|
||||
self._u_idx = defaultdict(lambda: 0)
|
||||
return True
|
||||
|
||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||
"""
|
||||
Generell Objects for Organisation and Maintanance such as Actions etc...
|
||||
General Objects for Organisation and Maintenance such as Actions, etc.
|
||||
|
||||
TODO
|
||||
|
||||
:param str_ident:
|
||||
|
||||
:return:
|
||||
:param str_ident: A string identifier for the object.
|
||||
:return: None
|
||||
"""
|
||||
self._status = None
|
||||
self._bound_entity = None
|
||||
@ -147,28 +138,28 @@ class Object:
|
||||
|
||||
def bind_to(self, entity):
|
||||
"""
|
||||
TODO
|
||||
Binds the object to a specified entity.
|
||||
|
||||
|
||||
:return:
|
||||
:param entity: The entity to bind to.
|
||||
:return: The validity of the binding.
|
||||
"""
|
||||
self._bound_entity = entity
|
||||
return c.VALID
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
"""
|
||||
TODO
|
||||
Checks if the object belongs to a specified entity.
|
||||
|
||||
|
||||
:return:
|
||||
:param entity: The entity to check against.
|
||||
:return: True if the object belongs to the entity, False otherwise.
|
||||
"""
|
||||
return self._bound_entity == entity
|
||||
|
||||
def unbind(self):
|
||||
"""
|
||||
TODO
|
||||
Unbinds the object from its current entity.
|
||||
|
||||
:return:
|
||||
:return: The entity that the object was previously bound to.
|
||||
"""
|
||||
previously_bound = self._bound_entity
|
||||
self._bound_entity = None
|
||||
|
@ -4,7 +4,7 @@ from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
|
||||
##########################################################################
|
||||
# ####################### Objects and Entitys ########################## #
|
||||
# ####################### Objects and Entities ########################## #
|
||||
##########################################################################
|
||||
|
||||
|
||||
@ -12,10 +12,11 @@ class PlaceHolder(Object):
|
||||
|
||||
def __init__(self, *args, fill_value=0, **kwargs):
|
||||
"""
|
||||
TODO
|
||||
A placeholder object that can be used as an observation during training. It is designed to be later replaced
|
||||
with a meaningful observation that wasn't initially present in the training run.
|
||||
|
||||
|
||||
:return:
|
||||
:param fill_value: The default value to fill the placeholder observation (Default: 0)
|
||||
:type fill_value: Any
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._fill_value = fill_value
|
||||
@ -23,20 +24,20 @@ class PlaceHolder(Object):
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
"""
|
||||
TODO
|
||||
Indicates whether this placeholder object can collide with other entities. Always returns False.
|
||||
|
||||
|
||||
:return:
|
||||
:return: False
|
||||
:rtype: bool
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
"""
|
||||
TODO
|
||||
Get the fill value representing the placeholder observation.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The fill value
|
||||
:rtype: Any
|
||||
"""
|
||||
return self._fill_value
|
||||
|
||||
@ -54,10 +55,10 @@ class GlobalPosition(Object):
|
||||
@property
|
||||
def encoding(self):
|
||||
"""
|
||||
TODO
|
||||
Get the encoded representation of the global position based on whether normalization is enabled.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The encoded representation of the global position
|
||||
:rtype: tuple[float, float] or tuple[int, int]
|
||||
"""
|
||||
if self._normalized:
|
||||
return tuple(np.divide(self._bound_entity.pos, self._shape))
|
||||
@ -66,10 +67,14 @@ class GlobalPosition(Object):
|
||||
|
||||
def __init__(self, agent, level_shape, *args, normalized: bool = True, **kwargs):
|
||||
"""
|
||||
TODO
|
||||
A utility class representing the global position of an entity in the environment.
|
||||
|
||||
|
||||
:return:
|
||||
:param agent: The agent entity to which the global position is bound.
|
||||
:type agent: marl_factory_grid.environment.entity.agent.Agent
|
||||
:param level_shape: The shape of the environment level.
|
||||
:type level_shape: tuple[int, int]
|
||||
:param normalized: Indicates whether the global position should be normalized (Default: True)
|
||||
:type normalized: bool
|
||||
"""
|
||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||
self.bind_to(agent)
|
||||
|
@ -7,10 +7,7 @@ class Wall(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
A class representing a wall entity in the environment.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -24,47 +24,48 @@ class Factory(gym.Env):
|
||||
@property
|
||||
def action_space(self):
|
||||
"""
|
||||
TODO
|
||||
The action space defines the set of all possible actions that an agent can take in the environment.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Action space
|
||||
:rtype: gym.Space
|
||||
"""
|
||||
return self.state[c.AGENT].action_space
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
"""
|
||||
TODO
|
||||
Returns the named action space for agents.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Named action space
|
||||
:rtype: dict[str, dict[str, list[int]]]
|
||||
"""
|
||||
return self.state[c.AGENT].named_action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
"""
|
||||
TODO
|
||||
The observation space represents all the information that an agent can receive from the environment at a given
|
||||
time step.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Observation space.
|
||||
:rtype: gym.Space
|
||||
"""
|
||||
return self.obs_builder.observation_space(self.state)
|
||||
|
||||
@property
|
||||
def named_observation_space(self):
|
||||
"""
|
||||
TODO
|
||||
Returns the named observation space for the environment.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Named observation space.
|
||||
:rtype: (dict, dict)
|
||||
"""
|
||||
return self.obs_builder.named_observation_space(self.state)
|
||||
|
||||
@property
|
||||
def params(self) -> dict:
|
||||
"""
|
||||
FIXME LAGEGY
|
||||
FIXME LEGACY
|
||||
|
||||
|
||||
:return:
|
||||
@ -80,10 +81,14 @@ class Factory(gym.Env):
|
||||
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
|
||||
custom_level_path: Union[None, PathLike] = None):
|
||||
"""
|
||||
TODO
|
||||
Initializes the marl-factory-grid as Gym environment.
|
||||
|
||||
|
||||
:return:
|
||||
:param config_file: Path to the configuration file.
|
||||
:type config_file: Union[str, PathLike]
|
||||
:param custom_modules_path: Path to custom modules directory. (Default: None)
|
||||
:type custom_modules_path: Union[None, PathLike]
|
||||
:param custom_level_path: Path to custom level file. (Default: None)
|
||||
:type custom_level_path: Union[None, PathLike]
|
||||
"""
|
||||
self._config_file = config_file
|
||||
self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
|
||||
@ -188,6 +193,16 @@ class Factory(gym.Env):
|
||||
return reward, done, info
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Run one timestep of the environment's dynamics using the agent actions.
|
||||
|
||||
When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to
|
||||
reset this environment's state for the next episode.
|
||||
|
||||
:param actions: An action or list of actions provided by the agent(s) to update the environment state.
|
||||
:return: observation, reward, terminated, truncated, info, done
|
||||
:rtype: tuple(list(np.ndarray), float, bool, bool, dict, bool)
|
||||
"""
|
||||
|
||||
if not isinstance(actions, list):
|
||||
actions = [int(actions)]
|
||||
|
@ -37,10 +37,10 @@ class Agents(Collection):
|
||||
@property
|
||||
def action_space(self):
|
||||
"""
|
||||
TODO
|
||||
The action space defines the set of all possible actions that an agent can take in the environment.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Action space
|
||||
:rtype: gym.Space
|
||||
"""
|
||||
from gymnasium import spaces
|
||||
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
|
||||
@ -49,10 +49,10 @@ class Agents(Collection):
|
||||
@property
|
||||
def named_action_space(self) -> dict[str, dict[str, list[int]]]:
|
||||
"""
|
||||
TODO
|
||||
Returns the named action space for agents.
|
||||
|
||||
|
||||
:return:
|
||||
:return: Named action space
|
||||
:rtype: dict[str, dict[str, list[int]]]
|
||||
"""
|
||||
named_space = dict()
|
||||
for agent in self:
|
||||
|
@ -13,31 +13,65 @@ class Collection(Objects):
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
"""
|
||||
Indicates whether the collection blocks light.
|
||||
|
||||
:return: Always False for a collection.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_pos(self):
|
||||
"""
|
||||
Indicates whether the collection blocks positions.
|
||||
|
||||
:return: Always False for a collection.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
"""
|
||||
Indicates whether the collection can collide.
|
||||
|
||||
:return: Always False for a collection.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
"""
|
||||
Indicates whether the collection can move.
|
||||
|
||||
:return: Always False for a collection.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
"""
|
||||
Indicates whether the collection has positions.
|
||||
|
||||
:return: Always True for a collection.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
"""
|
||||
Returns a list of encodings for all entities in the collection.
|
||||
|
||||
:return: List of encodings.
|
||||
"""
|
||||
return [x.encoding for x in self]
|
||||
|
||||
@property
|
||||
def spawn_rule(self):
|
||||
"""Prevent SpawnRule creation if Objects are spawned by map, Doors e.g."""
|
||||
"""
|
||||
Prevents SpawnRule creation if Objects are spawned by the map, doors, etc.
|
||||
|
||||
:return: The spawn rule or None.
|
||||
"""
|
||||
if self.symbol:
|
||||
return None
|
||||
elif self._spawnrule:
|
||||
@ -48,6 +82,17 @@ class Collection(Objects):
|
||||
def __init__(self, size, *args, coords_or_quantity: int = None, ignore_blocking=False,
|
||||
spawnrule: Union[None, Dict[str, dict]] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Initializes the Collection.
|
||||
|
||||
:param size: Size of the collection.
|
||||
:type size: int
|
||||
:param coords_or_quantity: Coordinates or quantity for spawning entities.
|
||||
:param ignore_blocking: Ignore blocking when spawning entities.
|
||||
:type ignore_blocking: bool
|
||||
:param spawnrule: Spawn rule for the collection. Default: None
|
||||
:type spawnrule: Union[None, Dict[str, dict]]
|
||||
"""
|
||||
super(Collection, self).__init__(*args, **kwargs)
|
||||
self._coords_or_quantity = coords_or_quantity
|
||||
self.size = size
|
||||
@ -55,6 +100,17 @@ class Collection(Objects):
|
||||
self._ignore_blocking = ignore_blocking
|
||||
|
||||
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
|
||||
"""
|
||||
Triggers the spawning of entities in the collection.
|
||||
|
||||
:param state: The game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:param entity_args: Additional arguments for entity creation.
|
||||
:param coords_or_quantity: Coordinates or quantity for spawning entities.
|
||||
:param ignore_blocking: Ignore blocking when spawning entities.
|
||||
:param entity_kwargs: Additional keyword arguments for entity creation.
|
||||
:return: Result of the spawn operation.
|
||||
"""
|
||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
|
||||
if self.var_has_position:
|
||||
if self.var_has_position and isinstance(coords_or_quantity, int):
|
||||
@ -74,6 +130,14 @@ class Collection(Objects):
|
||||
raise ValueError(f'{self._entity.__name__} has no position!')
|
||||
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
|
||||
"""
|
||||
Spawns entities in the collection.
|
||||
|
||||
:param coords_or_quantity: Coordinates or quantity for spawning entities.
|
||||
:param entity_args: Additional arguments for entity creation.
|
||||
:param entity_kwargs: Additional keyword arguments for entity creation.
|
||||
:return: Validity of the spawn operation.
|
||||
"""
|
||||
if self.var_has_position:
|
||||
if isinstance(coords_or_quantity, int):
|
||||
raise ValueError(f'{self._entity.__name__} should have a position!')
|
||||
@ -87,6 +151,11 @@ class Collection(Objects):
|
||||
return c.VALID
|
||||
|
||||
def despawn(self, items: List[Object]):
|
||||
"""
|
||||
Despawns entities from the collection.
|
||||
|
||||
:param items: List of entities to despawn.
|
||||
"""
|
||||
items = [items] if isinstance(items, Object) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
@ -97,9 +166,19 @@ class Collection(Objects):
|
||||
return self
|
||||
|
||||
def delete_env_object(self, env_object):
|
||||
"""
|
||||
Deletes an environmental object from the collection.
|
||||
|
||||
:param env_object: The environmental object to delete.
|
||||
"""
|
||||
del self[env_object.name]
|
||||
|
||||
def delete_env_object_by_name(self, name):
|
||||
"""
|
||||
Deletes an environmental object from the collection by name.
|
||||
|
||||
:param name: The name of the environmental object to delete.
|
||||
"""
|
||||
del self[name]
|
||||
|
||||
@property
|
||||
@ -126,6 +205,13 @@ class Collection(Objects):
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
|
||||
"""
|
||||
Creates a collection of entities from specified coordinates.
|
||||
|
||||
:param positions: List of coordinates for entity positions.
|
||||
:param args: Additional positional arguments.
|
||||
:return: The created collection.
|
||||
"""
|
||||
collection = cls(*args, **kwargs)
|
||||
collection.add_items(
|
||||
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
|
||||
@ -141,6 +227,12 @@ class Collection(Objects):
|
||||
super().__delitem__(name)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
"""
|
||||
Retrieves an entity from the collection based on its position.
|
||||
|
||||
:param pos: The position tuple.
|
||||
:return: The entity at the specified position or None if not found.
|
||||
"""
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return self.pos_dict[pos]
|
||||
@ -151,6 +243,11 @@ class Collection(Objects):
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
"""
|
||||
Returns a list of positions for all entities in the collection.
|
||||
|
||||
:return: List of positions.
|
||||
"""
|
||||
return [e.pos for e in self]
|
||||
|
||||
def notify_del_entity(self, entity: Entity):
|
||||
|
@ -11,12 +11,30 @@ class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
def neighboring_positions(self, pos):
|
||||
"""
|
||||
Get all 8 neighboring positions of a given position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of neighboring positions.
|
||||
"""
|
||||
return [tuple(x) for x in (POS_MASK_8 + pos).reshape(-1, 2) if tuple(x) in self._floor_positions]
|
||||
|
||||
def neighboring_4_positions(self, pos):
|
||||
"""
|
||||
Get neighboring 4 positions of a given position. (North, East, South, West)
|
||||
|
||||
:param pos: Reference position.
|
||||
:return: List of neighboring positions.
|
||||
"""
|
||||
return [tuple(x) for x in (POS_MASK_4 + pos) if tuple(x) in self._floor_positions]
|
||||
|
||||
def get_entities_near_pos(self, pos):
|
||||
"""
|
||||
Get entities near a given position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities near the position.
|
||||
"""
|
||||
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]
|
||||
|
||||
def render(self):
|
||||
@ -28,10 +46,18 @@ class Entities(Objects):
|
||||
|
||||
@property
|
||||
def floorlist(self):
|
||||
"""
|
||||
Shuffle and return the list of floor positions.
|
||||
|
||||
:return: Shuffled list of floor positions.
|
||||
"""
|
||||
shuffle(self._floor_positions)
|
||||
return [x for x in self._floor_positions]
|
||||
|
||||
def __init__(self, floor_positions):
|
||||
"""
|
||||
:param floor_positions: list of all positions that are not blocked by a wall.
|
||||
"""
|
||||
self._floor_positions = floor_positions
|
||||
self.pos_dict = None
|
||||
super().__init__()
|
||||
@ -40,28 +66,54 @@ class Entities(Objects):
|
||||
return f'{self.__class__.__name__}{[x for x in self]}'
|
||||
|
||||
def guests_that_can_collide(self, pos):
|
||||
"""
|
||||
Get entities at a position that can collide.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities at the position that can collide.
|
||||
"""
|
||||
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||
|
||||
@property
|
||||
def empty_positions(self):
|
||||
"""
|
||||
Get shuffled list of empty positions.
|
||||
|
||||
:return: Shuffled list of empty positions.
|
||||
"""
|
||||
empty_positions = [key for key in self.floorlist if not self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
@property
|
||||
def occupied_positions(self): # positions that are not empty
|
||||
def occupied_positions(self):
|
||||
"""
|
||||
Get shuffled list of occupied positions.
|
||||
|
||||
:return: Shuffled list of occupied positions.
|
||||
"""
|
||||
empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
@property
|
||||
def blocked_positions(self):
|
||||
"""
|
||||
Get shuffled list of blocked positions.
|
||||
|
||||
:return: Shuffled list of blocked positions.
|
||||
"""
|
||||
blocked_positions = [key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
||||
shuffle(blocked_positions)
|
||||
return blocked_positions
|
||||
|
||||
@property
|
||||
def free_positions_generator(self):
|
||||
"""
|
||||
Get a generator for free positions.
|
||||
|
||||
:return: Generator for free positions.
|
||||
"""
|
||||
generator = (
|
||||
key for key in self.floorlist if all(not x.var_can_collide and not x.var_is_blocking_pos
|
||||
for x in self.pos_dict[key])
|
||||
@ -70,9 +122,19 @@ class Entities(Objects):
|
||||
|
||||
@property
|
||||
def free_positions_list(self):
|
||||
"""
|
||||
Get a list of free positions.
|
||||
|
||||
:return: List of free positions.
|
||||
"""
|
||||
return [x for x in self.free_positions_generator]
|
||||
|
||||
def iter_entities(self):
|
||||
"""
|
||||
Get an iterator over all entities in the collection.
|
||||
|
||||
:return: Iterator over entities.
|
||||
"""
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def add_items(self, items: Dict):
|
||||
@ -105,13 +167,30 @@ class Entities(Objects):
|
||||
print('OhOh (debug me)')
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
"""
|
||||
Get entities at a specific position.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: List of entities at the position.
|
||||
"""
|
||||
return self.pos_dict[pos]
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
"""
|
||||
Get a list of all positions in the collection.
|
||||
|
||||
:return: List of positions.
|
||||
"""
|
||||
return [k for k, v in self.pos_dict.items() for _ in v]
|
||||
|
||||
def is_occupied(self, pos):
|
||||
"""
|
||||
Check if a position is occupied.
|
||||
|
||||
:param pos: The reference position.
|
||||
:return: True if the position is occupied, False otherwise.
|
||||
"""
|
||||
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
|
||||
|
||||
def reset(self):
|
||||
|
@ -1,29 +1,58 @@
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
"""
|
||||
Mixins are a way to modularly extend the functionality of classes in object-oriented programming without using
|
||||
inheritance in the traditional sense. They provide a means to include a set of methods and properties in a class that
|
||||
can be reused across different class hierarchies.
|
||||
"""
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class IsBoundMixin:
|
||||
"""
|
||||
This mixin is designed to be used in classes that represent objects which can be bound to another entity.
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
|
||||
|
||||
def bind(self, entity):
|
||||
"""
|
||||
Binds the current object to another entity.
|
||||
|
||||
:param entity: the entity to be bound
|
||||
"""
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._bound_entity = entity
|
||||
return c.VALID
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
"""
|
||||
Checks if the given entity is the bound entity.
|
||||
|
||||
:return: True if the given entity is the bound entity, false otherwise.
|
||||
"""
|
||||
return self._bound_entity == entity
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class HasBoundMixin:
|
||||
"""
|
||||
This mixin is intended for classes that contain a collection of objects and need functionality to interact with
|
||||
those objects.
|
||||
"""
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
"""
|
||||
Returns a list of pairs containing the names and corresponding objects within the collection.
|
||||
"""
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
"""
|
||||
Retrieves an object from the collection based on its belonging to a specific entity.
|
||||
"""
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
|
@ -13,22 +13,37 @@ class Objects:
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
"""
|
||||
Property indicating whether objects in the collection can be bound to another entity.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def observers(self):
|
||||
"""
|
||||
Property returning a set of observers associated with the collection.
|
||||
"""
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
"""
|
||||
Property providing a tag for observation purposes.
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def render():
|
||||
"""
|
||||
Static method returning an empty list. Override this method in derived classes for rendering functionality.
|
||||
"""
|
||||
return []
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
"""
|
||||
Property returning a list of pairs containing the names and corresponding objects within the collection.
|
||||
"""
|
||||
pair_list = [(self.name, self)]
|
||||
pair_list.extend([(a.name, a) for a in self])
|
||||
return pair_list
|
||||
@ -48,12 +63,26 @@ class Objects:
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the number of objects in the collection.
|
||||
"""
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self) -> Iterator[Union[Object, None]]:
|
||||
return iter(self.values())
|
||||
|
||||
def add_item(self, item: _entity):
|
||||
"""
|
||||
Adds an item to the collection.
|
||||
|
||||
|
||||
:param item: The object to add to the collection.
|
||||
|
||||
:returns: The updated collection.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the item is not of the correct type or already exists in the collection.
|
||||
"""
|
||||
assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,'
|
||||
assert isinstance(item, self._entity), assert_str
|
||||
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
||||
@ -66,6 +95,9 @@ class Objects:
|
||||
return self
|
||||
|
||||
def remove_item(self, item: _entity):
|
||||
"""
|
||||
Removes an item from the collection.
|
||||
"""
|
||||
for observer in item.observers:
|
||||
observer.notify_del_entity(item)
|
||||
# noinspection PyTypeChecker
|
||||
@ -77,6 +109,9 @@ class Objects:
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def del_observer(self, observer):
|
||||
"""
|
||||
Removes an observer from the collection and its entities.
|
||||
"""
|
||||
self.observers.remove(observer)
|
||||
for entity in self:
|
||||
if observer in entity.observers:
|
||||
@ -84,31 +119,56 @@ class Objects:
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def add_observer(self, observer):
|
||||
"""
|
||||
Adds an observer to the collection and its entities.
|
||||
"""
|
||||
self.observers.add(observer)
|
||||
for entity in self:
|
||||
entity.add_observer(observer)
|
||||
|
||||
def add_items(self, items: List[_entity]):
|
||||
"""
|
||||
Adds a list of items to the collection.
|
||||
|
||||
:param items: List of items to add.
|
||||
:type items: List[_entity]
|
||||
:returns: The updated collection.
|
||||
"""
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
"""
|
||||
Returns the keys (names) of the objects in the collection.
|
||||
"""
|
||||
return self._data.keys()
|
||||
|
||||
def values(self):
|
||||
"""
|
||||
Returns the values (objects) in the collection.
|
||||
"""
|
||||
return self._data.values()
|
||||
|
||||
def items(self):
|
||||
"""
|
||||
Returns the items (name-object pairs) in the collection.
|
||||
"""
|
||||
return self._data.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
"""
|
||||
Gets the index of an item in the collection.
|
||||
"""
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._data.values()) if v == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_name(self, name):
|
||||
"""
|
||||
Gets an object from the collection by its name.
|
||||
"""
|
||||
return next(x for x in self if x.name == name)
|
||||
|
||||
def __getitem__(self, item):
|
||||
@ -131,6 +191,9 @@ class Objects:
|
||||
return f'{self.__class__.__name__}[{len(self)}]'
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
"""
|
||||
Notifies the collection that an entity has been deleted.
|
||||
"""
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
@ -138,6 +201,9 @@ class Objects:
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: Object):
|
||||
"""
|
||||
Notifies the collection that an entity has been added.
|
||||
"""
|
||||
try:
|
||||
if self not in entity.observers:
|
||||
entity.add_observer(self)
|
||||
@ -148,24 +214,38 @@ class Objects:
|
||||
pass
|
||||
|
||||
def summarize_states(self):
|
||||
"""
|
||||
Summarizes the states of all entities in the collection.
|
||||
|
||||
:returns: A list of dictionaries representing the summarized states of the entities.
|
||||
:rtype: List[dict]
|
||||
"""
|
||||
# FIXME PROTOBUFF
|
||||
# return [e.summarize_state() for e in self]
|
||||
return [e.summarize_state() for e in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
"""
|
||||
Gets an entity from the collection that belongs to a specified entity.
|
||||
"""
|
||||
try:
|
||||
return h.get_first(self, filter_by=lambda x: x.belongs_to_entity(entity))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
"""
|
||||
Gets the index of an entity in the collection.
|
||||
"""
|
||||
try:
|
||||
return h.get_first_index(self, filter_by=lambda x: x == entity)
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the collection by clearing data and observers.
|
||||
"""
|
||||
self._data = defaultdict(lambda: None)
|
||||
self._observers = set(self)
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
|
@ -16,85 +16,128 @@ class Rule(abc.ABC):
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
TODO
|
||||
Get the name of the rule.
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
:return: The name of the rule.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
TODO
|
||||
Abstract base class representing a rule in the environment.
|
||||
|
||||
This class provides a framework for defining rules that govern the behavior of the environment. Rules can be
|
||||
implemented by inheriting from this class and overriding specific methods.
|
||||
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Return a string representation of the rule.
|
||||
|
||||
:return: A string representation of the rule.
|
||||
:rtype: str
|
||||
"""
|
||||
return f'{self.name}'
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
"""
|
||||
TODO
|
||||
Initialize the rule when the environment is created.
|
||||
|
||||
This method is called during the initialization of the environment. It allows the rule to perform any setup or
|
||||
initialization required.
|
||||
|
||||
:return:
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:param lvl_map: The map of the level.
|
||||
:type lvl_map: marl_factory_grid.environment.level.LevelMap
|
||||
:return: List of TickResults generated during initialization.
|
||||
:rtype: List[TickResult]
|
||||
"""
|
||||
return []
|
||||
|
||||
def on_reset_post_spawn(self, state) -> List[TickResult]:
|
||||
"""
|
||||
TODO
|
||||
Execute actions after entities are spawned during a reset.
|
||||
|
||||
This method is called after entities are spawned during a reset. It allows the rule to perform any actions
|
||||
required at this stage.
|
||||
|
||||
:return:
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: List of TickResults generated after entity spawning.
|
||||
:rtype: List[TickResult]
|
||||
"""
|
||||
return []
|
||||
|
||||
def on_reset(self, state) -> List[TickResult]:
|
||||
"""
|
||||
TODO
|
||||
Execute actions during a reset.
|
||||
|
||||
This method is called during a reset. It allows the rule to perform any actions required at this stage.
|
||||
|
||||
:return:
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: List of TickResults generated during a reset.
|
||||
:rtype: List[TickResult]
|
||||
"""
|
||||
return []
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
"""
|
||||
TODO
|
||||
Execute actions before the main step of the environment.
|
||||
|
||||
This method is called before the main step of the environment. It allows the rule to perform any actions
|
||||
required before the main step.
|
||||
|
||||
:return:
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: List of TickResults generated before the main step.
|
||||
:rtype: List[TickResult]
|
||||
"""
|
||||
return []
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
"""
|
||||
TODO
|
||||
Execute actions during the main step of the environment.
|
||||
|
||||
This method is called during the main step of the environment. It allows the rule to perform any actions
|
||||
required during the main step.
|
||||
|
||||
:return:
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: List of TickResults generated during the main step.
|
||||
:rtype: List[TickResult]
|
||||
"""
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
"""
|
||||
TODO
|
||||
Execute actions after the main step of the environment.
|
||||
|
||||
This method is called after the main step of the environment. It allows the rule to perform any actions
|
||||
required after the main step.
|
||||
|
||||
:return:
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: List of TickResults generated after the main step.
|
||||
:rtype: List[TickResult]
|
||||
"""
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
TODO
|
||||
Check conditions for the termination of the environment.
|
||||
|
||||
This method is called to check conditions for the termination of the environment. It allows the rule to
|
||||
specify conditions under which the environment should be considered done.
|
||||
|
||||
:return:
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: List of DoneResults indicating whether the environment is done.
|
||||
:rtype: List[DoneResult]
|
||||
"""
|
||||
return []
|
||||
|
||||
@ -160,15 +203,23 @@ class DoneAtMaxStepsReached(Rule):
|
||||
|
||||
def __init__(self, max_steps: int = 500):
|
||||
"""
|
||||
TODO
|
||||
A rule that terminates the environment when a specified maximum number of steps is reached.
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
:param max_steps: The maximum number of steps before the environment is considered done.
|
||||
:type max_steps: int
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
|
||||
def on_check_done(self, state):
|
||||
"""
|
||||
Check if the maximum number of steps is reached, and if so, mark the environment as done.
|
||||
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: List of DoneResults indicating whether the environment is done.
|
||||
:rtype: List[DoneResult]
|
||||
"""
|
||||
if self.max_steps <= state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
||||
return []
|
||||
@ -178,14 +229,23 @@ class AssignGlobalPositions(Rule):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
TODO
|
||||
A rule that assigns global positions to agents when the environment is reset.
|
||||
|
||||
|
||||
:return:
|
||||
:return: None
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def on_reset(self, state, lvl_map):
|
||||
"""
|
||||
Assign global positions to agents when the environment is reset.
|
||||
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:param lvl_map: The map of the current level.
|
||||
:type lvl_map: marl_factory_grid.levels.level.LevelMap
|
||||
:return: An empty list, as no additional results are generated by this rule during the reset.
|
||||
:rtype: List[TickResult]
|
||||
"""
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
for agent in state[c.AGENT]:
|
||||
gp = GlobalPosition(agent, lvl_map.level_shape)
|
||||
@ -197,10 +257,15 @@ class WatchCollisions(Rule):
|
||||
|
||||
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
|
||||
"""
|
||||
TODO
|
||||
A rule that monitors collisions between entities in the environment.
|
||||
|
||||
|
||||
:return:
|
||||
:param reward: The reward assigned for each collision.
|
||||
:type reward: float
|
||||
:param done_at_collisions: If True, marks the environment as done when collisions occur.
|
||||
:type done_at_collisions: bool
|
||||
:param reward_at_done: The reward assigned when the environment is marked as done due to collisions.
|
||||
:type reward_at_done: float
|
||||
:return: None
|
||||
"""
|
||||
super().__init__()
|
||||
self.reward_at_done = reward_at_done
|
||||
@ -209,6 +274,14 @@ class WatchCollisions(Rule):
|
||||
self.curr_done = False
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
"""
|
||||
Monitors collisions between entities after each step in the environment.
|
||||
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: A list of TickResult objects representing collisions and their associated rewards.
|
||||
:rtype: List[TickResult]
|
||||
"""
|
||||
self.curr_done = False
|
||||
results = list()
|
||||
for agent in state[c.AGENT]:
|
||||
@ -234,6 +307,14 @@ class WatchCollisions(Rule):
|
||||
return results
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Checks if the environment should be marked as done based on collision conditions.
|
||||
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.utils.states.GameState
|
||||
:return: A list of DoneResult objects representing the conditions for marking the environment as done.
|
||||
:rtype: List[DoneResult]
|
||||
"""
|
||||
if self.done_at_collisions:
|
||||
inter_entity_collision_detected = self.curr_done
|
||||
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision
|
||||
|
@ -0,0 +1,16 @@
|
||||
"""
|
||||
|
||||
The place to put the level-files.
|
||||
Per default the following levels are provided:
|
||||
|
||||
- eight_puzzle
|
||||
- large
|
||||
- large_qquad
|
||||
- narrow_corridor
|
||||
- rooms
|
||||
- shelves
|
||||
- simple
|
||||
- two_rooms
|
||||
|
||||
|
||||
"""
|
@ -5,3 +5,11 @@ from .doors import *
|
||||
from .items import *
|
||||
from .machines import *
|
||||
from .maintenance import *
|
||||
|
||||
"""
|
||||
modules
|
||||
=======
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
||||
|
@ -12,7 +12,7 @@ class Charge(Action):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Checks if a charge pod is present at the entity's position.
|
||||
Checks if a charge pod is present at the agent's position.
|
||||
If found, it attempts to charge the battery using the charge pod.
|
||||
"""
|
||||
super().__init__(b.ACTION_CHARGE, b.REWARD_CHARGE_VALID, b.Reward_CHARGE_FAIL)
|
||||
|
@ -31,7 +31,7 @@ class Battery(Object):
|
||||
|
||||
def __init__(self, initial_charge_level, owner, *args, **kwargs):
|
||||
"""
|
||||
Represents a battery entity in the environment that can be bound to an agent and charged at chargepods.
|
||||
Represents a battery entity in the environment that can be bound to an agent and charged at charge pods.
|
||||
|
||||
:param initial_charge_level: The current charge level of the battery, ranging from 0 to 1.
|
||||
:type initial_charge_level: float
|
||||
@ -45,7 +45,7 @@ class Battery(Object):
|
||||
|
||||
def do_charge_action(self, amount) -> bool:
|
||||
"""
|
||||
Updates the Battery's charge level accordingly.
|
||||
Updates the Battery's charge level according to the passed value.
|
||||
|
||||
:param amount: Amount added to the Battery's charge level.
|
||||
:returns: whether the battery could be charged. if not, it was already fully charged.
|
||||
@ -59,7 +59,7 @@ class Battery(Object):
|
||||
|
||||
def decharge(self, amount) -> bool:
|
||||
"""
|
||||
Decreases the charge value of a battery. Currently only riggered by the battery-decharge rule.
|
||||
Decreases the charge value of a battery. Currently only triggered by the battery-decharge rule.
|
||||
"""
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
@ -84,11 +84,11 @@ class ChargePod(Entity):
|
||||
"""
|
||||
Represents a charging pod for batteries in the environment.
|
||||
|
||||
:param charge_rate: The rate at which the charging pod charges batteries. Default is 0.4.
|
||||
:param charge_rate: The rate at which the charging pod charges batteries. Defaults to 0.4.
|
||||
:type charge_rate: float
|
||||
|
||||
:param multi_charge: Indicates whether the charging pod supports charging multiple batteries simultaneously.
|
||||
Default is False.
|
||||
Defaults to False.
|
||||
:type multi_charge: bool
|
||||
"""
|
||||
super(ChargePod, self).__init__(*args, **kwargs)
|
||||
@ -97,7 +97,8 @@ class ChargePod(Entity):
|
||||
|
||||
def charge_battery(self, entity, state) -> bool:
|
||||
"""
|
||||
Checks whether the battery can be charged. If so, triggers the charge action.
|
||||
Triggers the battery charge action if possible. Impossible if battery at full charge level or more than one
|
||||
agent at charge pods' position.
|
||||
|
||||
:returns: whether the action was successful (valid) or not.
|
||||
"""
|
||||
|
@ -19,7 +19,7 @@ class Batteries(Collection):
|
||||
|
||||
def __init__(self, size, initial_charge_level=1.0, *args, **kwargs):
|
||||
"""
|
||||
A collection of batteries that can spawn batteries.
|
||||
A collection of batteries that is in charge of spawning batteries. (spawned batteries are bound to agents)
|
||||
|
||||
:param size: The maximum allowed size of the collection. Ensures that the collection does not exceed this size.
|
||||
:type size: int
|
||||
|
@ -12,7 +12,7 @@ class Clean(Action):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Attempts to reduce dirt amount on entity's position.
|
||||
Attempts to reduce dirt amount on entity's position. Fails if no dirt is found at the at agents' position.
|
||||
"""
|
||||
super().__init__(d.CLEAN_UP, d.REWARD_CLEAN_UP_VALID, d.REWARD_CLEAN_UP_FAIL)
|
||||
|
||||
|
@ -18,7 +18,8 @@ class DirtPile(Entity):
|
||||
|
||||
def __init__(self, *args, amount=2, max_local_amount=5, **kwargs):
|
||||
"""
|
||||
Represents a pile of dirt at a specific position in the environment.
|
||||
Represents a pile of dirt at a specific position in the environment that agents can interact with. Agents can
|
||||
clean the dirt pile or, depending on activated rules, interact with it in different ways.
|
||||
|
||||
:param amount: The amount of dirt in the pile.
|
||||
:type amount: float
|
||||
|
@ -10,7 +10,7 @@ class DestAction(Action):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Attempts to wait at destination.
|
||||
The agent performing this action attempts to wait at the destination in order to receive a reward.
|
||||
"""
|
||||
super().__init__(d.DESTINATION, d.REWARD_WAIT_VALID, d.REWARD_WAIT_FAIL)
|
||||
|
||||
|
@ -38,7 +38,11 @@ class Destination(Entity):
|
||||
|
||||
def has_just_been_reached(self, state):
|
||||
"""
|
||||
Checks if the destination has just been reached based on the current state.
|
||||
Checks if the destination has been reached in the last environment step.
|
||||
|
||||
:return: the agent that has just reached the destination or whether any agent in the environment has
|
||||
performed actions equal to or exceeding the specified limit
|
||||
:rtype: Union[Agent, bool]
|
||||
"""
|
||||
if self.was_reached():
|
||||
return False
|
||||
|
@ -12,7 +12,7 @@ from marl_factory_grid.modules.destinations.entitites import Destination
|
||||
|
||||
ANY = 'any'
|
||||
ALL = 'all'
|
||||
SIMULTANEOUS = 'simultanious'
|
||||
SIMULTANEOUS = 'simultaneous'
|
||||
CONDITIONS = [ALL, ANY, SIMULTANEOUS]
|
||||
|
||||
|
||||
|
@ -10,7 +10,8 @@ class DoorUse(Action):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Attempts to interact with door (open/close it) and returns an action result if successful.
|
||||
The agent performing this action attempts to interact with door (open/close it), returning an action result if
|
||||
successful.
|
||||
"""
|
||||
super().__init__(d.ACTION_DOOR_USE, d.REWARD_USE_DOOR_VALID, d.REWARD_USE_DOOR_FAIL, **kwargs)
|
||||
|
||||
|
@ -19,7 +19,7 @@ class DoorIndicator(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Is added around a door for agents to see.
|
||||
Is added as a padding around doors so agents can see doors earlier.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__delattr__('move')
|
||||
|
@ -9,7 +9,7 @@ class DoorAutoClose(Rule):
|
||||
|
||||
def __init__(self, close_frequency: int = 10):
|
||||
"""
|
||||
This rule closes doors, that have been opened automatically, when no entity is blocking the position.
|
||||
This rule closes doors that have been opened automatically when no entity is blocking the position.
|
||||
|
||||
:type close_frequency: int
|
||||
:param close_frequency: How many ticks after opening, should the door close?
|
||||
|
@ -1,3 +1,11 @@
|
||||
from .actions import ItemAction
|
||||
from .entitites import Item, DropOffLocation
|
||||
from .groups import DropOffLocations, Items, Inventory, Inventories
|
||||
|
||||
"""
|
||||
items
|
||||
=====
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -62,7 +62,7 @@ class Inventory(IsBoundMixin, Collection):
|
||||
|
||||
def __init__(self, agent, *args, **kwargs):
|
||||
"""
|
||||
An inventory that can hold items picked up by the agent this is bound to.
|
||||
An inventory that can hold items picked up by the agent it is bound to.
|
||||
|
||||
:param agent: The agent this inventory is bound to and belongs to.
|
||||
:type agent: Agent
|
||||
@ -96,7 +96,7 @@ class Inventory(IsBoundMixin, Collection):
|
||||
|
||||
def clear_temp_state(self):
|
||||
"""
|
||||
Entites need this, but inventories have no state.
|
||||
Entities need this, but inventories have no state.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -123,7 +123,7 @@ class Inventories(Objects):
|
||||
|
||||
def __init__(self, size: int, *args, **kwargs):
|
||||
"""
|
||||
TODO
|
||||
A collection of all inventories used to spawn an inventory per agent.
|
||||
"""
|
||||
super(Inventories, self).__init__(*args, **kwargs)
|
||||
self.size = size
|
||||
|
@ -1,2 +1,10 @@
|
||||
from .entitites import Machine
|
||||
from .groups import Machines
|
||||
|
||||
"""
|
||||
machines
|
||||
========
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -11,7 +11,8 @@ class MachineAction(Action):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Attempts to maintain the machine and returns an action result if successful.
|
||||
When performing this action, the maintainer attempts to maintain the machine at his current position, returning
|
||||
an action result if successful.
|
||||
"""
|
||||
super().__init__(m.MACHINE_ACTION, m.MAINTAIN_VALID, m.MAINTAIN_FAIL)
|
||||
|
||||
|
@ -14,7 +14,8 @@ class Machine(Entity):
|
||||
|
||||
def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs):
|
||||
"""
|
||||
Represents a machine entity that the maintainer will try to maintain.
|
||||
Represents a machine entity that the maintainer will try to maintain by performing the maintenance action.
|
||||
Machines' health depletes over time.
|
||||
|
||||
:param work_interval: How long should the machine work before pausing.
|
||||
:type work_interval: int
|
||||
@ -31,7 +32,8 @@ class Machine(Entity):
|
||||
|
||||
def maintain(self) -> bool:
|
||||
"""
|
||||
Attempts to maintain the machine by increasing its health.
|
||||
Attempts to maintain the machine by increasing its health, which is only possible if the machine is at a maximum
|
||||
of 98/100 HP.
|
||||
"""
|
||||
if self.status == m.STATE_WORK:
|
||||
return c.NOT_VALID
|
||||
|
@ -1,2 +1,9 @@
|
||||
from .entities import Maintainer
|
||||
from .groups import Maintainers
|
||||
"""
|
||||
maintenance
|
||||
===========
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -16,8 +16,9 @@ from ..doors import DoorUse
|
||||
class Maintainer(Entity):
|
||||
|
||||
def __init__(self, objective, action, *args, **kwargs):
|
||||
"""
|
||||
Represents the maintainer entity that aims to maintain machines.
|
||||
self.action_ = """
|
||||
Represents the maintainer entity that aims to maintain machines. The maintainer calculates its route using nx
|
||||
shortest path and restores the health of machines it visits to 100.
|
||||
|
||||
:param objective: The maintainer's objective, e.g., "Machines".
|
||||
:type objective: str
|
||||
|
@ -27,7 +27,7 @@ class Maintainers(Collection):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
A collection of maintainers
|
||||
A collection of maintainers that is used to spawn them.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -1,3 +0,0 @@
|
||||
from .entitites import Zone
|
||||
from .groups import Zones
|
||||
from .rules import AgentSingleZonePlacement
|
@ -1,4 +0,0 @@
|
||||
# Names / Identifiers
|
||||
|
||||
ZONES = 'Zones' # Identifier of Zone-objects and groups (groups).
|
||||
ZONE = 'Zone' # -||-
|
@ -1,19 +0,0 @@
|
||||
import random
|
||||
from typing import List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
|
||||
class Zone(Object):
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return self.coords
|
||||
|
||||
def __init__(self, coords: List[Tuple[(int, int)]], *args, **kwargs):
|
||||
super(Zone, self).__init__(*args, **kwargs)
|
||||
self.coords = coords
|
||||
|
||||
@property
|
||||
def random_pos(self):
|
||||
return random.choice(self.coords)
|
@ -1,26 +0,0 @@
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
symbol = None
|
||||
_entity = Zone
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
||||
|
||||
def by_pos(self, pos):
|
||||
return self.pos_dict[pos]
|
||||
|
||||
def notify_add_entity(self, entity: Zone):
|
||||
self.pos_dict.update({key: [entity] for key in entity.positions})
|
||||
return True
|
||||
|
||||
def notify_del_entity(self, entity: Zone):
|
||||
for pos in entity.positions:
|
||||
self.pos_dict[pos].remove(entity)
|
||||
return True
|
@ -1,71 +0,0 @@
|
||||
from random import choices, choice
|
||||
|
||||
from . import constants as z, Zone
|
||||
from .. import Destination
|
||||
from ..destinations import constants as d
|
||||
from ...environment.rules import Rule
|
||||
from ...environment import constants as c
|
||||
|
||||
|
||||
class ZoneInit(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._zones = list()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
z_idx = 1
|
||||
|
||||
while z_idx:
|
||||
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||
if len(zone_positions):
|
||||
self._zones.append(Zone(zone_positions))
|
||||
z_idx += 1
|
||||
else:
|
||||
z_idx = 0
|
||||
|
||||
def on_reset(self, state):
|
||||
state[z.ZONES].add_items(self._zones)
|
||||
return []
|
||||
|
||||
|
||||
class AgentSingleZonePlacement(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_reset(self, state):
|
||||
n_agents = len(state[c.AGENT])
|
||||
assert len(state[z.ZONES]) >= n_agents
|
||||
|
||||
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
||||
for agent in state[c.AGENT]:
|
||||
agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state)
|
||||
return []
|
||||
|
||||
|
||||
class IndividualDestinationZonePlacement(Rule):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This is pretty new, and needs to be debugged, after the zones")
|
||||
super().__init__()
|
||||
|
||||
def on_reset(self, state):
|
||||
for agent in state[c.AGENT]:
|
||||
self.trigger_spawn(agent, state)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def trigger_spawn(agent, state):
|
||||
agent_zones = state[z.ZONES].by_pos(agent.pos)
|
||||
other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
|
||||
already_has_destination = True
|
||||
while already_has_destination:
|
||||
pos = choice(other_zones).random_pos
|
||||
if state[d.DESTINATION].by_pos(pos) is None:
|
||||
already_has_destination = False
|
||||
destination = Destination(pos, bind_to=agent)
|
||||
|
||||
state[d.DESTINATION].add_item(destination)
|
||||
continue
|
||||
return c.VALID
|
@ -1,3 +1,11 @@
|
||||
from . import helpers as h
|
||||
from . import helpers
|
||||
from .results import Result, DoneResult, ActionResult, TickResult
|
||||
|
||||
"""
|
||||
Utils
|
||||
=====
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -22,6 +22,12 @@ class FactoryConfigParser(object):
|
||||
default_observations = [c.WALLS, c.AGENT]
|
||||
|
||||
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
|
||||
"""
|
||||
This class parses the factory env config file.
|
||||
|
||||
:param config_path: Path to where the 'config.yml' is.
|
||||
:param custom_modules_path: Additional search path for custom modules, levels, entities, etc..
|
||||
"""
|
||||
self.config_path = Path(config_path)
|
||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.config = yaml.safe_load(self.config_path.open())
|
||||
@ -40,7 +46,6 @@ class FactoryConfigParser(object):
|
||||
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
|
||||
return self._n_abbr_dict[n]
|
||||
|
||||
|
||||
@property
|
||||
def agent_actions(self):
|
||||
return self._get_sub_list('Agents', "Actions")
|
||||
@ -129,13 +134,25 @@ class FactoryConfigParser(object):
|
||||
# Actions
|
||||
conf_actions = self.agents[name]['Actions']
|
||||
actions = list()
|
||||
# Actions:
|
||||
# Allowed
|
||||
# - Noop
|
||||
# - Move8
|
||||
# ----
|
||||
# Noop:
|
||||
# South:
|
||||
# reward_fail: 0.5
|
||||
# ----
|
||||
# Forbidden
|
||||
# - South:
|
||||
# reward_fail: 0.5
|
||||
|
||||
if isinstance(conf_actions, dict):
|
||||
conf_kwargs = conf_actions.copy()
|
||||
conf_actions = list(conf_actions.keys())
|
||||
elif isinstance(conf_actions, list):
|
||||
conf_kwargs = {}
|
||||
if isinstance(conf_actions, dict):
|
||||
if any(isinstance(x, dict) for x in conf_actions):
|
||||
raise ValueError
|
||||
pass
|
||||
for action in conf_actions:
|
||||
@ -152,11 +169,10 @@ class FactoryConfigParser(object):
|
||||
except AttributeError:
|
||||
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
|
||||
try:
|
||||
# print(action)
|
||||
# Handle Lists of Actions (e.g., Move8, Move4, Default)
|
||||
parsed_actions.extend(class_or_classes)
|
||||
# print(parsed_actions)
|
||||
for actions_class in class_or_classes:
|
||||
# break
|
||||
conf_kwargs[actions_class.__name__] = conf_kwargs.get(action, {})
|
||||
except TypeError:
|
||||
parsed_actions.append(class_or_classes)
|
||||
@ -174,7 +190,7 @@ class FactoryConfigParser(object):
|
||||
['Actions', 'Observations', 'Positions', 'Clones']}
|
||||
parsed_agents_conf[name] = dict(
|
||||
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
clones = self.agents[name].get('Clones', 0)
|
||||
if clones:
|
||||
|
@ -10,14 +10,14 @@ from marl_factory_grid.environment import constants as c
|
||||
|
||||
"""
|
||||
This file is used for:
|
||||
1. string based definition
|
||||
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||
When defining new envs, use class inheritance.
|
||||
|
||||
2. utility function definition
|
||||
There are static utility functions which are not bound to a specific environment.
|
||||
In this file they are defined to be used across the entire package.
|
||||
1. string based definition
|
||||
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||
When defining new envs, use class inheritance.
|
||||
|
||||
2. utility function definition
|
||||
There are static utility functions which are not bound to a specific environment.
|
||||
In this file they are defined to be used across the entire package.
|
||||
"""
|
||||
|
||||
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||
@ -54,15 +54,9 @@ class ObservationTranslator:
|
||||
A string _identifier based approach is used.
|
||||
Currently, it is not possible to mix different obs shapes.
|
||||
|
||||
|
||||
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||
:type this_named_observation_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently, not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N'
|
||||
"""
|
||||
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
|
@ -16,7 +16,10 @@ class LevelParser(object):
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
Internal Usage
|
||||
Calculates the effective diameter of the POMDP observation space.
|
||||
|
||||
:return: The calculated effective diameter.
|
||||
:rtype: int
|
||||
"""
|
||||
return self.pomdp_r * 2 + 1
|
||||
|
||||
|
@ -0,0 +1,7 @@
|
||||
"""
|
||||
logging
|
||||
=======
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -17,6 +17,9 @@ class EnvMonitor(Wrapper):
|
||||
ext = 'png'
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||
"""
|
||||
EnvMonitor is a wrapper for Gymnasium environments that monitors and logs key information during interactions.
|
||||
"""
|
||||
super(EnvMonitor, self).__init__(env)
|
||||
self._filepath = filepath
|
||||
self._monitor_df = pd.DataFrame()
|
||||
@ -52,6 +55,14 @@ class EnvMonitor(Wrapper):
|
||||
return
|
||||
|
||||
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
"""
|
||||
Saves the monitoring data to a file and optionally generates plots.
|
||||
|
||||
:param filepath: The path to save the monitoring data file.
|
||||
:type filepath: Union[Path, str, None]
|
||||
:param auto_plotting_keys: Keys to use for automatic plot generation.
|
||||
:type auto_plotting_keys: Any
|
||||
"""
|
||||
filepath = Path(filepath or self._filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
|
@ -11,6 +11,16 @@ class EnvRecorder(Wrapper):
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None,
|
||||
episodes: Union[List[int], None] = None):
|
||||
"""
|
||||
EnvRecorder is a wrapper for OpenAI Gym environments that records state summaries during interactions.
|
||||
|
||||
:param env: The environment to record.
|
||||
:type env: gym.Env
|
||||
:param filepath: The path to save the recording data file.
|
||||
:type filepath: Union[str, PathLike]
|
||||
:param episodes: A list of episode numbers to record. If None, records all episodes.
|
||||
:type episodes: Union[List[int], None]
|
||||
"""
|
||||
super(EnvRecorder, self).__init__(env)
|
||||
self.filepath = filepath
|
||||
self.episodes = episodes
|
||||
@ -19,6 +29,9 @@ class EnvRecorder(Wrapper):
|
||||
self._recorder_out_list = list()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Overrides the reset method to reset the environment and recording lists.
|
||||
"""
|
||||
self._curr_ep_recorder = list()
|
||||
self._recorder_out_list = list()
|
||||
self._curr_episode += 1
|
||||
@ -26,10 +39,12 @@ class EnvRecorder(Wrapper):
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Todo
|
||||
Overrides the step method to record state summaries during each step.
|
||||
|
||||
:param actions:
|
||||
:return:
|
||||
:param actions: The actions taken in the environment.
|
||||
:type actions: Any
|
||||
:return: The observation, reward, done flag, and additional information.
|
||||
:rtype: Tuple
|
||||
"""
|
||||
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||
if not self.episodes or self._curr_episode in self.episodes:
|
||||
@ -55,6 +70,18 @@ class EnvRecorder(Wrapper):
|
||||
save_occupation_map=False,
|
||||
save_trajectory_map=False,
|
||||
):
|
||||
"""
|
||||
Saves the recorded data to a file.
|
||||
|
||||
:param filepath: The path to save the recording data file.
|
||||
:type filepath: Union[Path, str, None]
|
||||
:param only_deltas: If True, saves only the differences between consecutive episodes.
|
||||
:type only_deltas: bool
|
||||
:param save_occupation_map: If True, saves an occupation map as a heatmap.
|
||||
:type save_occupation_map: bool
|
||||
:param save_trajectory_map: If True, saves a trajectory map.
|
||||
:type save_trajectory_map: bool
|
||||
"""
|
||||
self._finalize()
|
||||
filepath = Path(filepath or self.filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
@ -73,7 +100,6 @@ class EnvRecorder(Wrapper):
|
||||
n_dests=0,
|
||||
dwell_time=0,
|
||||
spawn_frequency=0,
|
||||
spawn_in_other_zone=False,
|
||||
spawn_mode=''
|
||||
)
|
||||
rewards_dest = dict(
|
||||
|
@ -19,10 +19,10 @@ class OBSBuilder(object):
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
TODO
|
||||
Calculates the effective diameter of the POMDP observation space.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The calculated effective diameter.
|
||||
:rtype: int
|
||||
"""
|
||||
if self.pomdp_r:
|
||||
return (self.pomdp_r * 2) + 1
|
||||
@ -31,10 +31,17 @@ class OBSBuilder(object):
|
||||
|
||||
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
|
||||
"""
|
||||
TODO
|
||||
OBSBuilder
|
||||
==========
|
||||
|
||||
The OBSBuilder class is responsible for constructing observations in the environment.
|
||||
|
||||
:return:
|
||||
:param level_shape: The shape of the level or environment.
|
||||
:type level_shape: np.size
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.environment.state.Gamestate
|
||||
:param pomdp_r: The POMDP radius, influencing the size of the observation space.
|
||||
:type pomdp_r: int
|
||||
"""
|
||||
self.all_obs = dict()
|
||||
self.ray_caster = dict()
|
||||
@ -52,6 +59,9 @@ class OBSBuilder(object):
|
||||
self.reset(state)
|
||||
|
||||
def reset(self, state):
|
||||
"""
|
||||
Resets temporary information and constructs an empty observation array with possible placeholders.
|
||||
"""
|
||||
# Reset temporary information
|
||||
self.curr_lightmaps = dict()
|
||||
# Construct an empty obs (array) for possible placeholders
|
||||
@ -61,6 +71,11 @@ class OBSBuilder(object):
|
||||
return True
|
||||
|
||||
def observation_space(self, state):
|
||||
"""
|
||||
Returns the observation space for a single agent or a tuple of spaces for multiple agents.
|
||||
:returns: The observation space for the agent(s).
|
||||
:rtype: gym.Space|Tuple
|
||||
"""
|
||||
from gymnasium.spaces import Tuple, Box
|
||||
self.reset(state)
|
||||
obsn = self.build_for_all(state)
|
||||
@ -71,13 +86,29 @@ class OBSBuilder(object):
|
||||
return space
|
||||
|
||||
def named_observation_space(self, state):
|
||||
"""
|
||||
:returns: A dictionary of named observation spaces for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
self.reset(state)
|
||||
return self.build_for_all(state)
|
||||
|
||||
def build_for_all(self, state) -> (dict, dict):
|
||||
"""
|
||||
Builds observations for all agents in the environment.
|
||||
|
||||
:returns: A dictionary of observations for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}
|
||||
|
||||
def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
|
||||
"""
|
||||
Builds named observations for all agents in the environment.
|
||||
|
||||
:returns: A dictionary containing named observations for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
named_obs_dict = {}
|
||||
for agent in state[c.AGENT]:
|
||||
obs, names = self.build_for_agent(agent, state)
|
||||
@ -85,6 +116,16 @@ class OBSBuilder(object):
|
||||
return named_obs_dict
|
||||
|
||||
def place_entity_in_observation(self, obs_array, agent, e):
|
||||
"""
|
||||
Places the encoding of an entity in the observation array relative to the agent's position.
|
||||
|
||||
:param obs_array: The observation array.
|
||||
:type obs_array: np.ndarray
|
||||
:param agent: the associated agent
|
||||
:type agent: Agent
|
||||
:param e: The entity to be placed in the observation.
|
||||
:type e: Entity
|
||||
"""
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
if not min([y, x]) < 0:
|
||||
try:
|
||||
@ -95,6 +136,12 @@ class OBSBuilder(object):
|
||||
pass
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
"""
|
||||
Builds observations for a specific agent.
|
||||
|
||||
:returns: A tuple containing a list of observation names and the corresponding observation array
|
||||
:rtype: Tuple[List[str], np.ndarray]
|
||||
"""
|
||||
try:
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
except KeyError:
|
||||
@ -190,8 +237,8 @@ class OBSBuilder(object):
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
"""
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
|
||||
:param agent: The agent for whom the observation scheme is built.
|
||||
"""
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
|
@ -0,0 +1,7 @@
|
||||
"""
|
||||
PLotting
|
||||
========
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -13,6 +13,16 @@ MODEL_MAP = None
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
"""
|
||||
|
||||
Compare multiple runs with different seeds by generating a line plot that shows the evolution of scores (step rewards)
|
||||
across episodes.
|
||||
|
||||
:param run_path: The path to the directory containing the monitor files for each run.
|
||||
:type run_path: Union[str, PathLike]
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
for run, monitor_file in enumerate(run_path.rglob('monitor*.pick')):
|
||||
@ -23,7 +33,7 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
@ -49,6 +59,19 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
|
||||
def compare_model_runs(run_path: Path, run_identifier: Union[str, int], parameter: Union[str, List[str]],
|
||||
use_tex: bool = False):
|
||||
"""
|
||||
Compares multiple model runs based on specified parameters by generating a line plot showing the evolution of scores (step rewards)
|
||||
across episodes.
|
||||
|
||||
:param run_path: The path to the directory containing the monitor files for each model run.
|
||||
:type run_path: Path
|
||||
:param run_identifier: A string or integer identifying the runs to compare.
|
||||
:type run_identifier: Union[str, int]
|
||||
:param parameter: A single parameter or a list of parameters to compare.
|
||||
:type parameter: Union[str, List[str]]
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||
@ -89,6 +112,20 @@ def compare_model_runs(run_path: Path, run_identifier: Union[str, int], paramete
|
||||
|
||||
def compare_all_parameter_runs(run_root_path: Path, parameter: Union[str, List[str]],
|
||||
param_names: Union[List[str], None] = None, str_to_ignore='', use_tex: bool = False):
|
||||
"""
|
||||
Compares model runs across different parameter settings by generating a line plot showing the evolution of scores across episodes.
|
||||
|
||||
:param run_root_path: The root path to the directory containing the monitor files for all model runs.
|
||||
:type run_root_path: Path
|
||||
:param parameter: The parameter(s) to compare across different runs.
|
||||
:type parameter: Union[str, List[str]]
|
||||
:param param_names: A list of custom names for the parameters to be used as labels in the plot. If None, default names will be assigned.
|
||||
:type param_names: Union[List[str], None]
|
||||
:param str_to_ignore: A string pattern to ignore in parameter names.
|
||||
:type str_to_ignore: str
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_root_path = Path(run_root_path)
|
||||
df_list = list()
|
||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||
|
@ -10,7 +10,21 @@ from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||
file_key: str ='monitor', file_ext: str ='pkl'):
|
||||
file_key: str = 'monitor', file_ext: str = 'pkl'):
|
||||
"""
|
||||
Plots the Epoch score (step reward) over a single run based on monitoring data stored in a file.
|
||||
|
||||
:param run_path: The path to the directory containing monitoring data or directly to the monitoring file.
|
||||
:type run_path: Union[str, PathLike]
|
||||
:param use_tex: Flag indicating whether to use TeX for plotting.
|
||||
:type use_tex: bool, optional
|
||||
:param column_keys: Specific columns to include in the plot. If None, includes all columns except ignored ones.
|
||||
:type column_keys: list or None, optional
|
||||
:param file_key: The keyword to identify the monitoring file.
|
||||
:type file_key: str, optional
|
||||
:param file_ext: The extension of the monitoring file.
|
||||
:type file_ext: str, optional
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
@ -26,7 +40,7 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
|
@ -1,7 +1,6 @@
|
||||
import seaborn as sns
|
||||
import matplotlib as mpl
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
@ -20,6 +19,14 @@ PALETTE = 10 * (
|
||||
|
||||
|
||||
def plot(filepath, ext='png'):
|
||||
"""
|
||||
Saves the current plot to a file and displays it.
|
||||
|
||||
:param filepath: The path to save the plot file.
|
||||
:type filepath: str
|
||||
:param ext: The file extension of the saved plot. Default is 'png'.
|
||||
:type ext: str
|
||||
"""
|
||||
plt.tight_layout()
|
||||
figure = plt.gcf()
|
||||
ax = plt.gca()
|
||||
@ -35,6 +42,20 @@ def plot(filepath, ext='png'):
|
||||
|
||||
|
||||
def prepare_tex(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot for rendering in LaTeX.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||
hue_order=hue_order, hue=hue, style=style)
|
||||
@ -45,6 +66,20 @@ def prepare_tex(df, hue, style, hue_order):
|
||||
|
||||
|
||||
def prepare_plt(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot using matplotlib.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
@ -57,6 +92,20 @@ def prepare_plt(df, hue, style, hue_order):
|
||||
|
||||
|
||||
def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot with a legend centered at the bottom and spread across two columns.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
@ -70,6 +119,23 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
||||
"""
|
||||
Prepares a line plot for visualization. Based on the use tex parameter calls the prepare_tex or prepare_plot
|
||||
function accordingly, followed by the plot function to save the plot.
|
||||
|
||||
:param filepath: The file path where the plot will be saved.
|
||||
:type filepath: str
|
||||
:param results_df: The DataFrame containing the data to be plotted.
|
||||
:type results_df: pandas.DataFrame
|
||||
:param ext: The file extension of the saved plot (default is 'png').
|
||||
:type ext: str
|
||||
:param hue: The variable to determine the color of the lines in the plot.
|
||||
:type hue: str
|
||||
:param style: The variable to determine the style of the lines in the plot (default is None).
|
||||
:type style: str or None
|
||||
:param use_tex: Whether to use LaTeX for text rendering (default is False).
|
||||
:type use_tex: bool
|
||||
"""
|
||||
df = results_df.copy()
|
||||
df[hue] = df[hue].str.replace('_', '-')
|
||||
hue_order = sorted(list(df[hue].unique()))
|
||||
|
@ -8,10 +8,17 @@ from numba import njit
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
"""
|
||||
TODO
|
||||
The RayCaster class enables agents in the environment to simulate field-of-view visibility,
|
||||
providing methods for calculating visible entities and outlining the field of view based on
|
||||
Bresenham's algorithm.
|
||||
|
||||
|
||||
:return:
|
||||
:param agent: The agent for which the RayCaster is initialized.
|
||||
:type agent: Agent
|
||||
:param pomdp_r: The range of the partially observable Markov decision process (POMDP).
|
||||
:type pomdp_r: int
|
||||
:param degs: The degrees of the field of view (FOV). Defaults to 360.
|
||||
:type degs: int
|
||||
:return: None
|
||||
"""
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
@ -25,6 +32,12 @@ class RayCaster:
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
"""
|
||||
Builds the targets for the rays based on the field of view (FOV).
|
||||
|
||||
:return: The targets for the rays.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
@ -36,11 +49,31 @@ class RayCaster:
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, key, callback):
|
||||
"""
|
||||
Retrieves or caches a value in the cache dictionary.
|
||||
|
||||
:param key: The key for the cache dictionary.
|
||||
:type key: any
|
||||
:param callback: The callback function to obtain the value if not present in the cache.
|
||||
:type callback: callable
|
||||
:return: The cached or newly computed value.
|
||||
:rtype: any
|
||||
"""
|
||||
if key not in self._cache_dict:
|
||||
self._cache_dict[key] = callback()
|
||||
return self._cache_dict[key]
|
||||
|
||||
def visible_entities(self, pos_dict, reset_cache=True):
|
||||
"""
|
||||
Returns a list of visible entities based on the agent's field of view.
|
||||
|
||||
:param pos_dict: The dictionary containing positions of entities.
|
||||
:type pos_dict: dict
|
||||
:param reset_cache: Flag to reset the cache. Defaults to True.
|
||||
:type reset_cache: bool
|
||||
:return: A list of visible entities.
|
||||
:rtype: list
|
||||
"""
|
||||
visible = list()
|
||||
if reset_cache:
|
||||
self._cache_dict = dict()
|
||||
@ -71,15 +104,33 @@ class RayCaster:
|
||||
return visible
|
||||
|
||||
def get_rays(self):
|
||||
"""
|
||||
Gets the rays for the agent.
|
||||
|
||||
:return: The rays for the agent.
|
||||
:rtype: list
|
||||
"""
|
||||
a_pos = self.agent.pos
|
||||
outline = self.ray_targets + a_pos
|
||||
return self.bresenham_loop(a_pos, outline)
|
||||
|
||||
# todo do this once and cache the points!
|
||||
def get_fov_outline(self) -> np.ndarray:
|
||||
"""
|
||||
Gets the field of view (FOV) outline.
|
||||
|
||||
:return: The FOV outline.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
return self.ray_targets + self.agent.pos
|
||||
|
||||
def get_square_outline(self):
|
||||
"""
|
||||
Gets the square outline for the agent.
|
||||
|
||||
:return: The square outline.
|
||||
:rtype: list
|
||||
"""
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
@ -90,6 +141,16 @@ class RayCaster:
|
||||
@staticmethod
|
||||
@njit
|
||||
def bresenham_loop(a_pos, points):
|
||||
"""
|
||||
Applies Bresenham's algorithm to calculate the points between two positions.
|
||||
|
||||
:param a_pos: The starting position.
|
||||
:type a_pos: list
|
||||
:param points: The ending positions.
|
||||
:type points: list
|
||||
:return: The list of points between the starting and ending positions.
|
||||
:rtype: list
|
||||
"""
|
||||
results = []
|
||||
for end in points:
|
||||
x1, y1 = a_pos
|
||||
|
@ -34,12 +34,26 @@ class Renderer:
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
"""
|
||||
TODO
|
||||
The Renderer class initializes and manages the rendering environment for the simulation,
|
||||
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
|
||||
rendering the entities on the screen with specified parameters.
|
||||
|
||||
|
||||
:return:
|
||||
:param lvl_shape: Tuple representing the shape of the level.
|
||||
:type lvl_shape: Tuple[int, int]
|
||||
:param lvl_padded_shape: Optional Tuple representing the padded shape of the level.
|
||||
:type lvl_padded_shape: Union[Tuple[int, int], None]
|
||||
:param cell_size: Size of each cell in pixels.
|
||||
:type cell_size: int
|
||||
:param fps: Frames per second for rendering.
|
||||
:type fps: int
|
||||
:param factor: Factor for resizing assets.
|
||||
:type factor: float
|
||||
:param grid_lines: Boolean indicating whether to display grid lines.
|
||||
:type grid_lines: bool
|
||||
:param view_radius: Radius for agent's field of view.
|
||||
:type view_radius: int
|
||||
"""
|
||||
# TODO: Customn_assets paths
|
||||
# TODO: Custom_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
@ -60,6 +74,9 @@ class Renderer:
|
||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||
|
||||
def fill_bg(self):
|
||||
"""
|
||||
Fills the background of the screen with the specified BG color.
|
||||
"""
|
||||
self.screen.fill(Renderer.BG_COLOR)
|
||||
if self.grid_lines:
|
||||
w, h = self.screen_size
|
||||
@ -69,6 +86,16 @@ class Renderer:
|
||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||
|
||||
def blit_params(self, entity):
|
||||
"""
|
||||
Prepares parameters for blitting an entity on the screen. Blitting refers to the process of combining or copying
|
||||
rectangular blocks of pixels from one part of a graphical buffer to another and is often used to efficiently
|
||||
update the display by copying pre-drawn or cached images onto the screen.
|
||||
|
||||
:param entity: The entity to be blitted.
|
||||
:type entity: Entity
|
||||
:return: Dictionary containing source and destination information for blitting.
|
||||
:rtype: dict
|
||||
"""
|
||||
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
|
||||
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||
|
||||
@ -90,12 +117,31 @@ class Renderer:
|
||||
return dict(source=img, dest=rect)
|
||||
|
||||
def load_asset(self, path, factor=1.0):
|
||||
"""
|
||||
Loads and resizes an asset from the specified path.
|
||||
|
||||
:param path: Path to the asset.
|
||||
:type path: str
|
||||
:param factor: Resizing factor for the asset.
|
||||
:type factor: float
|
||||
:return: Resized asset.
|
||||
"""
|
||||
s = int(factor*self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
|
||||
def visibility_rects(self, bp, view):
|
||||
"""
|
||||
Calculates the visibility rectangles for an agent.
|
||||
|
||||
:param bp: Blit parameters for the agent.
|
||||
:type bp: dict
|
||||
:param view: Agent's field of view.
|
||||
:type view: np.ndarray
|
||||
:return: List of visibility rectangles.
|
||||
:rtype: List[dict]
|
||||
"""
|
||||
rects = []
|
||||
for i, j in product(range(-self.view_radius, self.view_radius+1),
|
||||
range(-self.view_radius, self.view_radius+1)):
|
||||
@ -111,6 +157,14 @@ class Renderer:
|
||||
return rects
|
||||
|
||||
def render(self, entities):
|
||||
"""
|
||||
Renders the entities on the screen.
|
||||
|
||||
:param entities: List of entities to be rendered.
|
||||
:type entities: List[Entity]
|
||||
:return: Transposed RGB observation array.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
|
@ -15,10 +15,12 @@ from marl_factory_grid.utils.results import Result
|
||||
class StepRules:
|
||||
def __init__(self, *args):
|
||||
"""
|
||||
TODO
|
||||
Manages a collection of rules to be applied at each step of the environment.
|
||||
|
||||
The StepRules class allows you to organize and apply custom rules during the simulation, ensuring that the
|
||||
corresponding hooks for all rules are called at the appropriate times.
|
||||
|
||||
:return:
|
||||
:param args: Optional Rule objects to initialize the StepRules with.
|
||||
"""
|
||||
if args:
|
||||
self.rules = list(args)
|
||||
@ -92,10 +94,18 @@ class Gamestate(object):
|
||||
|
||||
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
|
||||
"""
|
||||
TODO
|
||||
The `Gamestate` class represents the state of the game environment.
|
||||
|
||||
|
||||
:return:
|
||||
:param lvl_shape: The shape of the game level.
|
||||
:type lvl_shape: tuple
|
||||
:param entities: The entities present in the environment.
|
||||
:type entities: Entities
|
||||
:param agents_conf: Agent configurations for the environment.
|
||||
:type agents_conf: Any
|
||||
:param verbose: Controls verbosity in the environment.
|
||||
:type verbose: bool
|
||||
:param rules: Organizes and applies custom rules during the simulation.
|
||||
:type rules: StepRules
|
||||
"""
|
||||
self.lvl_shape = lvl_shape
|
||||
self.entities = entities
|
||||
@ -162,7 +172,7 @@ class Gamestate(object):
|
||||
|
||||
def tick(self, actions) -> list[Result]:
|
||||
"""
|
||||
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
|
||||
Performs a single **Gamestate Tick** by calling the inner rule hooks in sequential order.
|
||||
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
|
||||
- agent tick: Agents do their actions.
|
||||
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
|
||||
|
@ -15,7 +15,7 @@ OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
TESTS = 'Tests'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'Gamestate', 'Path',
|
||||
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', 'Collection',
|
||||
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Collection',
|
||||
'State', 'Object', 'default_valid_reward', 'default_fail_reward', 'size']
|
||||
|
||||
|
||||
|
@ -6,7 +6,10 @@ import numpy as np
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
"""
|
||||
|
||||
todo @romue404
|
||||
"""
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
|
Reference in New Issue
Block a user