mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
added test hooks (like rules)
This commit is contained in:
parent
64c0d0e4e9
commit
3d18fe9839
42
marl_factory_grid/environment/tests.py
Normal file
42
marl_factory_grid/environment/tests.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import abc
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||||
|
|
||||||
|
|
||||||
|
class Test(abc.ABC):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.name}'
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def on_reset(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class FirstTest(Test):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
pass
|
81
marl_factory_grid/testing/test_config.yaml
Normal file
81
marl_factory_grid/testing/test_config.yaml
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
Agents:
|
||||||
|
Wolfgang:
|
||||||
|
Actions:
|
||||||
|
- Noop
|
||||||
|
- BtryCharge
|
||||||
|
- CleanUp
|
||||||
|
- DestAction
|
||||||
|
- DoorUse
|
||||||
|
- ItemAction
|
||||||
|
- Move8
|
||||||
|
Observations:
|
||||||
|
- Combined:
|
||||||
|
- Other
|
||||||
|
- Walls
|
||||||
|
- GlobalPosition
|
||||||
|
- Battery
|
||||||
|
- ChargePods
|
||||||
|
- DirtPiles
|
||||||
|
- Destinations
|
||||||
|
- Doors
|
||||||
|
- Items
|
||||||
|
- Inventory
|
||||||
|
- DropOffLocations
|
||||||
|
- Maintainers
|
||||||
|
Entities:
|
||||||
|
Batteries:
|
||||||
|
initial_charge: 0.8
|
||||||
|
per_action_costs: 0.02
|
||||||
|
ChargePods: {}
|
||||||
|
Destinations: {}
|
||||||
|
DirtPiles:
|
||||||
|
clean_amount: 1
|
||||||
|
dirt_spawn_r_var: 0.1
|
||||||
|
initial_amount: 2
|
||||||
|
initial_dirt_ratio: 0.05
|
||||||
|
max_global_amount: 20
|
||||||
|
max_local_amount: 5
|
||||||
|
Doors: {}
|
||||||
|
DropOffLocations: {}
|
||||||
|
GlobalPositions: {}
|
||||||
|
Inventories: {}
|
||||||
|
Items: {}
|
||||||
|
Machines: {}
|
||||||
|
Maintainers: {}
|
||||||
|
Zones: {}
|
||||||
|
|
||||||
|
General:
|
||||||
|
env_seed: 69
|
||||||
|
individual_rewards: true
|
||||||
|
level_name: large
|
||||||
|
pomdp_r: 3
|
||||||
|
verbose: false
|
||||||
|
tests: true
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
SpawnAgents: {}
|
||||||
|
DoneAtBatteryDischarge: {}
|
||||||
|
Collision:
|
||||||
|
done_at_collisions: false
|
||||||
|
AssignGlobalPositions: {}
|
||||||
|
DoneAtDestinationReachAny: {}
|
||||||
|
DestinationReachReward: {}
|
||||||
|
SpawnDestinations:
|
||||||
|
n_dests: 1
|
||||||
|
spawn_mode: GROUPED
|
||||||
|
DoneOnAllDirtCleaned: {}
|
||||||
|
SpawnDirt:
|
||||||
|
spawn_freq: 15
|
||||||
|
EntitiesSmearDirtOnMove:
|
||||||
|
smear_ratio: 0.2
|
||||||
|
DoorAutoClose:
|
||||||
|
close_frequency: 10
|
||||||
|
ItemRules:
|
||||||
|
max_dropoff_storage_size: 0
|
||||||
|
n_items: 5
|
||||||
|
n_locations: 5
|
||||||
|
spawn_frequency: 15
|
||||||
|
MaxStepsReached:
|
||||||
|
max_steps: 10
|
||||||
|
# AgentSingleZonePlacement:
|
||||||
|
# n_zones: 4
|
32
marl_factory_grid/testing/test_run.py
Normal file
32
marl_factory_grid/testing/test_run.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
from marl_factory_grid.environment.factory import Factory
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Render at each step?
|
||||||
|
render = True
|
||||||
|
|
||||||
|
# Path to config File
|
||||||
|
path = Path('test_config.yaml')
|
||||||
|
|
||||||
|
# Env Init
|
||||||
|
factory = Factory(path)
|
||||||
|
|
||||||
|
for episode in trange(5):
|
||||||
|
_ = factory.reset()
|
||||||
|
done = False
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
action_spaces = factory.action_space
|
||||||
|
while not done:
|
||||||
|
a = [randint(0, x.n - 1) for x in action_spaces]
|
||||||
|
obs_type, _, _, done, info = factory.step(a)
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
if done:
|
||||||
|
print(f'Episode {episode} done...')
|
||||||
|
break
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user