named tuples working

This commit is contained in:
steffen-illium 2021-06-04 16:43:34 +02:00
parent 1e87c4807f
commit 8ba7c418f0
3 changed files with 27 additions and 19 deletions

View File

@ -1,6 +1,6 @@
import pickle import pickle
from pathlib import Path from pathlib import Path
from typing import List, Union, Iterable from typing import List, Union, Iterable, NamedTuple
import gym import gym
import numpy as np import numpy as np
@ -11,6 +11,12 @@ import yaml
from environments import helpers as h from environments import helpers as h
class MovementProperties(NamedTuple):
allow_square_movement: bool = False
allow_diagonal_movement: bool = False
allow_no_op: bool = False
class AgentState: class AgentState:
def __init__(self, i: int, action: int): def __init__(self, i: int, action: int):
@ -78,16 +84,17 @@ class Actions(Register):
def movement_actions(self): def movement_actions(self):
return self._movement_actions return self._movement_actions
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False): def __init__(self, movement_properties: MovementProperties):
self.allow_no_op = movement_properties.allow_no_op
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
self.allow_square_movement = movement_properties.allow_square_movement
# FIXME: There is a bug in helpers because there actions are ints. and the order matters. # FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!" assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
super(Actions, self).__init__() super(Actions, self).__init__()
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement if self.allow_square_movement:
self.allow_square_movement = allow_square_movement
if allow_square_movement:
self + ['north', 'east', 'south', 'west'] self + ['north', 'east', 'south', 'west']
if allow_diagonal_movement: if self.allow_diagonal_movement:
self + ['north-east', 'south-east', 'south-west', 'north-west'] self + ['north-east', 'south-east', 'south-west', 'north-west']
self._movement_actions = self._register.copy() self._movement_actions = self._register.copy()
if self.allow_no_op: if self.allow_no_op:
@ -124,20 +131,18 @@ class BaseFactory(gym.Env):
return self._actions.movement_actions return self._actions.movement_actions
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True, movement_properties: MovementProperties = MovementProperties(),
omit_agent_slice_in_obs=False, **kwargs): omit_agent_slice_in_obs=False, **kwargs):
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement self.movement_properties = movement_properties
self.allow_square_movement = allow_square_movement
self.n_agents = n_agents self.n_agents = n_agents
self.max_steps = max_steps self.max_steps = max_steps
self.pomdp_radius = pomdp_radius self.pomdp_radius = pomdp_radius
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
self.done_at_collision = False self.done_at_collision = False
_actions = Actions(allow_square_movement=self.allow_square_movement, _actions = Actions(self.movement_properties)
allow_diagonal_movement=self.allow_diagonal_movement,
allow_no_op=allow_no_op)
self._actions = _actions + self.additional_actions self._actions = _actions + self.additional_actions
self._level = h.one_hot_level( self._level = h.one_hot_level(

View File

@ -1,6 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import List, Union, NamedTuple
import random import random
import numpy as np import numpy as np
@ -15,8 +15,7 @@ DIRT_INDEX = -1
CLEAN_UP_ACTION = 'clean_up' CLEAN_UP_ACTION = 'clean_up'
@dataclass class DirtProperties(NamedTuple):
class DirtProperties:
clean_amount: int = 2 # How much does the robot clean with one action. clean_amount: int = 2 # How much does the robot clean with one action.
max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent. max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent.
gain_amount: float = 0.5 # How much dirt does spawn per tile gain_amount: float = 0.5 # How much dirt does spawn per tile

View File

@ -12,6 +12,7 @@ from gym.wrappers import FrameStack
from stable_baselines3.common.callbacks import CallbackList from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from environments.factory.base_factory import MovementProperties
from environments.factory.simple_factory import DirtProperties, SimpleFactory from environments.factory.simple_factory import DirtProperties, SimpleFactory
from environments.helpers import IGNORED_DF_COLUMNS from environments.helpers import IGNORED_DF_COLUMNS
from environments.logging.monitor import MonitorCallback from environments.logging.monitor import MonitorCallback
@ -94,6 +95,9 @@ if __name__ == '__main__':
# from sb3_contrib import QRDQN # from sb3_contrib import QRDQN
dirt_props = DirtProperties() dirt_props = DirtProperties()
move_props = MovementProperties(allow_diagonal_movement=False,
allow_square_movement=True,
allow_no_op=False)
time_stamp = int(time.time()) time_stamp = int(time.time())
out_path = None out_path = None
@ -104,7 +108,7 @@ if __name__ == '__main__':
for seed in range(3): for seed in range(3):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400, env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
allow_diagonal_movement=True, allow_no_op=False, verbose=False, movement_properties=move_props,
omit_agent_slice_in_obs=True) omit_agent_slice_in_obs=True)
env.save_params(Path('debug_out', 'yaml.txt')) env.save_params(Path('debug_out', 'yaml.txt'))