checked rule returns

This commit is contained in:
Steffen Illium
2023-11-11 15:33:33 +01:00
parent 80247cb56a
commit 1dfeb7ae4a
7 changed files with 10 additions and 11 deletions
marl_factory_grid
environment
modules
batteries
clean_up
destinations
maintenance
zones
setup.py

@ -85,7 +85,7 @@ class SpawnAgents(Rule):
f'\n{agent_conf["positions"].copy()}')
else:
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
pass
return []
class DoneAtMaxStepsReached(Rule):
@ -97,7 +97,7 @@ class DoneAtMaxStepsReached(Rule):
def on_check_done(self, state):
if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
return []
class AssignGlobalPositions(Rule):

@ -60,7 +60,7 @@ class BatteryDecharge(Rule):
batteries.by_entity(agent).decharge(energy_consumption)
results.append(TickResult(self.name, entity=agent, validity=c.VALID))
results.append(TickResult(self.name, entity=agent, validity=c.VALID, value=energy_consumption))
return results

@ -22,7 +22,7 @@ class DoneOnAllDirtCleaned(Rule):
def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
return []
class RespawnDirt(Rule):
@ -81,5 +81,6 @@ class EntitiesSmearDirtOnMove(Rule):
old_pos_dirt = next(iter(old_pos_dirt))
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID))
results.append(TickResult(identifier=self.name, entity=entity,
validity=c.VALID, value=smeared_dirt))
return results

@ -90,6 +90,8 @@ class DoneAtDestinationReach(DestinationReachReward):
pass
else:
dest.unmark_as_reached()
return [DoneResult(f'{dest}_unmarked_as_reached',
validity=c.NOT_VALID, entity=dest)]
else:
pass
else:

@ -1,6 +1,5 @@
from typing import List
import marl_factory_grid.modules.maintenance.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
@ -31,5 +30,5 @@ class DoneAtMaintainerCollision(Rule):
for agent in agents:
if agent.pos in m_pos:
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD))
reward=M.MAINTAINER_COLLISION_REWARD))
return done_results

@ -43,9 +43,6 @@ class AgentSingleZonePlacement(Rule):
agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state)
return []
def tick_step(self, state):
return []
class IndividualDestinationZonePlacement(Rule):

@ -5,7 +5,7 @@ long_description = (this_directory / "README.md").read_text()
setup(name='Marl-Factory-Grid',
version='0.1.2',
version='0.2.0',
description='A framework to research MARL agents in various setings.',
author='Steffen Illium',
author_email='steffen.illium@ifi.lmu.de',