named tuples working

This commit is contained in:
steffen-illium 2021-06-04 16:43:34 +02:00
parent 1e87c4807f
commit 8ba7c418f0
3 changed files with 27 additions and 19 deletions

View File

@ -1,6 +1,6 @@
import pickle
from pathlib import Path
from typing import List, Union, Iterable
from typing import List, Union, Iterable, NamedTuple
import gym
import numpy as np
@ -11,6 +11,12 @@ import yaml
from environments import helpers as h
class MovementProperties(NamedTuple):
allow_square_movement: bool = False
allow_diagonal_movement: bool = False
allow_no_op: bool = False
class AgentState:
def __init__(self, i: int, action: int):
@ -78,16 +84,17 @@ class Actions(Register):
def movement_actions(self):
return self._movement_actions
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
def __init__(self, movement_properties: MovementProperties):
self.allow_no_op = movement_properties.allow_no_op
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
self.allow_square_movement = movement_properties.allow_square_movement
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
super(Actions, self).__init__()
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement
if allow_square_movement:
if self.allow_square_movement:
self + ['north', 'east', 'south', 'west']
if allow_diagonal_movement:
if self.allow_diagonal_movement:
self + ['north-east', 'south-east', 'south-west', 'north-west']
self._movement_actions = self._register.copy()
if self.allow_no_op:
@ -124,20 +131,18 @@ class BaseFactory(gym.Env):
return self._actions.movement_actions
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True,
movement_properties: MovementProperties = MovementProperties(),
omit_agent_slice_in_obs=False, **kwargs):
self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement
self.movement_properties = movement_properties
self.n_agents = n_agents
self.max_steps = max_steps
self.pomdp_radius = pomdp_radius
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
self.done_at_collision = False
_actions = Actions(allow_square_movement=self.allow_square_movement,
allow_diagonal_movement=self.allow_diagonal_movement,
allow_no_op=allow_no_op)
_actions = Actions(self.movement_properties)
self._actions = _actions + self.additional_actions
self._level = h.one_hot_level(

View File

@ -1,6 +1,6 @@
from collections import OrderedDict
from dataclasses import dataclass
from typing import List, Union
from typing import List, Union, NamedTuple
import random
import numpy as np
@ -15,8 +15,7 @@ DIRT_INDEX = -1
CLEAN_UP_ACTION = 'clean_up'
@dataclass
class DirtProperties:
class DirtProperties(NamedTuple):
clean_amount: int = 2 # How much does the robot clean with one action.
max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent.
gain_amount: float = 0.5 # How much dirt does spawn per tile

View File

@ -12,6 +12,7 @@ from gym.wrappers import FrameStack
from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from environments.factory.base_factory import MovementProperties
from environments.factory.simple_factory import DirtProperties, SimpleFactory
from environments.helpers import IGNORED_DF_COLUMNS
from environments.logging.monitor import MonitorCallback
@ -94,6 +95,9 @@ if __name__ == '__main__':
# from sb3_contrib import QRDQN
dirt_props = DirtProperties()
move_props = MovementProperties(allow_diagonal_movement=False,
allow_square_movement=True,
allow_no_op=False)
time_stamp = int(time.time())
out_path = None
@ -104,7 +108,7 @@ if __name__ == '__main__':
for seed in range(3):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
allow_diagonal_movement=True, allow_no_op=False, verbose=False,
movement_properties=move_props,
omit_agent_slice_in_obs=True)
env.save_params(Path('debug_out', 'yaml.txt'))