multi_agent observation when n_agent more then 1

This commit is contained in:
steffen-illium
2021-06-09 13:12:49 +02:00
parent 62c141aa1c
commit cf2378a734
6 changed files with 271 additions and 159 deletions

View File

@ -1,16 +1,14 @@
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Union, NamedTuple
import random
import numpy as np
from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties
from environments.factory.base_factory import BaseFactory
from environments import helpers as h
from environments.logging.monitor import MonitorCallback
from environments.factory.renderer import Renderer, Entity
from environments.utility_classes import AgentState, MovementProperties
DIRT_INDEX = -1
CLEAN_UP_ACTION = 'clean_up'
@ -25,13 +23,16 @@ class DirtProperties(NamedTuple):
max_global_amount: int = 20 # Max dirt amount in the whole environment.
# noinspection PyAttributeOutsideInit
class SimpleFactory(BaseFactory):
@property
def additional_actions(self) -> Union[str, List[str]]:
return CLEAN_UP_ACTION
def _is_clean_up_action(self, action):
def _is_clean_up_action(self, action: Union[str, int]):
if isinstance(action, str):
action = self._actions.by_name(action)
return self._actions[action] == CLEAN_UP_ACTION
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
@ -47,9 +48,9 @@ class SimpleFactory(BaseFactory):
height, width = self._state.shape[1:]
self._renderer = Renderer(width, height, view_radius=self.pomdp_radius)
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
def asset_str(agent):
if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]):
@ -93,17 +94,18 @@ class SimpleFactory(BaseFactory):
return pos, cleanup_was_sucessfull
def step(self, actions):
_, r, done, info = super(SimpleFactory, self).step(actions)
_, reward, done, info = super(SimpleFactory, self).step(actions)
if not self._next_dirt_spawn:
self.spawn_dirt()
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
else:
self._next_dirt_spawn -= 1
obs = self._return_state()
return obs, r, done, info
obs = self._get_observations()
return obs, reward, done, info
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
if action != self._is_moving_action(action):
if action != self._actions.is_moving_action(action):
if self._is_clean_up_action(action):
agent_i_pos = self.agent_i_position(agent_i)
_, valid = self.clean_up(agent_i_pos)
@ -119,7 +121,7 @@ class SimpleFactory(BaseFactory):
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt()
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
obs = self._return_state()
obs = self._get_observations()
return obs
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
@ -141,7 +143,7 @@ class SimpleFactory(BaseFactory):
if entity != self._state_slices.by_name("dirt")]
if list_of_collisions:
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
self.print(f't = {self._steps}\tAgent {agent_state.i} has collisions with '
f'{list_of_collisions}')
if self._is_clean_up_action(agent_state.action):
@ -155,7 +157,7 @@ class SimpleFactory(BaseFactory):
f'at {agent_state.pos}, but was unsucsessfull.')
info_dict.update(failed_cleanup_attempt=1)
elif self._is_moving_action(agent_state.action):
elif self._actions.is_moving_action(agent_state.action):
if agent_state.action_valid:
# info_dict.update(movement=1)
reward -= 0.00
@ -185,10 +187,11 @@ class SimpleFactory(BaseFactory):
if __name__ == '__main__':
render = True
import yaml
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
env_kwargs = yaml.load(f)
factory = SimpleFactory(**env_kwargs)
move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True)
dirt_props = DirtProperties()
factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=2,
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=False)
# dirt_props = DirtProperties()
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
@ -200,10 +203,12 @@ if __name__ == '__main__':
for epoch in range(100):
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
env_state = factory.reset()
r = 0
for agent_i_action in random_actions:
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
r += step_r
if render:
factory.render()
if done_bool:
break
print(f'Factory run {epoch} done, reward is:\n {reward}')
print(f'Factory run {epoch} done, reward is:\n {r}')