mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
named tuples working
This commit is contained in:
parent
1e87c4807f
commit
8ba7c418f0
@ -1,6 +1,6 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable
|
||||
from typing import List, Union, Iterable, NamedTuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -11,6 +11,12 @@ import yaml
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
class MovementProperties(NamedTuple):
|
||||
allow_square_movement: bool = False
|
||||
allow_diagonal_movement: bool = False
|
||||
allow_no_op: bool = False
|
||||
|
||||
|
||||
class AgentState:
|
||||
|
||||
def __init__(self, i: int, action: int):
|
||||
@ -78,16 +84,17 @@ class Actions(Register):
|
||||
def movement_actions(self):
|
||||
return self._movement_actions
|
||||
|
||||
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
|
||||
def __init__(self, movement_properties: MovementProperties):
|
||||
self.allow_no_op = movement_properties.allow_no_op
|
||||
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
||||
self.allow_square_movement = movement_properties.allow_square_movement
|
||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||
super(Actions, self).__init__()
|
||||
self.allow_no_op = allow_no_op
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
if allow_square_movement:
|
||||
|
||||
if self.allow_square_movement:
|
||||
self + ['north', 'east', 'south', 'west']
|
||||
if allow_diagonal_movement:
|
||||
if self.allow_diagonal_movement:
|
||||
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
||||
self._movement_actions = self._register.copy()
|
||||
if self.allow_no_op:
|
||||
@ -124,20 +131,18 @@ class BaseFactory(gym.Env):
|
||||
return self._actions.movement_actions
|
||||
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||
allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True,
|
||||
movement_properties: MovementProperties = MovementProperties(),
|
||||
omit_agent_slice_in_obs=False, **kwargs):
|
||||
self.allow_no_op = allow_no_op
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
|
||||
self.movement_properties = movement_properties
|
||||
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
self.pomdp_radius = pomdp_radius
|
||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||
|
||||
self.done_at_collision = False
|
||||
_actions = Actions(allow_square_movement=self.allow_square_movement,
|
||||
allow_diagonal_movement=self.allow_diagonal_movement,
|
||||
allow_no_op=allow_no_op)
|
||||
_actions = Actions(self.movement_properties)
|
||||
self._actions = _actions + self.additional_actions
|
||||
|
||||
self._level = h.one_hot_level(
|
||||
|
@ -1,6 +1,6 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Union, NamedTuple
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
@ -15,8 +15,7 @@ DIRT_INDEX = -1
|
||||
CLEAN_UP_ACTION = 'clean_up'
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirtProperties:
|
||||
class DirtProperties(NamedTuple):
|
||||
clean_amount: int = 2 # How much does the robot clean with one action.
|
||||
max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent.
|
||||
gain_amount: float = 0.5 # How much dirt does spawn per tile
|
||||
|
6
main.py
6
main.py
@ -12,6 +12,7 @@ from gym.wrappers import FrameStack
|
||||
from stable_baselines3.common.callbacks import CallbackList
|
||||
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
|
||||
|
||||
from environments.factory.base_factory import MovementProperties
|
||||
from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
||||
from environments.helpers import IGNORED_DF_COLUMNS
|
||||
from environments.logging.monitor import MonitorCallback
|
||||
@ -94,6 +95,9 @@ if __name__ == '__main__':
|
||||
# from sb3_contrib import QRDQN
|
||||
|
||||
dirt_props = DirtProperties()
|
||||
move_props = MovementProperties(allow_diagonal_movement=False,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
time_stamp = int(time.time())
|
||||
|
||||
out_path = None
|
||||
@ -104,7 +108,7 @@ if __name__ == '__main__':
|
||||
for seed in range(3):
|
||||
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
||||
allow_diagonal_movement=True, allow_no_op=False, verbose=False,
|
||||
movement_properties=move_props,
|
||||
omit_agent_slice_in_obs=True)
|
||||
env.save_params(Path('debug_out', 'yaml.txt'))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user