AgentState Object

dataclass
class AgentState:
    i: int
    action: int

    pos = None
    collision_vector = None
    action_valid = None
This commit is contained in:
steffen-illium
2021-05-14 09:09:20 +02:00
parent 14741aa5a5
commit 86204a6266
3 changed files with 74 additions and 37 deletions

View File

@ -1,12 +1,17 @@
import numpy as np
from attr import dataclass
from environments.factory.base_factory import BaseFactory
from collections import namedtuple
from typing import Iterable
from environments import helpers as h
DIRT_INDEX = -1
DirtProperties = namedtuple('DirtProperties', ['clean_amount', 'max_spawn_ratio', 'gain_amount'],
defaults=[0.25, 0.1, 0.1])
@dataclass
class DirtProperties:
clean_amount = 0.25
max_spawn_ratio = 0.1
gain_amount = 0.1
class GettingDirty(BaseFactory):
@ -15,7 +20,7 @@ class GettingDirty(BaseFactory):
def _clean_up_action(self):
return self.movement_actions + 1 - 1
def __init__(self, *args, dirt_properties:DirtProperties, **kwargs):
def __init__(self, *args, dirt_properties: DirtProperties, **kwargs):
self._dirt_properties = dirt_properties
super(GettingDirty, self).__init__(*args, **kwargs)
self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
@ -58,7 +63,7 @@ class GettingDirty(BaseFactory):
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt()
def calculate_reward(self, collisions_vecs: np.ndarray, actions: Iterable[int], r: int) -> (int, dict):
def calculate_reward(self, collisions_vecs: np.ndarray, actions: Iterable[int]) -> (int, dict):
for agent_i, cols in enumerate(collisions_vecs):
cols = np.argwhere(cols != 0).flatten()
print(f't = {self.steps}\tAgent {agent_i} has collisions with '
@ -68,8 +73,8 @@ class GettingDirty(BaseFactory):
if __name__ == '__main__':
import random
dirt_properties = DirtProperties()
factory = GettingDirty(n_agents=1, dirt_properties=dirt_properties)
dirt_props = DirtProperties()
factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
random_actions = [random.randint(0, 8) for _ in range(200)]
for action in random_actions:
state, r, done, _ = factory.step(action)