diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 850f794..2918ed7 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -1,6 +1,6 @@ import pickle from pathlib import Path -from typing import List, Union, Iterable +from typing import List, Union, Iterable, NamedTuple import gym import numpy as np @@ -11,6 +11,12 @@ import yaml from environments import helpers as h +class MovementProperties(NamedTuple): + allow_square_movement: bool = False + allow_diagonal_movement: bool = False + allow_no_op: bool = False + + class AgentState: def __init__(self, i: int, action: int): @@ -78,16 +84,17 @@ class Actions(Register): def movement_actions(self): return self._movement_actions - def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False): + def __init__(self, movement_properties: MovementProperties): + self.allow_no_op = movement_properties.allow_no_op + self.allow_diagonal_movement = movement_properties.allow_diagonal_movement + self.allow_square_movement = movement_properties.allow_square_movement # FIXME: There is a bug in helpers because there actions are ints. and the order matters. - assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!" + assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!" super(Actions, self).__init__() - self.allow_no_op = allow_no_op - self.allow_diagonal_movement = allow_diagonal_movement - self.allow_square_movement = allow_square_movement - if allow_square_movement: + + if self.allow_square_movement: self + ['north', 'east', 'south', 'west'] - if allow_diagonal_movement: + if self.allow_diagonal_movement: self + ['north-east', 'south-east', 'south-west', 'north-west'] self._movement_actions = self._register.copy() if self.allow_no_op: @@ -124,20 +131,18 @@ class BaseFactory(gym.Env): return self._actions.movement_actions def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, - allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True, + movement_properties: MovementProperties = MovementProperties(), omit_agent_slice_in_obs=False, **kwargs): - self.allow_no_op = allow_no_op - self.allow_diagonal_movement = allow_diagonal_movement - self.allow_square_movement = allow_square_movement + + self.movement_properties = movement_properties + self.n_agents = n_agents self.max_steps = max_steps self.pomdp_radius = pomdp_radius self.omit_agent_slice_in_obs = omit_agent_slice_in_obs self.done_at_collision = False - _actions = Actions(allow_square_movement=self.allow_square_movement, - allow_diagonal_movement=self.allow_diagonal_movement, - allow_no_op=allow_no_op) + _actions = Actions(self.movement_properties) self._actions = _actions + self.additional_actions self._level = h.one_hot_level( diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 8b89116..2dde6df 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -1,6 +1,6 @@ from collections import OrderedDict from dataclasses import dataclass -from typing import List, Union +from typing import List, Union, NamedTuple import random import numpy as np @@ -15,8 +15,7 @@ DIRT_INDEX = -1 CLEAN_UP_ACTION = 'clean_up' -@dataclass -class DirtProperties: +class DirtProperties(NamedTuple): clean_amount: int = 2 # How much does the robot clean with one action. max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent. gain_amount: float = 0.5 # How much dirt does spawn per tile diff --git a/main.py b/main.py index 89c9781..d06d030 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ from gym.wrappers import FrameStack from stable_baselines3.common.callbacks import CallbackList from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv +from environments.factory.base_factory import MovementProperties from environments.factory.simple_factory import DirtProperties, SimpleFactory from environments.helpers import IGNORED_DF_COLUMNS from environments.logging.monitor import MonitorCallback @@ -94,6 +95,9 @@ if __name__ == '__main__': # from sb3_contrib import QRDQN dirt_props = DirtProperties() + move_props = MovementProperties(allow_diagonal_movement=False, + allow_square_movement=True, + allow_no_op=False) time_stamp = int(time.time()) out_path = None @@ -104,7 +108,7 @@ if __name__ == '__main__': for seed in range(3): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400, - allow_diagonal_movement=True, allow_no_op=False, verbose=False, + movement_properties=move_props, omit_agent_slice_in_obs=True) env.save_params(Path('debug_out', 'yaml.txt'))