mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
Items
This commit is contained in:
parent
b0aeb6f94f
commit
8631f11502
@ -87,20 +87,22 @@ class BaseFactory(gym.Env):
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_r: Union[None, int] = 0,
|
||||
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
||||
combin_agent_slices_in_obs: bool = False, frames_to_stack=0, record_episodes=False,
|
||||
omit_agent_slice_in_obs=False, done_at_collision=False, cast_shadows=True, **kwargs):
|
||||
omit_agent_slice_in_obs=False, done_at_collision=False, cast_shadows=True,
|
||||
verbose=False, **kwargs):
|
||||
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||
|
||||
# Attribute Assignment
|
||||
self.movement_properties = movement_properties
|
||||
self.level_name = level_name
|
||||
self._level_shape = None
|
||||
self.verbose = verbose
|
||||
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
self.pomdp_r = pomdp_radius
|
||||
self.pomdp_r = pomdp_r
|
||||
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||
self.cast_shadows = cast_shadows
|
||||
@ -115,6 +117,7 @@ class BaseFactory(gym.Env):
|
||||
if additional_actions := self.additional_actions:
|
||||
self._actions.register_additional_items(additional_actions)
|
||||
|
||||
# Reset
|
||||
self.reset()
|
||||
|
||||
def _init_state_slices(self) -> StateSlices:
|
||||
@ -345,7 +348,7 @@ class BaseFactory(gym.Env):
|
||||
else:
|
||||
return obs
|
||||
|
||||
def do_additional_actions(self, agent_i: int, action: int) -> bool:
|
||||
def do_additional_actions(self, agent: Agent, action: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
||||
@ -418,3 +421,7 @@ class BaseFactory(gym.Env):
|
||||
if hasattr(entity, 'summarize_state'):
|
||||
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
if self.verbose:
|
||||
print(string)
|
||||
|
84
environments/factory/item_pickup.py
Normal file
84
environments/factory/item_pickup.py
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import List, Union, NamedTuple
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.helpers import Constants as c
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Object, Slice
|
||||
from environments.factory.base.registers import Entities
|
||||
|
||||
from environments.factory.renderer import Renderer, Entity
|
||||
from environments.utility_classes import MovementProperties
|
||||
|
||||
|
||||
|
||||
ITEM = 'item'
|
||||
INVENTORY = 'inventory'
|
||||
PICK_UP = 'pick_up'
|
||||
DROP_DOWN = 'drop_down'
|
||||
ITEM_ACTION = 'item_action'
|
||||
NO_ITEM = 0
|
||||
ITEM_DROP_OFF = -1
|
||||
|
||||
|
||||
class ItemProperties(NamedTuple):
|
||||
n_items: int = 1 # How much does the robot clean with one actions.
|
||||
spawn_frequency: int = 5 # Spawn Frequency in Steps
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class ItemFactory(BaseFactory):
|
||||
def __init__(self, item_properties: ItemProperties, *args, **kwargs):
|
||||
super(ItemFactory, self).__init__(*args, **kwargs)
|
||||
self.item_properties = item_properties
|
||||
|
||||
@property
|
||||
def additional_actions(self) -> Union[str, List[str]]:
|
||||
return [ITEM_ACTION]
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Union[Entities, List[Entities]]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def additional_slices(self) -> Union[Slice, List[Slice]]:
|
||||
return [Slice(ITEM, np.zeros(self._level_shape)), Slice(INVENTORY, np.zeros(self._level_shape))]
|
||||
|
||||
def _is_item_action(self, action):
|
||||
if isinstance(action, str):
|
||||
action = self._actions.by_name(action)
|
||||
return self._actions[action].name == ITEM_ACTION
|
||||
|
||||
def do_item_action(self, agent):
|
||||
item_slice = self._slices.by_name(ITEM).slice
|
||||
if item := item_slice[agent.pos]:
|
||||
if item == ITEM_DROP_OFF:
|
||||
|
||||
self._slices.by_name(INVENTORY).slice[agent.pos] = item
|
||||
item_slice[agent.pos] = NO_ITEM
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: int) -> bool:
|
||||
if self._is_item_action(action):
|
||||
valid = self.do_item_action(agent)
|
||||
return valid
|
||||
else:
|
||||
raise RuntimeError('This should not happen!!!')
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
self.spawn_drop_off_location()
|
||||
self.spawn_items(self.n_items)
|
||||
if self.n_items > 1:
|
||||
self._next_item_spawn = self.item_properties.spawn_frequency
|
||||
|
||||
def calculate_reward(self) -> (int, dict):
|
||||
pass
|
||||
|
||||
def render(self, mode='human'):
|
||||
pass
|
||||
|
||||
|
@ -26,6 +26,16 @@ class DirtProperties(NamedTuple):
|
||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place
|
||||
|
||||
|
||||
def softmax(x):
|
||||
"""Compute softmax values for each sets of scores in x."""
|
||||
e_x = np.exp(x - np.max(x))
|
||||
return e_x / e_x.sum()
|
||||
|
||||
|
||||
def entropy(x):
|
||||
return -(x * np.log(x + 1e-8)).sum()
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class SimpleFactory(BaseFactory):
|
||||
|
||||
@ -46,9 +56,8 @@ class SimpleFactory(BaseFactory):
|
||||
action = self._actions.by_name(action)
|
||||
return self._actions[action].name == CLEAN_UP_ACTION
|
||||
|
||||
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), verbose=False, **kwargs):
|
||||
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), **kwargs):
|
||||
self.dirt_properties = dirt_properties
|
||||
self.verbose = verbose
|
||||
self._renderer = None # expensive - don't use it when not required !
|
||||
super(SimpleFactory, self).__init__(*args, **kwargs)
|
||||
|
||||
@ -108,8 +117,8 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
def clean_up(self, agent: Agent) -> bool:
|
||||
dirt_slice = self._slices.by_name(DIRT).slice
|
||||
if dirt_slice[agent.pos]:
|
||||
new_dirt_amount = dirt_slice[agent.pos] - self.dirt_properties.clean_amount
|
||||
if old_dirt_amount := dirt_slice[agent.pos]:
|
||||
new_dirt_amount = old_dirt_amount - self.dirt_properties.clean_amount
|
||||
dirt_slice[agent.pos] = max(new_dirt_amount, c.FREE_CELL.value)
|
||||
return True
|
||||
else:
|
||||
@ -135,14 +144,11 @@ class SimpleFactory(BaseFactory):
|
||||
return {}
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: int) -> bool:
|
||||
if action != self._actions.is_moving_action(action):
|
||||
if self._is_clean_up_action(action):
|
||||
valid = self.clean_up(agent)
|
||||
return valid
|
||||
else:
|
||||
raise RuntimeError('This should not happen!!!')
|
||||
if self._is_clean_up_action(action):
|
||||
valid = self.clean_up(agent)
|
||||
return valid
|
||||
else:
|
||||
raise RuntimeError('This should not happen!!!')
|
||||
return c.NOT_VALID.value
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
self.spawn_dirt()
|
||||
@ -155,13 +161,18 @@ class SimpleFactory(BaseFactory):
|
||||
dirty_tiles = [dirt_slice[tile.pos] for tile in self._tiles if dirt_slice[tile.pos]]
|
||||
current_dirt_amount = sum(dirty_tiles)
|
||||
dirty_tile_count = len(dirty_tiles)
|
||||
if dirty_tile_count:
|
||||
dirt_distribution_score = entropy(softmax(dirt_slice)) / dirty_tile_count
|
||||
else:
|
||||
dirt_distribution_score = 0
|
||||
|
||||
info_dict.update(dirt_amount=current_dirt_amount)
|
||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||
info_dict.update(dirt_distribution_score=dirt_distribution_score)
|
||||
|
||||
try:
|
||||
# penalty = current_dirt_amount
|
||||
reward = 0
|
||||
reward = dirt_distribution_score
|
||||
except (ZeroDivisionError, RuntimeWarning):
|
||||
reward = 0
|
||||
|
||||
@ -213,10 +224,6 @@ class SimpleFactory(BaseFactory):
|
||||
# track the last reward , minus the current reward = potential
|
||||
return reward, info_dict
|
||||
|
||||
def print(self, string):
|
||||
if self.verbose:
|
||||
print(string)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
|
8
main.py
8
main.py
@ -98,16 +98,16 @@ if __name__ == '__main__':
|
||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
train_steps = 5e6
|
||||
train_steps = 2.5e6
|
||||
time_stamp = int(time.time())
|
||||
|
||||
out_path = None
|
||||
|
||||
for modeL_type in [A2C, PPO, RegDQN, DQN]: # , QRDQN]:
|
||||
for modeL_type in [A2C, PPO, DQN]: # ,RegDQN, QRDQN]:
|
||||
for seed in range(3):
|
||||
|
||||
with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=4, max_steps=400, parse_doors=True,
|
||||
movement_properties=move_props, level_name='rooms', frames_to_stack=0,
|
||||
with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=True,
|
||||
movement_properties=move_props, level_name='rooms', frames_to_stack=3,
|
||||
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False,
|
||||
cast_shadows=True,
|
||||
) as env:
|
||||
|
@ -14,7 +14,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model_name = 'A2C_1627393138'
|
||||
model_name = 'A2C_1627491061'
|
||||
run_id = 0
|
||||
out_path = Path(__file__).parent / 'debug_out'
|
||||
model_path = out_path / model_name
|
||||
|
Loading…
x
Reference in New Issue
Block a user