mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Items
This commit is contained in:
parent
b0aeb6f94f
commit
8631f11502
@ -87,20 +87,22 @@ class BaseFactory(gym.Env):
|
|||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_r: Union[None, int] = 0,
|
||||||
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
||||||
combin_agent_slices_in_obs: bool = False, frames_to_stack=0, record_episodes=False,
|
combin_agent_slices_in_obs: bool = False, frames_to_stack=0, record_episodes=False,
|
||||||
omit_agent_slice_in_obs=False, done_at_collision=False, cast_shadows=True, **kwargs):
|
omit_agent_slice_in_obs=False, done_at_collision=False, cast_shadows=True,
|
||||||
|
verbose=False, **kwargs):
|
||||||
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||||
|
|
||||||
# Attribute Assignment
|
# Attribute Assignment
|
||||||
self.movement_properties = movement_properties
|
self.movement_properties = movement_properties
|
||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
self._level_shape = None
|
self._level_shape = None
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.pomdp_r = pomdp_radius
|
self.pomdp_r = pomdp_r
|
||||||
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
||||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||||
self.cast_shadows = cast_shadows
|
self.cast_shadows = cast_shadows
|
||||||
@ -115,6 +117,7 @@ class BaseFactory(gym.Env):
|
|||||||
if additional_actions := self.additional_actions:
|
if additional_actions := self.additional_actions:
|
||||||
self._actions.register_additional_items(additional_actions)
|
self._actions.register_additional_items(additional_actions)
|
||||||
|
|
||||||
|
# Reset
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def _init_state_slices(self) -> StateSlices:
|
def _init_state_slices(self) -> StateSlices:
|
||||||
@ -345,7 +348,7 @@ class BaseFactory(gym.Env):
|
|||||||
else:
|
else:
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def do_additional_actions(self, agent_i: int, action: int) -> bool:
|
def do_additional_actions(self, agent: Agent, action: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
||||||
@ -418,3 +421,7 @@ class BaseFactory(gym.Env):
|
|||||||
if hasattr(entity, 'summarize_state'):
|
if hasattr(entity, 'summarize_state'):
|
||||||
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
def print(self, string):
|
||||||
|
if self.verbose:
|
||||||
|
print(string)
|
||||||
|
84
environments/factory/item_pickup.py
Normal file
84
environments/factory/item_pickup.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from typing import List, Union, NamedTuple
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from environments.helpers import Constants as c
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.factory.base.objects import Agent, Action, Object, Slice
|
||||||
|
from environments.factory.base.registers import Entities
|
||||||
|
|
||||||
|
from environments.factory.renderer import Renderer, Entity
|
||||||
|
from environments.utility_classes import MovementProperties
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ITEM = 'item'
|
||||||
|
INVENTORY = 'inventory'
|
||||||
|
PICK_UP = 'pick_up'
|
||||||
|
DROP_DOWN = 'drop_down'
|
||||||
|
ITEM_ACTION = 'item_action'
|
||||||
|
NO_ITEM = 0
|
||||||
|
ITEM_DROP_OFF = -1
|
||||||
|
|
||||||
|
|
||||||
|
class ItemProperties(NamedTuple):
|
||||||
|
n_items: int = 1 # How much does the robot clean with one actions.
|
||||||
|
spawn_frequency: int = 5 # Spawn Frequency in Steps
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAttributeOutsideInit
|
||||||
|
class ItemFactory(BaseFactory):
|
||||||
|
def __init__(self, item_properties: ItemProperties, *args, **kwargs):
|
||||||
|
super(ItemFactory, self).__init__(*args, **kwargs)
|
||||||
|
self.item_properties = item_properties
|
||||||
|
|
||||||
|
@property
|
||||||
|
def additional_actions(self) -> Union[str, List[str]]:
|
||||||
|
return [ITEM_ACTION]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def additional_entities(self) -> Union[Entities, List[Entities]]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def additional_slices(self) -> Union[Slice, List[Slice]]:
|
||||||
|
return [Slice(ITEM, np.zeros(self._level_shape)), Slice(INVENTORY, np.zeros(self._level_shape))]
|
||||||
|
|
||||||
|
def _is_item_action(self, action):
|
||||||
|
if isinstance(action, str):
|
||||||
|
action = self._actions.by_name(action)
|
||||||
|
return self._actions[action].name == ITEM_ACTION
|
||||||
|
|
||||||
|
def do_item_action(self, agent):
|
||||||
|
item_slice = self._slices.by_name(ITEM).slice
|
||||||
|
if item := item_slice[agent.pos]:
|
||||||
|
if item == ITEM_DROP_OFF:
|
||||||
|
|
||||||
|
self._slices.by_name(INVENTORY).slice[agent.pos] = item
|
||||||
|
item_slice[agent.pos] = NO_ITEM
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def do_additional_actions(self, agent: Agent, action: int) -> bool:
|
||||||
|
if self._is_item_action(action):
|
||||||
|
valid = self.do_item_action(agent)
|
||||||
|
return valid
|
||||||
|
else:
|
||||||
|
raise RuntimeError('This should not happen!!!')
|
||||||
|
|
||||||
|
def do_additional_reset(self) -> None:
|
||||||
|
self.spawn_drop_off_location()
|
||||||
|
self.spawn_items(self.n_items)
|
||||||
|
if self.n_items > 1:
|
||||||
|
self._next_item_spawn = self.item_properties.spawn_frequency
|
||||||
|
|
||||||
|
def calculate_reward(self) -> (int, dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def render(self, mode='human'):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
@ -26,6 +26,16 @@ class DirtProperties(NamedTuple):
|
|||||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place
|
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(x):
|
||||||
|
"""Compute softmax values for each sets of scores in x."""
|
||||||
|
e_x = np.exp(x - np.max(x))
|
||||||
|
return e_x / e_x.sum()
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(x):
|
||||||
|
return -(x * np.log(x + 1e-8)).sum()
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
class SimpleFactory(BaseFactory):
|
class SimpleFactory(BaseFactory):
|
||||||
|
|
||||||
@ -46,9 +56,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
action = self._actions.by_name(action)
|
action = self._actions.by_name(action)
|
||||||
return self._actions[action].name == CLEAN_UP_ACTION
|
return self._actions[action].name == CLEAN_UP_ACTION
|
||||||
|
|
||||||
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), verbose=False, **kwargs):
|
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), **kwargs):
|
||||||
self.dirt_properties = dirt_properties
|
self.dirt_properties = dirt_properties
|
||||||
self.verbose = verbose
|
|
||||||
self._renderer = None # expensive - don't use it when not required !
|
self._renderer = None # expensive - don't use it when not required !
|
||||||
super(SimpleFactory, self).__init__(*args, **kwargs)
|
super(SimpleFactory, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
@ -108,8 +117,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
def clean_up(self, agent: Agent) -> bool:
|
def clean_up(self, agent: Agent) -> bool:
|
||||||
dirt_slice = self._slices.by_name(DIRT).slice
|
dirt_slice = self._slices.by_name(DIRT).slice
|
||||||
if dirt_slice[agent.pos]:
|
if old_dirt_amount := dirt_slice[agent.pos]:
|
||||||
new_dirt_amount = dirt_slice[agent.pos] - self.dirt_properties.clean_amount
|
new_dirt_amount = old_dirt_amount - self.dirt_properties.clean_amount
|
||||||
dirt_slice[agent.pos] = max(new_dirt_amount, c.FREE_CELL.value)
|
dirt_slice[agent.pos] = max(new_dirt_amount, c.FREE_CELL.value)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@ -135,14 +144,11 @@ class SimpleFactory(BaseFactory):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: int) -> bool:
|
def do_additional_actions(self, agent: Agent, action: int) -> bool:
|
||||||
if action != self._actions.is_moving_action(action):
|
if self._is_clean_up_action(action):
|
||||||
if self._is_clean_up_action(action):
|
valid = self.clean_up(agent)
|
||||||
valid = self.clean_up(agent)
|
return valid
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
raise RuntimeError('This should not happen!!!')
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('This should not happen!!!')
|
return c.NOT_VALID.value
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
def do_additional_reset(self) -> None:
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
@ -155,13 +161,18 @@ class SimpleFactory(BaseFactory):
|
|||||||
dirty_tiles = [dirt_slice[tile.pos] for tile in self._tiles if dirt_slice[tile.pos]]
|
dirty_tiles = [dirt_slice[tile.pos] for tile in self._tiles if dirt_slice[tile.pos]]
|
||||||
current_dirt_amount = sum(dirty_tiles)
|
current_dirt_amount = sum(dirty_tiles)
|
||||||
dirty_tile_count = len(dirty_tiles)
|
dirty_tile_count = len(dirty_tiles)
|
||||||
|
if dirty_tile_count:
|
||||||
|
dirt_distribution_score = entropy(softmax(dirt_slice)) / dirty_tile_count
|
||||||
|
else:
|
||||||
|
dirt_distribution_score = 0
|
||||||
|
|
||||||
info_dict.update(dirt_amount=current_dirt_amount)
|
info_dict.update(dirt_amount=current_dirt_amount)
|
||||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||||
|
info_dict.update(dirt_distribution_score=dirt_distribution_score)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# penalty = current_dirt_amount
|
# penalty = current_dirt_amount
|
||||||
reward = 0
|
reward = dirt_distribution_score
|
||||||
except (ZeroDivisionError, RuntimeWarning):
|
except (ZeroDivisionError, RuntimeWarning):
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
@ -213,10 +224,6 @@ class SimpleFactory(BaseFactory):
|
|||||||
# track the last reward , minus the current reward = potential
|
# track the last reward , minus the current reward = potential
|
||||||
return reward, info_dict
|
return reward, info_dict
|
||||||
|
|
||||||
def print(self, string):
|
|
||||||
if self.verbose:
|
|
||||||
print(string)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = True
|
||||||
|
8
main.py
8
main.py
@ -98,16 +98,16 @@ if __name__ == '__main__':
|
|||||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
allow_no_op=False)
|
allow_no_op=False)
|
||||||
train_steps = 5e6
|
train_steps = 2.5e6
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
|
|
||||||
for modeL_type in [A2C, PPO, RegDQN, DQN]: # , QRDQN]:
|
for modeL_type in [A2C, PPO, DQN]: # ,RegDQN, QRDQN]:
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
|
|
||||||
with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=4, max_steps=400, parse_doors=True,
|
with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=True,
|
||||||
movement_properties=move_props, level_name='rooms', frames_to_stack=0,
|
movement_properties=move_props, level_name='rooms', frames_to_stack=3,
|
||||||
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False,
|
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False,
|
||||||
cast_shadows=True,
|
cast_shadows=True,
|
||||||
) as env:
|
) as env:
|
||||||
|
@ -14,7 +14,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
model_name = 'A2C_1627393138'
|
model_name = 'A2C_1627491061'
|
||||||
run_id = 0
|
run_id = 0
|
||||||
out_path = Path(__file__).parent / 'debug_out'
|
out_path = Path(__file__).parent / 'debug_out'
|
||||||
model_path = out_path / model_name
|
model_path = out_path / model_name
|
||||||
|
Loading…
x
Reference in New Issue
Block a user