1
0
mirror of https://github.com/illiumst/marl-factory-grid.git synced 2026-04-23 23:37:32 +02:00

Merge branch 'main' into unit_testing

This commit is contained in:
Chanumask
2023-11-13 11:00:14 +01:00
22 changed files with 205 additions and 114 deletions
+9 -3
View File
@@ -97,20 +97,26 @@ class Factory(gym.Env):
return self.state.entities[item]
def reset(self) -> (dict, dict):
# Reset information the state holds
self.state.reset()
# Reset Information the GlobalEntity collection holds.
self.state.entities.reset()
# All is set up, trigger entity spawn with variable pos
self.state.rules.do_all_reset(self.state)
# Build initial observations for all agents
return self.obs_builder.refresh_and_build_for_all(self.state)
self.obs_builder.reset(self.state)
return self.obs_builder.build_for_all(self.state)
def manual_step_init(self) -> List[Result]:
self.state.curr_step += 1
# Main Agent Step
pre_step_result = self.state.rules.tick_pre_step_all(self)
self.obs_builder.reset_struc_obs_block(self.state)
self.obs_builder.reset(self.state)
return pre_step_result
def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray):
@@ -164,7 +170,7 @@ class Factory(gym.Env):
info.update(step_reward=sum(reward), step=self.state.curr_step)
obs = self.obs_builder.refresh_and_build_for_all(self.state)
obs = self.obs_builder.build_for_all(self.state)
return None, [x for x in obs.values()], reward, done, info
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
@@ -1,6 +1,5 @@
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.rules import SpawnAgents
class Agents(Collection):
@@ -8,7 +7,7 @@ class Agents(Collection):
@property
def spawn_rule(self):
return {SpawnAgents.__name__: {}}
return {}
@property
def var_is_blocking_light(self):
@@ -27,7 +27,7 @@ class Entities(Objects):
@property
def floorlist(self):
shuffle(self._floor_positions)
return self._floor_positions
return [x for x in self._floor_positions]
def __init__(self, floor_positions):
self._floor_positions = floor_positions
+11 -17
View File
@@ -70,28 +70,22 @@ class SpawnAgents(Rule):
def on_reset(self, state):
agents = state[c.AGENT]
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
for agent_name, agent_conf in state.agents_conf.items():
empty_positions = state.entities.empty_positions
actions = agent_conf['actions'].copy()
observations = agent_conf['observations'].copy()
positions = agent_conf['positions'].copy()
other = agent_conf['other'].copy()
if positions:
shuffle(positions)
while True:
try:
pos = positions.pop()
except IndexError:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_conf["positions"].copy()}')
if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
continue
else:
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
break
if position := h.get_first(x for x in positions if x in empty_positions):
assert state.check_pos_validity(position), 'smth went wrong....'
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
elif positions:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_conf["positions"].copy()}')
else:
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
pass
return []
class DoneAtMaxStepsReached(Rule):
@@ -103,7 +97,7 @@ class DoneAtMaxStepsReached(Rule):
def on_check_done(self, state):
if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
return []
class AssignGlobalPositions(Rule):
@@ -130,7 +124,7 @@ class WatchCollisions(Rule):
def tick_post_step(self, state) -> List[TickResult]:
self.curr_done = False
pos_with_collisions = state.get_all_pos_with_collisions()
pos_with_collisions = state.get_collision_positions()
results = list()
for pos in pos_with_collisions:
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]