Merge remote-tracking branch 'origin/main' into main
This commit is contained in:
@ -1,16 +1,14 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Union, NamedTuple
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties
|
||||
from environments.factory.base_factory import BaseFactory
|
||||
from environments import helpers as h
|
||||
|
||||
from environments.logging.monitor import MonitorCallback
|
||||
from environments.factory.renderer import Renderer, Entity
|
||||
from environments.utility_classes import AgentState, MovementProperties
|
||||
|
||||
DIRT_INDEX = -1
|
||||
CLEAN_UP_ACTION = 'clean_up'
|
||||
@ -25,13 +23,16 @@ class DirtProperties(NamedTuple):
|
||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class SimpleFactory(BaseFactory):
|
||||
|
||||
@property
|
||||
def additional_actions(self) -> Union[str, List[str]]:
|
||||
return CLEAN_UP_ACTION
|
||||
|
||||
def _is_clean_up_action(self, action):
|
||||
def _is_clean_up_action(self, action: Union[str, int]):
|
||||
if isinstance(action, str):
|
||||
action = self._actions.by_name(action)
|
||||
return self._actions[action] == CLEAN_UP_ACTION
|
||||
|
||||
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
|
||||
@ -47,9 +48,9 @@ class SimpleFactory(BaseFactory):
|
||||
height, width = self._state.shape[1:]
|
||||
self._renderer = Renderer(width, height, view_radius=self.pomdp_radius)
|
||||
|
||||
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
|
||||
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
|
||||
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||
|
||||
def asset_str(agent):
|
||||
if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]):
|
||||
@ -94,17 +95,18 @@ class SimpleFactory(BaseFactory):
|
||||
return pos, cleanup_was_sucessfull
|
||||
|
||||
def step(self, actions):
|
||||
_, r, done, info = super(SimpleFactory, self).step(actions)
|
||||
_, reward, done, info = super(SimpleFactory, self).step(actions)
|
||||
if not self._next_dirt_spawn:
|
||||
self.spawn_dirt()
|
||||
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||
else:
|
||||
self._next_dirt_spawn -= 1
|
||||
obs = self._return_state()
|
||||
return obs, r, done, info
|
||||
|
||||
obs = self._get_observations()
|
||||
return obs, reward, done, info
|
||||
|
||||
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
if action != self._is_moving_action(action):
|
||||
if action != self._actions.is_moving_action(action):
|
||||
if self._is_clean_up_action(action):
|
||||
agent_i_pos = self.agent_i_position(agent_i)
|
||||
_, valid = self.clean_up(agent_i_pos)
|
||||
@ -120,7 +122,7 @@ class SimpleFactory(BaseFactory):
|
||||
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
|
||||
self.spawn_dirt()
|
||||
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||
obs = self._return_state()
|
||||
obs = self._get_observations()
|
||||
return obs
|
||||
|
||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||
@ -142,7 +144,7 @@ class SimpleFactory(BaseFactory):
|
||||
if entity != self._state_slices.by_name("dirt")]
|
||||
|
||||
if list_of_collisions:
|
||||
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||
self.print(f't = {self._steps}\tAgent {agent_state.i} has collisions with '
|
||||
f'{list_of_collisions}')
|
||||
|
||||
if self._is_clean_up_action(agent_state.action):
|
||||
@ -156,7 +158,7 @@ class SimpleFactory(BaseFactory):
|
||||
f'at {agent_state.pos}, but was unsucsessfull.')
|
||||
info_dict.update(failed_cleanup_attempt=1)
|
||||
|
||||
elif self._is_moving_action(agent_state.action):
|
||||
elif self._actions.is_moving_action(agent_state.action):
|
||||
if agent_state.action_valid:
|
||||
# info_dict.update(movement=1)
|
||||
reward -= 0.00
|
||||
@ -186,10 +188,11 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
import yaml
|
||||
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
|
||||
env_kwargs = yaml.load(f)
|
||||
factory = SimpleFactory(**env_kwargs)
|
||||
|
||||
move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True)
|
||||
dirt_props = DirtProperties()
|
||||
factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=2,
|
||||
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=False)
|
||||
|
||||
# dirt_props = DirtProperties()
|
||||
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
||||
@ -201,10 +204,12 @@ if __name__ == '__main__':
|
||||
for epoch in range(100):
|
||||
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
|
||||
env_state = factory.reset()
|
||||
r = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
r += step_r
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {reward}')
|
||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
|
Reference in New Issue
Block a user