mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2026-04-23 23:37:32 +02:00
Merge branch 'main' into unit_testing
This commit is contained in:
@@ -97,20 +97,26 @@ class Factory(gym.Env):
|
||||
return self.state.entities[item]
|
||||
|
||||
def reset(self) -> (dict, dict):
|
||||
|
||||
# Reset information the state holds
|
||||
self.state.reset()
|
||||
|
||||
# Reset Information the GlobalEntity collection holds.
|
||||
self.state.entities.reset()
|
||||
|
||||
# All is set up, trigger entity spawn with variable pos
|
||||
self.state.rules.do_all_reset(self.state)
|
||||
|
||||
# Build initial observations for all agents
|
||||
return self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
self.obs_builder.reset(self.state)
|
||||
return self.obs_builder.build_for_all(self.state)
|
||||
|
||||
def manual_step_init(self) -> List[Result]:
|
||||
self.state.curr_step += 1
|
||||
|
||||
# Main Agent Step
|
||||
pre_step_result = self.state.rules.tick_pre_step_all(self)
|
||||
self.obs_builder.reset_struc_obs_block(self.state)
|
||||
self.obs_builder.reset(self.state)
|
||||
return pre_step_result
|
||||
|
||||
def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray):
|
||||
@@ -164,7 +170,7 @@ class Factory(gym.Env):
|
||||
|
||||
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
||||
|
||||
obs = self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
obs = self.obs_builder.build_for_all(self.state)
|
||||
return None, [x for x in obs.values()], reward, done, info
|
||||
|
||||
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.environment.rules import SpawnAgents
|
||||
|
||||
|
||||
class Agents(Collection):
|
||||
@@ -8,7 +7,7 @@ class Agents(Collection):
|
||||
|
||||
@property
|
||||
def spawn_rule(self):
|
||||
return {SpawnAgents.__name__: {}}
|
||||
return {}
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
|
||||
@@ -27,7 +27,7 @@ class Entities(Objects):
|
||||
@property
|
||||
def floorlist(self):
|
||||
shuffle(self._floor_positions)
|
||||
return self._floor_positions
|
||||
return [x for x in self._floor_positions]
|
||||
|
||||
def __init__(self, floor_positions):
|
||||
self._floor_positions = floor_positions
|
||||
|
||||
@@ -70,28 +70,22 @@ class SpawnAgents(Rule):
|
||||
|
||||
def on_reset(self, state):
|
||||
agents = state[c.AGENT]
|
||||
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
|
||||
for agent_name, agent_conf in state.agents_conf.items():
|
||||
empty_positions = state.entities.empty_positions
|
||||
actions = agent_conf['actions'].copy()
|
||||
observations = agent_conf['observations'].copy()
|
||||
positions = agent_conf['positions'].copy()
|
||||
other = agent_conf['other'].copy()
|
||||
if positions:
|
||||
shuffle(positions)
|
||||
while True:
|
||||
try:
|
||||
pos = positions.pop()
|
||||
except IndexError:
|
||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||
f'\n{agent_conf["positions"].copy()}')
|
||||
if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
|
||||
continue
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
|
||||
break
|
||||
|
||||
if position := h.get_first(x for x in positions if x in empty_positions):
|
||||
assert state.check_pos_validity(position), 'smth went wrong....'
|
||||
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
||||
elif positions:
|
||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||
f'\n{agent_conf["positions"].copy()}')
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
class DoneAtMaxStepsReached(Rule):
|
||||
@@ -103,7 +97,7 @@ class DoneAtMaxStepsReached(Rule):
|
||||
def on_check_done(self, state):
|
||||
if self.max_steps <= state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
|
||||
return []
|
||||
|
||||
|
||||
class AssignGlobalPositions(Rule):
|
||||
@@ -130,7 +124,7 @@ class WatchCollisions(Rule):
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
self.curr_done = False
|
||||
pos_with_collisions = state.get_all_pos_with_collisions()
|
||||
pos_with_collisions = state.get_collision_positions()
|
||||
results = list()
|
||||
for pos in pos_with_collisions:
|
||||
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
|
||||
|
||||
Reference in New Issue
Block a user