Rework for performance
This commit is contained in:
parent
78bf19f7f4
commit
435056f373
@ -13,6 +13,8 @@ from gym.wrappers import FrameStack
|
||||
from environments.factory.base.shadow_casting import Map
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
from environments.helpers import EnvActions as a
|
||||
from environments.helpers import Rewards as r
|
||||
from environments.factory.base.objects import Agent, Tile, Action
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \
|
||||
GlobalPositions
|
||||
@ -205,8 +207,9 @@ class BaseFactory(gym.Env):
|
||||
|
||||
if self.obs_prop.show_global_position_info:
|
||||
global_positions = GlobalPositions(self._level_shape)
|
||||
obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||
global_positions.spawn_global_position_objects(obs_shape_2d, self[c.AGENT])
|
||||
# This moved into the GlobalPosition object
|
||||
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||
global_positions.spawn_global_position_objects(self[c.AGENT])
|
||||
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
|
||||
# Return
|
||||
@ -232,37 +235,51 @@ class BaseFactory(gym.Env):
|
||||
# Pre step Hook for later use
|
||||
self.hook_pre_step()
|
||||
|
||||
# Move this in a seperate function?
|
||||
for action, agent in zip(actions, self[c.AGENT]):
|
||||
agent.clear_temp_state()
|
||||
action_obj = self._actions[int(action)]
|
||||
step_result = dict(collisions=[], rewards=[], info={}, action_name='', action_valid=False)
|
||||
# cls.print(f'Action #{action} has been resolved to: {action_obj}')
|
||||
if h.EnvActions.is_move(action_obj):
|
||||
valid = self._move_or_colide(agent, action_obj)
|
||||
elif h.EnvActions.NOOP == agent.temp_action:
|
||||
valid = c.VALID
|
||||
elif h.EnvActions.USE_DOOR == action_obj:
|
||||
valid = self._handle_door_interaction(agent)
|
||||
if a.is_move(action_obj):
|
||||
action_valid, reward = self._do_move_action(agent, action_obj)
|
||||
elif a.NOOP == action_obj:
|
||||
action_valid = c.VALID
|
||||
reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.pos}_NOOP': 1})
|
||||
elif a.USE_DOOR == action_obj:
|
||||
action_valid, reward = self._handle_door_interaction(agent)
|
||||
else:
|
||||
valid = self.do_additional_actions(agent, action_obj)
|
||||
assert valid is not None, 'This should not happen, every Action musst be detected correctly!'
|
||||
agent.temp_action = action_obj
|
||||
agent.temp_valid = valid
|
||||
|
||||
# In-between step Hook for later use
|
||||
info = self.do_additional_step()
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
action_valid, reward = self.do_additional_actions(agent, action_obj)
|
||||
# Not needed any more sice the tuple assignment above will fail in case of a failing action resolvement.
|
||||
# assert step_result is not None, 'This should not happen, every Action musst be detected correctly!'
|
||||
step_result['action_name'] = action_obj.identifier
|
||||
step_result['action_valid'] = action_valid
|
||||
step_result['rewards'].append(reward)
|
||||
agent.step_result = step_result
|
||||
|
||||
# Additional step and Reward, Info Init
|
||||
rewards, info = self.do_additional_step()
|
||||
# Todo: Make this faster, so that only tiles of entities that can collide are searched.
|
||||
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
||||
for tile in tiles_with_collisions:
|
||||
guests = tile.guests_that_can_collide
|
||||
for i, guest in enumerate(guests):
|
||||
# This does make a copy, but is faster than.copy()
|
||||
this_collisions = guests[:]
|
||||
del this_collisions[i]
|
||||
guest.temp_collisions = this_collisions
|
||||
assert hasattr(guest, 'step_result')
|
||||
for collision in this_collisions:
|
||||
guest.step_result['collisions'].append(collision)
|
||||
|
||||
done = self.done_at_collision and tiles_with_collisions
|
||||
done = False
|
||||
if self.done_at_collision:
|
||||
if done_at_col := bool(tiles_with_collisions):
|
||||
done = done_at_col
|
||||
info.update(COLLISION_DONE=done_at_col)
|
||||
|
||||
done = done or self.check_additional_done()
|
||||
additional_done, additional_done_info = self.check_additional_done()
|
||||
done = done or additional_done
|
||||
info.update(additional_done_info)
|
||||
|
||||
# Step the door close intervall
|
||||
if self.parse_doors:
|
||||
@ -270,7 +287,8 @@ class BaseFactory(gym.Env):
|
||||
doors.tick_doors()
|
||||
|
||||
# Finalize
|
||||
reward, reward_info = self.calculate_reward()
|
||||
reward, reward_info = self.build_reward_result()
|
||||
|
||||
info.update(reward_info)
|
||||
if self._steps >= self.max_steps:
|
||||
done = True
|
||||
@ -285,7 +303,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
return obs, reward, done, info
|
||||
|
||||
def _handle_door_interaction(self, agent) -> c:
|
||||
def _handle_door_interaction(self, agent) -> (bool, dict):
|
||||
if doors := self[c.DOORS]:
|
||||
# Check if agent really is standing on a door:
|
||||
if self.doors_have_area:
|
||||
@ -294,12 +312,21 @@ class BaseFactory(gym.Env):
|
||||
door = doors.by_pos(agent.pos)
|
||||
if door is not None:
|
||||
door.use()
|
||||
return c.VALID
|
||||
valid = c.VALID
|
||||
self.print(f'{agent.name} just used a door {door.name}')
|
||||
info_dict = {f'{agent.name}_door_use_{door.name}': 1}
|
||||
# When he doesn't...
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
valid = c.NOT_VALID
|
||||
info_dict = {f'{agent.name}_failed_door_use': 1}
|
||||
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but there is none.')
|
||||
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
raise RuntimeError('This should not happen, since the door action should not be available.')
|
||||
reward = dict(value=r.USE_DOOR_VALID if valid else r.USE_DOOR_FAIL,
|
||||
reason=a.USE_DOOR, info=info_dict)
|
||||
|
||||
return valid, reward
|
||||
|
||||
def _build_observations(self) -> np.typing.ArrayLike:
|
||||
# Observation dict:
|
||||
@ -308,7 +335,7 @@ class BaseFactory(gym.Env):
|
||||
# Generel Observations
|
||||
lvl_obs = self[c.WALLS].as_array()
|
||||
door_obs = self[c.DOORS].as_array()
|
||||
agent_obs = self[c.AGENT].as_array() if self.obs_prop.render_agents != a_obs.NOT else None
|
||||
global_agent_obs = self[c.AGENT].as_array() if self.obs_prop.render_agents != a_obs.NOT else None
|
||||
placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None
|
||||
add_obs_dict = self._additional_observations()
|
||||
|
||||
@ -318,15 +345,20 @@ class BaseFactory(gym.Env):
|
||||
if self.obs_prop.render_agents != a_obs.NOT:
|
||||
if self.obs_prop.omit_agent_self:
|
||||
if self.obs_prop.render_agents == a_obs.SEPERATE:
|
||||
agent_obs = np.take(agent_obs, [x for x in range(self.n_agents) if x != agent_idx], axis=0)
|
||||
other_agent_obs_idx = [x for x in range(self.n_agents) if x != agent_idx]
|
||||
agent_obs = np.take(global_agent_obs, other_agent_obs_idx, axis=0)
|
||||
else:
|
||||
agent_obs = agent_obs.copy()
|
||||
agent_obs = global_agent_obs.copy()
|
||||
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||
else:
|
||||
agent_obs = global_agent_obs
|
||||
else:
|
||||
agent_obs = global_agent_obs
|
||||
|
||||
# Build Level Observations
|
||||
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||
lvl_obs = lvl_obs.copy()
|
||||
lvl_obs += agent_obs
|
||||
lvl_obs += global_agent_obs
|
||||
|
||||
obs_dict[c.WALLS] = lvl_obs
|
||||
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED]:
|
||||
@ -340,11 +372,12 @@ class BaseFactory(gym.Env):
|
||||
obsn = self._do_pomdp_cutout(agent, obsn)
|
||||
|
||||
raw_obs = self._additional_per_agent_raw_observations(agent)
|
||||
obsn = np.vstack((obsn, *list(raw_obs.values())))
|
||||
raw_obs = {key: np.expand_dims(val, 0) if val.ndim != 3 else val for key, val in raw_obs.items()}
|
||||
obsn = np.vstack((obsn, *raw_obs.values()))
|
||||
|
||||
keys = list(chain(obs_dict.keys(), raw_obs.keys()))
|
||||
idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1
|
||||
per_agent_expl_idx[agent.name] = {key: list(range(a, b)) for key, a, b in
|
||||
per_agent_expl_idx[agent.name] = {key: list(range(d, b)) for key, d, b in
|
||||
zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])}
|
||||
|
||||
# Shadow Casting
|
||||
@ -390,7 +423,13 @@ class BaseFactory(gym.Env):
|
||||
if door_shadowing:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
light_block_map[xs, ys] = 0
|
||||
agent.temp_light_map = light_block_map.copy()
|
||||
if agent.step_result:
|
||||
agent.step_result['lightmap'] = light_block_map
|
||||
pass
|
||||
else:
|
||||
assert self._steps == 0
|
||||
agent.step_result = {'action_name': a.NOOP, 'action_valid': True,
|
||||
'collisions': [], 'lightmap': light_block_map}
|
||||
|
||||
obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
|
||||
else:
|
||||
@ -410,27 +449,27 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def _do_pomdp_cutout(self, agent, obs_to_be_padded):
|
||||
assert obs_to_be_padded.ndim == 3
|
||||
r, d = self._pomdp_r, self.pomdp_diameter
|
||||
x0, x1 = max(0, agent.x - r), min(agent.x + r + 1, self._level_shape[0])
|
||||
y0, y1 = max(0, agent.y - r), min(agent.y + r + 1, self._level_shape[1])
|
||||
ra, d = self._pomdp_r, self.pomdp_diameter
|
||||
x0, x1 = max(0, agent.x - ra), min(agent.x + ra + 1, self._level_shape[0])
|
||||
y0, y1 = max(0, agent.y - ra), min(agent.y + ra + 1, self._level_shape[1])
|
||||
oobs = obs_to_be_padded[:, x0:x1, y0:y1]
|
||||
if oobs.shape[1:] != (d, d):
|
||||
if xd := oobs.shape[1] % d:
|
||||
if agent.x > r:
|
||||
if agent.x > ra:
|
||||
x0_pad = 0
|
||||
x1_pad = (d - xd)
|
||||
else:
|
||||
x0_pad = r - agent.x
|
||||
x0_pad = ra - agent.x
|
||||
x1_pad = 0
|
||||
else:
|
||||
x0_pad, x1_pad = 0, 0
|
||||
|
||||
if yd := oobs.shape[2] % d:
|
||||
if agent.y > r:
|
||||
if agent.y > ra:
|
||||
y0_pad = 0
|
||||
y1_pad = (d - yd)
|
||||
else:
|
||||
y0_pad = r - agent.y
|
||||
y0_pad = ra - agent.y
|
||||
y1_pad = 0
|
||||
else:
|
||||
y0_pad, y1_pad = 0, 0
|
||||
@ -439,22 +478,39 @@ class BaseFactory(gym.Env):
|
||||
return oobs
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
||||
tiles_with_collisions = list()
|
||||
for tile in self[c.FLOOR]:
|
||||
if tile.is_occupied():
|
||||
guests = tile.guests_that_can_collide
|
||||
if len(guests) >= 2:
|
||||
tiles_with_collisions.append(tile)
|
||||
return tiles_with_collisions
|
||||
tiles = [x.tile for y in self._entities for x in y if
|
||||
y.can_collide and not isinstance(y, WallTiles) and x.can_collide and len(x.tile.guests) > 1]
|
||||
if False:
|
||||
tiles_with_collisions = list()
|
||||
for tile in self[c.FLOOR]:
|
||||
if tile.is_occupied():
|
||||
guests = tile.guests_that_can_collide
|
||||
if len(guests) >= 2:
|
||||
tiles_with_collisions.append(tile)
|
||||
return tiles
|
||||
|
||||
def _move_or_colide(self, agent: Agent, action: Action) -> bool:
|
||||
def _do_move_action(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
info_dict = dict()
|
||||
new_tile, valid = self._check_agent_move(agent, action)
|
||||
if valid:
|
||||
# Does not collide width level boundaries
|
||||
return agent.move(new_tile)
|
||||
valid = agent.move(new_tile)
|
||||
if valid:
|
||||
# This will spam your logs, beware!
|
||||
# self.print(f'{agent.name} just moved from {agent.last_pos} to {agent.pos}.')
|
||||
# info_dict.update({f'{agent.pos}_move': 1})
|
||||
pass
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
||||
info_dict.update({f'{agent.pos}_wall_collide': 1})
|
||||
else:
|
||||
# Agent seems to be trying to collide in this step
|
||||
return c.NOT_VALID
|
||||
# Agent seems to be trying to Leave the level
|
||||
self.print(f'{agent.name} tried to leave the level {agent.pos}.')
|
||||
info_dict.update({f'{agent.pos}_wall_collide': 1})
|
||||
reward_value = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
||||
reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict}
|
||||
return valid, reward
|
||||
|
||||
def _check_agent_move(self, agent, action: Action) -> (Tile, bool):
|
||||
# Actions
|
||||
@ -474,7 +530,7 @@ class BaseFactory(gym.Env):
|
||||
if doors := self[c.DOORS]:
|
||||
if self.doors_have_area:
|
||||
if door := doors.by_pos(new_tile.pos):
|
||||
if door.is_open:
|
||||
if door.is_closed:
|
||||
return agent.tile, c.NOT_VALID
|
||||
else: # door.is_closed:
|
||||
pass
|
||||
@ -494,69 +550,46 @@ class BaseFactory(gym.Env):
|
||||
|
||||
return new_tile, valid
|
||||
|
||||
def calculate_reward(self) -> (int, dict):
|
||||
@abc.abstractmethod
|
||||
def additional_per_agent_rewards(self, agent) -> List[dict]:
|
||||
return []
|
||||
|
||||
def build_reward_result(self) -> (int, dict):
|
||||
# Returns: Reward, Info
|
||||
per_agent_info_dict = defaultdict(dict)
|
||||
reward = {}
|
||||
info = defaultdict(lambda: 0.0)
|
||||
|
||||
# Gather additional sub-env rewards and calculate collisions
|
||||
for agent in self[c.AGENT]:
|
||||
per_agent_reward = 0
|
||||
if self._actions.is_moving_action(agent.temp_action):
|
||||
if agent.temp_valid:
|
||||
# info_dict.update(movement=1)
|
||||
per_agent_reward -= 0.001
|
||||
pass
|
||||
else:
|
||||
per_agent_reward -= 0.05
|
||||
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_vs_LEVEL': 1})
|
||||
|
||||
elif h.EnvActions.USE_DOOR == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
# per_agent_reward += 0.00
|
||||
self.print(f'{agent.name} did just use the door at {agent.pos}.')
|
||||
per_agent_info_dict[agent.name].update(door_used=1)
|
||||
else:
|
||||
# per_agent_reward -= 0.00
|
||||
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.')
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_failed_door_open': 1})
|
||||
elif h.EnvActions.NOOP == agent.temp_action:
|
||||
per_agent_info_dict[agent.name].update(no_op=1)
|
||||
# per_agent_reward -= 0.00
|
||||
|
||||
# EnvMonitor Notes
|
||||
if agent.temp_valid:
|
||||
per_agent_info_dict[agent.name].update(valid_action=1)
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_valid_action': 1})
|
||||
rewards = self.additional_per_agent_rewards(agent)
|
||||
for reward in rewards:
|
||||
agent.step_result['rewards'].append(reward)
|
||||
if collisions := agent.step_result['collisions']:
|
||||
self.print(f't = {self._steps}\t{agent.name} has collisions with {collisions}')
|
||||
info[c.COLLISION] += 1
|
||||
reward = {'value': r.COLLISION, 'reason': c.COLLISION, 'info': {f'{agent.name}_{c.COLLISION}': 1}}
|
||||
agent.step_result['rewards'].append(reward)
|
||||
else:
|
||||
per_agent_info_dict[agent.name].update(failed_action=1)
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_failed_action': 1})
|
||||
# No Collisions, nothing to do
|
||||
pass
|
||||
|
||||
additional_reward, additional_info_dict = self.calculate_additional_reward(agent)
|
||||
per_agent_reward += additional_reward
|
||||
per_agent_info_dict[agent.name].update(additional_info_dict)
|
||||
|
||||
if agent.temp_collisions:
|
||||
self.print(f't = {self._steps}\t{agent.name} has collisions with {agent.temp_collisions}')
|
||||
per_agent_info_dict[agent.name].update(collisions=1)
|
||||
|
||||
for other_agent in agent.temp_collisions:
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_vs_{other_agent.name}': 1})
|
||||
reward[agent.name] = per_agent_reward
|
||||
comb_rewards = {agent.name: sum(x['value'] for x in agent.step_result['rewards']) for agent in self[c.AGENT]}
|
||||
|
||||
# Combine the per_agent_info_dict:
|
||||
combined_info_dict = defaultdict(lambda: 0)
|
||||
for info_dict in per_agent_info_dict.values():
|
||||
for key, value in info_dict.items():
|
||||
combined_info_dict[key] += value
|
||||
for agent in self[c.AGENT]:
|
||||
for reward in agent.step_result['rewards']:
|
||||
combined_info_dict.update(reward['info'])
|
||||
|
||||
combined_info_dict = dict(combined_info_dict)
|
||||
combined_info_dict.update(info)
|
||||
|
||||
if self.individual_rewards:
|
||||
self.print(f"rewards are {reward}")
|
||||
reward = list(reward.values())
|
||||
self.print(f"rewards are {comb_rewards}")
|
||||
reward = list(comb_rewards.values())
|
||||
return reward, combined_info_dict
|
||||
else:
|
||||
reward = sum(reward.values())
|
||||
reward = sum(comb_rewards.values())
|
||||
self.print(f"reward is {reward}")
|
||||
return reward, combined_info_dict
|
||||
|
||||
@ -574,7 +607,7 @@ class BaseFactory(gym.Env):
|
||||
agents = []
|
||||
for i, agent in enumerate(self[c.AGENT]):
|
||||
name, state = h.asset_str(agent)
|
||||
agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.temp_light_map))
|
||||
agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.step_result['lightmap']))
|
||||
doors = []
|
||||
if self.parse_doors:
|
||||
for i, door in enumerate(self[c.DOORS]):
|
||||
@ -637,16 +670,16 @@ class BaseFactory(gym.Env):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def do_additional_step(self) -> dict:
|
||||
return {}
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
return [], {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_additional_done(self) -> bool:
|
||||
return False
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
return False, {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
@ -660,8 +693,8 @@ class BaseFactory(gym.Env):
|
||||
return additional_raw_observations
|
||||
|
||||
@abc.abstractmethod
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
return 0, {}
|
||||
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def render_additional_assets(self):
|
||||
|
@ -33,7 +33,7 @@ class Object:
|
||||
else:
|
||||
return self._name
|
||||
|
||||
def __init__(self, str_ident: Union[str, None] = None, is_blocking_light=False, **kwargs):
|
||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||
|
||||
self._str_ident = str_ident
|
||||
|
||||
@ -45,7 +45,6 @@ class Object:
|
||||
else:
|
||||
raise ValueError('Please use either of the idents.')
|
||||
|
||||
self._is_blocking_light = is_blocking_light
|
||||
if kwargs:
|
||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||
|
||||
@ -62,6 +61,10 @@ class EnvObject(Object):
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.OCCUPIED_CELL
|
||||
@ -71,7 +74,10 @@ class EnvObject(Object):
|
||||
self._register = register
|
||||
|
||||
def change_register(self, register):
|
||||
register.register_item(self)
|
||||
self._register.delete_env_object(self)
|
||||
self._register = register
|
||||
return self._register == register
|
||||
|
||||
|
||||
class BoundingMixin(Object):
|
||||
@ -85,11 +91,6 @@ class BoundingMixin(Object):
|
||||
assert entity_to_be_bound is not None
|
||||
self._bound_entity = entity_to_be_bound
|
||||
|
||||
def __repr__(self):
|
||||
s = super(BoundingMixin, self).__repr__()
|
||||
i = s[:s.find('(')]
|
||||
return f'{s[:i]}[{self.bound_entity.name}]{s[i:]}'
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
|
||||
@ -101,13 +102,9 @@ class BoundingMixin(Object):
|
||||
class Entity(EnvObject):
|
||||
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return self._is_blocking_light
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
@ -125,10 +122,9 @@ class Entity(EnvObject):
|
||||
def tile(self):
|
||||
return self._tile
|
||||
|
||||
def __init__(self, tile, *args, is_blocking_light=True, **kwargs):
|
||||
def __init__(self, tile, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._tile = tile
|
||||
self._is_blocking_light = is_blocking_light
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self, **_) -> dict:
|
||||
@ -170,9 +166,9 @@ class MoveableEntity(Entity):
|
||||
self._tile = next_tile
|
||||
self._last_tile = curr_tile
|
||||
self._register.notify_change_to_value(self)
|
||||
return True
|
||||
return c.VALID
|
||||
else:
|
||||
return False
|
||||
return c.NOT_VALID
|
||||
|
||||
|
||||
##########################################################################
|
||||
@ -284,6 +280,10 @@ class Tile(EnvObject):
|
||||
|
||||
class Wall(Tile):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.OCCUPIED_CELL
|
||||
@ -381,6 +381,10 @@ class Door(Entity):
|
||||
|
||||
class Agent(MoveableEntity):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Agent, self).__init__(*args, **kwargs)
|
||||
self.clear_temp_state()
|
||||
@ -389,12 +393,9 @@ class Agent(MoveableEntity):
|
||||
def clear_temp_state(self):
|
||||
# for attr in cls.__dict__:
|
||||
# if attr.startswith('temp'):
|
||||
self.temp_collisions = []
|
||||
self.temp_valid = None
|
||||
self.temp_action = None
|
||||
self.temp_light_map = None
|
||||
self.step_result = None
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
|
||||
state_dict.update(valid=bool(self.temp_action_result['valid']), action=str(self.temp_action_result['action']))
|
||||
return state_dict
|
||||
|
@ -85,19 +85,27 @@ class EnvObjectRegister(ObjectRegister):
|
||||
def encodings(self):
|
||||
return [x.encoding for x in self]
|
||||
|
||||
def __init__(self, obs_shape: (int, int), *args, individual_slices: bool = False, **kwargs):
|
||||
def __init__(self, obs_shape: (int, int), *args,
|
||||
individual_slices: bool = False,
|
||||
is_blocking_light: bool = False,
|
||||
can_collide: bool = False,
|
||||
can_be_shadowed: bool = True, **kwargs):
|
||||
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
||||
self._shape = obs_shape
|
||||
self._array = None
|
||||
self._individual_slices = individual_slices
|
||||
self._lazy_eval_transforms = []
|
||||
self.is_blocking_light = is_blocking_light
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.can_collide = can_collide
|
||||
|
||||
def register_item(self, other: EnvObject):
|
||||
super(EnvObjectRegister, self).register_item(other)
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
if self._individual_slices:
|
||||
self._array = np.vstack((self._array, np.zeros((1, *self._shape))))
|
||||
else:
|
||||
if self._individual_slices:
|
||||
self._array = np.vstack((self._array, np.zeros((1, *self._shape))))
|
||||
self.notify_change_to_value(other)
|
||||
|
||||
def as_array(self):
|
||||
@ -179,14 +187,9 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
def tiles(self):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __init__(self, level_shape, *args,
|
||||
is_blocking_light: bool = False,
|
||||
can_be_shadowed: bool = True,
|
||||
**kwargs):
|
||||
def __init__(self, level_shape, *args, **kwargs):
|
||||
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
||||
self._lazy_eval_transforms = []
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.is_blocking_light = is_blocking_light
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
@ -220,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
return None
|
||||
|
||||
|
||||
class BoundRegisterMixin(EnvObjectRegister, ABC):
|
||||
class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||
|
||||
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -229,6 +232,21 @@ class BoundRegisterMixin(EnvObjectRegister, ABC):
|
||||
def belongs_to_entity(self, entity):
|
||||
return self._bound_entity == entity
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def as_array_by_entity(self, entity):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
|
||||
@ -255,6 +273,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
can_collide = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
@ -360,7 +379,6 @@ class Entities(ObjectRegister):
|
||||
|
||||
class WallTiles(EntityRegister):
|
||||
_accepted_objects = Wall
|
||||
_light_blocking = True
|
||||
|
||||
def as_array(self):
|
||||
if not np.any(self._array):
|
||||
@ -371,9 +389,10 @@ class WallTiles(EntityRegister):
|
||||
self._array[0, x, y] = self._value
|
||||
return self._array
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WallTiles, self).__init__(*args, is_blocking_light=self._light_blocking, individual_slices=False,
|
||||
**kwargs)
|
||||
def __init__(self, *args, is_blocking_light=True, **kwargs):
|
||||
super(WallTiles, self).__init__(*args, individual_slices=False,
|
||||
can_collide=True,
|
||||
is_blocking_light=is_blocking_light, **kwargs)
|
||||
self._value = c.OCCUPIED_CELL
|
||||
|
||||
@classmethod
|
||||
@ -381,7 +400,7 @@ class WallTiles(EntityRegister):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.register_additional_items(
|
||||
[cls._accepted_objects(pos, tiles, is_blocking_light=cls._light_blocking)
|
||||
[cls._accepted_objects(pos, tiles)
|
||||
for pos in argwhere_coordinates]
|
||||
)
|
||||
return tiles
|
||||
@ -399,10 +418,9 @@ class WallTiles(EntityRegister):
|
||||
|
||||
class FloorTiles(WallTiles):
|
||||
_accepted_objects = Tile
|
||||
_light_blocking = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FloorTiles, self).__init__(*args, **kwargs)
|
||||
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
||||
super(FloorTiles, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||
self._value = c.FREE_CELL
|
||||
|
||||
@property
|
||||
@ -430,7 +448,7 @@ class Agents(MovingEntityObjectRegister):
|
||||
_accepted_objects = Agent
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, can_collide=True, **kwargs)
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
@ -446,7 +464,7 @@ class Agents(MovingEntityObjectRegister):
|
||||
class Doors(EntityRegister):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
|
||||
super(Doors, self).__init__(*args, is_blocking_light=True, can_collide=True, **kwargs)
|
||||
|
||||
_accepted_objects = Door
|
||||
|
||||
|
@ -2,6 +2,7 @@ import numpy as np
|
||||
|
||||
from environments.helpers import Constants as c
|
||||
|
||||
# Multipliers for transforming coordinates to other octants:
|
||||
mult_array = np.asarray([
|
||||
[1, 0, 0, -1, -1, 0, 0, 1],
|
||||
[0, 1, -1, 0, 0, -1, 1, 0],
|
||||
@ -11,8 +12,6 @@ mult_array = np.asarray([
|
||||
|
||||
|
||||
class Map(object):
|
||||
# Multipliers for transforming coordinates to other octants:
|
||||
|
||||
def __init__(self, map_array: np.typing.ArrayLike, diamond_slope: float = 0.9):
|
||||
self.data = map_array
|
||||
self.width, self.height = map_array.shape
|
||||
@ -33,7 +32,7 @@ class Map(object):
|
||||
self.light[x, y] = self.flag
|
||||
|
||||
def _cast_light(self, cx, cy, row, start, end, radius, xx, xy, yx, yy, id):
|
||||
"Recursive lightcasting function"
|
||||
"""Recursive lightcasting function"""
|
||||
if start < end:
|
||||
return
|
||||
radius_squared = radius*radius
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Union, NamedTuple, Dict
|
||||
from typing import Union, NamedTuple, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -6,13 +6,29 @@ from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
CHARGE_ACTION = h.EnvActions.CHARGE
|
||||
CHARGE_POD = 1
|
||||
class Constants(BaseConstants):
|
||||
# Battery Env
|
||||
CHARGE_PODS = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||
CHARGE_POD = 1
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CHARGE = 'do_charge_action'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
CHARGE_VALID = 0.1
|
||||
CHARGE_FAIL = -0.1
|
||||
BATTERY_DISCHARGED = -1.0
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
@ -24,7 +40,12 @@ class BatteryProperties(NamedTuple):
|
||||
multi_charge: bool = False
|
||||
|
||||
|
||||
class Battery(EnvObject, BoundingMixin):
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
class Battery(BoundingMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def is_discharged(self):
|
||||
@ -37,13 +58,13 @@ class Battery(EnvObject, BoundingMixin):
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def charge(self, amount) -> c:
|
||||
def do_charge_action(self, amount):
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return c.VALID
|
||||
return dict(valid=c.VALID, action=a.CHARGE, reward=r.CHARGE_VALID)
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
return dict(valid=c.NOT_VALID, action=a.CHARGE, reward=r.CHARGE_FAIL)
|
||||
|
||||
def decharge(self, amount) -> c:
|
||||
if self.charge_level != 0:
|
||||
@ -54,7 +75,7 @@ class Battery(EnvObject, BoundingMixin):
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
def summarize_state(self, **_):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
@ -63,53 +84,43 @@ class Battery(EnvObject, BoundingMixin):
|
||||
class BatteriesRegister(EnvObjectRegister):
|
||||
|
||||
_accepted_objects = Battery
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.is_observable = True
|
||||
|
||||
def as_array(self):
|
||||
# ToDO: Make this Lazy
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for inv_idx, battery in enumerate(self):
|
||||
self._array[inv_idx] = battery.as_array()
|
||||
return self._array
|
||||
|
||||
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
|
||||
batteries = [self._accepted_objects(pomdp_r, self._shape, agent,
|
||||
initial_charge_level)
|
||||
for _, agent in enumerate(agents)]
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(batteries)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((bat for bat in self if bat.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, cls).summarize_states())
|
||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
||||
|
||||
# Todo Move this to Mixin!
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def as_array_by_entity(self, entity):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return CHARGE_POD
|
||||
return c.CHARGE_POD
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
@ -120,9 +131,9 @@ class ChargePod(Entity):
|
||||
def charge_battery(self, battery: Battery):
|
||||
if battery.charge_level == 1.0:
|
||||
return c.NOT_VALID
|
||||
if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1:
|
||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||
return c.NOT_VALID
|
||||
battery.charge(self.charge_rate)
|
||||
battery.do_charge_action(self.charge_rate)
|
||||
return c.VALID
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
@ -135,14 +146,6 @@ class ChargePods(EntityRegister):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
|
||||
@DeprecationWarning
|
||||
def Xas_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS.value:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
super(ChargePods, self).__repr__()
|
||||
|
||||
@ -155,14 +158,14 @@ class BatteryFactory(BaseFactory):
|
||||
self.btry_prop = btry_prop
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].by_entity(agent).as_array()})
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||
return additional_raw_observations
|
||||
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.CHARGE_POD: self[c.CHARGE_POD].as_array()})
|
||||
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||
return additional_observations
|
||||
|
||||
@property
|
||||
@ -178,12 +181,12 @@ class BatteryFactory(BaseFactory):
|
||||
|
||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
)
|
||||
batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods})
|
||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||
return super_entities
|
||||
|
||||
def do_additional_step(self) -> dict:
|
||||
info_dict = super(BatteryFactory, self).do_additional_step()
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).do_additional_step()
|
||||
|
||||
# Decharge
|
||||
batteries = self[c.BATTERIES]
|
||||
@ -196,65 +199,70 @@ class BatteryFactory(BaseFactory):
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
return info_dict
|
||||
return super_reward_info
|
||||
|
||||
def do_charge(self, agent) -> c:
|
||||
if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos):
|
||||
return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||
def do_charge_action(self, agent) -> (dict, dict):
|
||||
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||
if valid:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
valid = c.NOT_VALID
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||
reward = dict(value=r.CHARGE_VALID if valid else r.CHARGE_FAIL, reason=a.CHARGE, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
if action == CHARGE_ACTION:
|
||||
valid = self.do_charge(agent)
|
||||
return valid
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.CHARGE:
|
||||
action_result = self.do_charge_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return valid
|
||||
return action_result
|
||||
pass
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
# There is Nothing to reset.
|
||||
pass
|
||||
|
||||
def check_additional_done(self) -> bool:
|
||||
super_done = super(BatteryFactory, self).check_additional_done()
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||
if super_done:
|
||||
return super_done
|
||||
return super_done, super_dict
|
||||
else:
|
||||
return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES])
|
||||
if self.btry_prop.done_when_discharged:
|
||||
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||
super_dict.update(DISCHARGE_DONE=1)
|
||||
return btry_done, super_dict
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
pass
|
||||
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent)
|
||||
if h.EnvActions.CHARGE == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
charge_pod = self[c.CHARGE_POD].by_pos(agent.pos)
|
||||
info_dict.update({f'{agent.name}_charge': 1})
|
||||
info_dict.update(agent_charged=1)
|
||||
self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.')
|
||||
reward += 0.1
|
||||
else:
|
||||
self[c.DROP_OFF].by_pos(agent.pos)
|
||||
info_dict.update({f'{agent.name}_failed_charge': 1})
|
||||
info_dict.update(failed_charge=1)
|
||||
self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.')
|
||||
reward -= 0.1
|
||||
|
||||
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
|
||||
reward_event_dict = super(BatteryFactory, self).additional_per_agent_reward(agent)
|
||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||
info_dict.update({f'{agent.name}_discharged': 1})
|
||||
reward -= 1
|
||||
self.print(f'{agent.name} Battery is discharged!')
|
||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': r.BATTERY_DISCHARGED, 'info': info_dict}})
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level})
|
||||
return reward, info_dict
|
||||
# All Fine
|
||||
pass
|
||||
return reward_event_dict
|
||||
|
||||
def render_additional_assets(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]]
|
||||
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||
additional_assets.extend(charge_pods)
|
||||
return additional_assets
|
||||
|
||||
|
@ -6,18 +6,32 @@ import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Destination Env
|
||||
DEST = 'Destination'
|
||||
DESTINATION = 1
|
||||
DESTINATION_DONE = 0.5
|
||||
DEST_REACHED = 'ReachedDestination'
|
||||
|
||||
|
||||
DESTINATION = 1
|
||||
DESTINATION_DONE = 0.5
|
||||
class Actions(BaseActions):
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
|
||||
WAIT_VALID = 0.1
|
||||
WAIT_FAIL = -0.1
|
||||
DEST_REACHED = 5.0
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
@ -30,20 +44,16 @@ class Destination(Entity):
|
||||
def currently_dwelling_names(self):
|
||||
return self._per_agent_times.keys()
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return DESTINATION
|
||||
return c.DESTINATION
|
||||
|
||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self.dwell_time = dwell_time
|
||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||
|
||||
def wait(self, agent: Agent):
|
||||
def do_wait_action(self, agent: Agent):
|
||||
self._per_agent_times[agent.name] -= 1
|
||||
return c.VALID
|
||||
|
||||
@ -52,7 +62,7 @@ class Destination(Entity):
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.name.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||
|
||||
def agent_is_dwelling(self, agent: Agent):
|
||||
@ -67,15 +77,19 @@ class Destination(Entity):
|
||||
class Destinations(EntityRegister):
|
||||
|
||||
_accepted_objects = Destination
|
||||
_light_blocking = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_blocking_light = False
|
||||
self.can_be_shadowed = False
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
self._array[:] = c.FREE_CELL
|
||||
# ToDo: Switch to new Style Array Put
|
||||
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS.value:
|
||||
if item.pos != c.NO_POS:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
@ -85,10 +99,11 @@ class Destinations(EntityRegister):
|
||||
|
||||
class ReachedDestinations(Destinations):
|
||||
_accepted_objects = Destination
|
||||
_light_blocking = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||
self.can_be_shadowed = False
|
||||
self.is_blocking_light = False
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
@ -102,7 +117,7 @@ class DestModeOptions(object):
|
||||
|
||||
class DestProperties(NamedTuple):
|
||||
n_dests: int = 1 # How many destinations are there
|
||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||
dwell_time: int = 0 # How long does the agent need to "do_wait_action" on a destination
|
||||
spawn_frequency: int = 0
|
||||
spawn_in_other_zone: bool = True #
|
||||
spawn_mode: str = DestModeOptions.DONE
|
||||
@ -113,6 +128,11 @@ class DestProperties(NamedTuple):
|
||||
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
class DestFactory(BaseFactory):
|
||||
# noinspection PyMissingConstructor
|
||||
@ -131,7 +151,7 @@ class DestFactory(BaseFactory):
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_actions = super().additional_actions
|
||||
if self.dest_prop.dwell_time:
|
||||
super_actions.append(Action(enum_ident=h.EnvActions.WAIT_ON_DEST))
|
||||
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
@ -147,27 +167,32 @@ class DestFactory(BaseFactory):
|
||||
)
|
||||
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
|
||||
|
||||
super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations})
|
||||
super_entities.update({c.DEST: destinations, c.DEST_REACHED: reached_destinations})
|
||||
return super_entities
|
||||
|
||||
def wait(self, agent: Agent):
|
||||
if destiantion := self[c.DESTINATION].by_pos(agent.pos):
|
||||
valid = destiantion.wait(agent)
|
||||
return valid
|
||||
def do_wait_action(self, agent: Agent) -> (dict, dict):
|
||||
if destination := self[c.DEST].by_pos(agent.pos):
|
||||
valid = destination.do_wait_action(agent)
|
||||
self.print(f'{agent.name} just waited at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_VALID': 1}
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
|
||||
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
|
||||
reward = dict(value=r.WAIT_VALID if valid else r.WAIT_FAIL, reason=a.WAIT_ON_DEST, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
if action == h.EnvActions.WAIT_ON_DEST:
|
||||
valid = self.wait(agent)
|
||||
return valid
|
||||
super_action_result = super().do_additional_actions(agent, action)
|
||||
if super_action_result is None:
|
||||
if action == a.WAIT_ON_DEST:
|
||||
action_result = self.do_wait_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return valid
|
||||
return super_action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -180,14 +205,14 @@ class DestFactory(BaseFactory):
|
||||
if destinations_to_spawn:
|
||||
n_dest_to_spawn = len(destinations_to_spawn)
|
||||
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DESTINATION].register_additional_items(destinations)
|
||||
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].register_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DESTINATION].register_additional_items(destinations)
|
||||
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].register_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||
@ -197,15 +222,14 @@ class DestFactory(BaseFactory):
|
||||
else:
|
||||
self.print('No Items are spawning, limit is reached.')
|
||||
|
||||
def do_additional_step(self) -> dict:
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
info_dict = super().do_additional_step()
|
||||
super_reward_info = super().do_additional_step()
|
||||
for key, val in self._dest_spawn_timer.items():
|
||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||
for dest in list(self[c.DESTINATION].values()):
|
||||
for dest in list(self[c.DEST].values()):
|
||||
if dest.is_considered_reached:
|
||||
self[c.REACHEDDESTINATION].register_item(dest)
|
||||
self[c.DESTINATION].delete_env_object(dest)
|
||||
dest.change_register(self[c.DEST])
|
||||
self._dest_spawn_timer[dest.name] = 0
|
||||
self.print(f'{dest.name} is reached now, removing...')
|
||||
else:
|
||||
@ -218,41 +242,29 @@ class DestFactory(BaseFactory):
|
||||
dest.leave(agent)
|
||||
self.print(f'{agent.name} left the destination early.')
|
||||
self.trigger_destination_spawn()
|
||||
return info_dict
|
||||
return super_reward_info
|
||||
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.DESTINATION: self[c.DESTINATION].as_array()})
|
||||
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
reward, info_dict = super().calculate_additional_reward(agent)
|
||||
if h.EnvActions.WAIT_ON_DEST == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
info_dict.update({f'{agent.name}_waiting_at_dest': 1})
|
||||
info_dict.update(agent_waiting_at_dest=1)
|
||||
self.print(f'{agent.name} just waited at {agent.pos}')
|
||||
reward += 0.1
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_tried_failed': 1})
|
||||
info_dict.update(agent_waiting_failed=1)
|
||||
self.print(f'{agent.name} just tried to wait wait at {agent.pos} but failed')
|
||||
reward -= 0.1
|
||||
if len(self[c.REACHEDDESTINATION]):
|
||||
for reached_dest in list(self[c.REACHEDDESTINATION]):
|
||||
reward_event_dict = super().additional_per_agent_reward(agent)
|
||||
if len(self[c.DEST_REACHED]):
|
||||
for reached_dest in list(self[c.DEST_REACHED]):
|
||||
if agent.pos == reached_dest.pos:
|
||||
info_dict.update({f'{agent.name}_reached_destination': 1})
|
||||
info_dict.update(agent_reached_destination=1)
|
||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||
reward += 0.5
|
||||
self[c.REACHEDDESTINATION].delete_env_object(reached_dest)
|
||||
return reward, info_dict
|
||||
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
|
||||
return reward_event_dict
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
destinations = [RenderEntity(c.DESTINATION.value, dest.pos) for dest in self[c.DESTINATION]]
|
||||
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
|
||||
additional_assets.extend(destinations)
|
||||
return additional_assets
|
||||
|
||||
|
@ -8,6 +8,7 @@ import numpy as np
|
||||
# from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
||||
@ -21,8 +22,14 @@ class Constants(BaseConstants):
|
||||
DIRT = 'Dirt'
|
||||
|
||||
|
||||
class EnvActions(BaseActions):
|
||||
CLEAN_UP = 'clean_up'
|
||||
class Actions(BaseActions):
|
||||
CLEAN_UP = 'do_cleanup_action'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
CLEAN_UP_VALID = 0.5
|
||||
CLEAN_UP_FAIL = -0.1
|
||||
CLEAN_UP_LAST_PIECE = 4.5
|
||||
|
||||
|
||||
class DirtProperties(NamedTuple):
|
||||
@ -41,10 +48,6 @@ class DirtProperties(NamedTuple):
|
||||
|
||||
class Dirt(Entity):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return self._amount
|
||||
@ -116,6 +119,8 @@ def entropy(x):
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
@ -125,7 +130,7 @@ class DirtFactory(BaseFactory):
|
||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||
super_actions = super().additional_actions
|
||||
if self.dirt_prop.agent_can_interact:
|
||||
super_actions.append(Action(str_ident=EnvActions.CLEAN_UP))
|
||||
super_actions.append(Action(str_ident=a.CLEAN_UP))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
@ -151,7 +156,7 @@ class DirtFactory(BaseFactory):
|
||||
additional_assets.extend(dirt)
|
||||
return additional_assets
|
||||
|
||||
def clean_up(self, agent: Agent) -> c:
|
||||
def do_cleanup_action(self, agent: Agent) -> (dict, dict):
|
||||
if dirt := self[c.DIRT].by_pos(agent.pos):
|
||||
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
||||
|
||||
@ -159,9 +164,21 @@ class DirtFactory(BaseFactory):
|
||||
self[c.DIRT].delete_env_object(dirt)
|
||||
else:
|
||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||
return c.VALID
|
||||
valid = c.VALID
|
||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1}
|
||||
reward = r.CLEAN_UP_VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1}
|
||||
reward = r.CLEAN_UP_FAIL
|
||||
|
||||
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||
reward += r.CLEAN_UP_LAST_PIECE
|
||||
self.print(f'{agent.name} picked up the last piece of dirt!')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_LAST_PIECE': 1}
|
||||
return valid, dict(value=reward, reason=a.CLEAN_UP, info=info_dict)
|
||||
|
||||
def trigger_dirt_spawn(self, initial_spawn=False):
|
||||
dirt_rng = self._dirt_rng
|
||||
@ -177,8 +194,8 @@ class DirtFactory(BaseFactory):
|
||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
||||
|
||||
def do_additional_step(self) -> dict:
|
||||
info_dict = super().do_additional_step()
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
super_reward_info = super().do_additional_step()
|
||||
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||
for agent in self[c.AGENT]:
|
||||
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
||||
@ -199,42 +216,44 @@ class DirtFactory(BaseFactory):
|
||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
||||
else:
|
||||
self._next_dirt_spawn -= 1
|
||||
return info_dict
|
||||
return super_reward_info
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
if action == EnvActions.CLEAN_UP:
|
||||
if self.dirt_prop.agent_can_interact:
|
||||
valid = self.clean_up(agent)
|
||||
return valid
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.CLEAN_UP:
|
||||
return self.do_cleanup_action(agent)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return valid
|
||||
return action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
super().do_additional_reset()
|
||||
self.trigger_dirt_spawn(initial_spawn=True)
|
||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
|
||||
|
||||
def check_additional_done(self):
|
||||
super_done = super().check_additional_done()
|
||||
done = self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0)
|
||||
return super_done or done
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
super_done, super_dict = super().check_additional_done()
|
||||
if self.dirt_prop.done_when_clean:
|
||||
if all_cleaned := len(self[c.DIRT]) == 0:
|
||||
super_dict.update(ALL_CLEAN_DONE=all_cleaned)
|
||||
return all_cleaned, super_dict
|
||||
return super_done, super_dict
|
||||
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
reward, info_dict = super().calculate_additional_reward(agent)
|
||||
def gather_additional_info(self, agent: Agent) -> dict:
|
||||
event_reward_dict = super().additional_per_agent_reward(agent)
|
||||
info_dict = dict()
|
||||
|
||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||
current_dirt_amount = sum(dirt)
|
||||
dirty_tile_count = len(dirt)
|
||||
|
||||
# if dirty_tile_count:
|
||||
# dirt_distribution_score = entropy(softmax(np.asarray(dirt)) / dirty_tile_count)
|
||||
# else:
|
||||
@ -242,33 +261,14 @@ class DirtFactory(BaseFactory):
|
||||
|
||||
info_dict.update(dirt_amount=current_dirt_amount)
|
||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||
# info_dict.update(dirt_distribution_score=dirt_distribution_score)
|
||||
|
||||
if agent.temp_action == EnvActions.CLEAN_UP:
|
||||
if agent.temp_valid:
|
||||
# Reward if pickup succeds,
|
||||
# 0.5 on every pickup
|
||||
reward += 0.5
|
||||
info_dict.update(dirt_cleaned=1)
|
||||
if self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||
# 0.5 additional reward for the very last pickup
|
||||
reward += 4.5
|
||||
info_dict.update(done_clean=1)
|
||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||
else:
|
||||
reward -= 0.01
|
||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||
info_dict.update({f'{agent.name}_failed_dirt_cleanup': 1})
|
||||
info_dict.update(failed_dirt_clean=1)
|
||||
|
||||
# Potential based rewards ->
|
||||
# track the last reward , minus the current reward = potential
|
||||
return reward, info_dict
|
||||
event_reward_dict.update({'info': info_dict})
|
||||
return event_reward_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as aro
|
||||
render = True
|
||||
render = False
|
||||
|
||||
dirt_props = DirtProperties(
|
||||
initial_dirt_ratio=0.35,
|
||||
@ -289,14 +289,15 @@ if __name__ == '__main__':
|
||||
move_props = {'allow_square_movement': True,
|
||||
'allow_diagonal_movement': False,
|
||||
'allow_no_op': False}
|
||||
import time
|
||||
global_timings = []
|
||||
for i in range(20):
|
||||
for i in range(10):
|
||||
|
||||
factory = DirtFactory(n_agents=2, done_at_collision=False,
|
||||
level_name='rooms', max_steps=1000,
|
||||
doors_have_area=False,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
record_episodes=True, verbose=True,
|
||||
verbose=False,
|
||||
mv_prop=move_props, dirt_prop=dirt_props,
|
||||
# inject_agents=[TSPDirtAgent],
|
||||
)
|
||||
@ -307,7 +308,6 @@ if __name__ == '__main__':
|
||||
obs_space = factory.observation_space
|
||||
obs_space_named = factory.named_observation_space
|
||||
times = []
|
||||
import time
|
||||
for epoch in range(10):
|
||||
start_time = time.time()
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
@ -318,18 +318,19 @@ if __name__ == '__main__':
|
||||
factory.render()
|
||||
# tsp_agent = factory.get_injected_agents()[0]
|
||||
|
||||
r = 0
|
||||
rwrd = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
r += step_r
|
||||
env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action)
|
||||
rwrd += step_rwrd
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
times.append(time.time() - start_time)
|
||||
# print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
print('Time Taken: ', sum(times) / 10)
|
||||
global_timings.append(sum(times) / 10)
|
||||
print('Time Taken: ', sum(global_timings[10:]) / 10)
|
||||
print('Mean Time Taken: ', sum(times) / 10)
|
||||
global_timings.extend(times)
|
||||
print('Mean Time Taken: ', sum(global_timings) / len(global_timings))
|
||||
print('Median Time Taken: ', global_timings[len(global_timings)//2])
|
||||
|
||||
pass
|
||||
|
@ -7,9 +7,10 @@ import random
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundRegisterMixin, ObjectRegister
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
@ -23,10 +24,17 @@ class Constants(BaseConstants):
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
|
||||
class EnvActions(BaseActions):
|
||||
class Actions(BaseActions):
|
||||
ITEM_ACTION = 'item_action'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
DROP_OFF_VALID = 0.1
|
||||
DROP_OFF_FAIL = -0.1
|
||||
PICK_UP_FAIL = -0.1
|
||||
PICK_UP_VALID = 0.1
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -37,10 +45,6 @@ class Item(Entity):
|
||||
def auto_despawn(self):
|
||||
return self._auto_despawn
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differently
|
||||
@ -68,7 +72,7 @@ class ItemRegister(EntityRegister):
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(BoundRegisterMixin):
|
||||
class Inventory(BoundEnvObjRegister):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -131,10 +135,6 @@ class Inventories(ObjectRegister):
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return Constants.ITEM_DROP_OFF
|
||||
@ -176,7 +176,8 @@ class ItemProperties(NamedTuple):
|
||||
|
||||
|
||||
c = Constants
|
||||
a = EnvActions
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
@ -230,37 +231,43 @@ class ItemFactory(BaseFactory):
|
||||
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def do_item_action(self, agent: Agent):
|
||||
def do_item_action(self, agent: Agent) -> (dict, dict):
|
||||
inventory = self[c.INVENTORY].by_entity(agent)
|
||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||
if inventory:
|
||||
valid = drop_off.place_item(inventory.pop())
|
||||
return valid
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
valid = c.NOT_VALID
|
||||
if valid:
|
||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_VALID': 1}
|
||||
else:
|
||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1}
|
||||
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
return valid, reward
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
try:
|
||||
inventory.register_item(item)
|
||||
item.change_register(inventory)
|
||||
self[c.ITEM].delete_env_object(item)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
return c.VALID
|
||||
except RuntimeError:
|
||||
return c.NOT_VALID
|
||||
item.change_register(inventory)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1}
|
||||
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1}
|
||||
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.ITEM_ACTION:
|
||||
valid = self.do_item_action(agent)
|
||||
return valid
|
||||
action_result = self.do_item_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return valid
|
||||
return action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -277,9 +284,9 @@ class ItemFactory(BaseFactory):
|
||||
else:
|
||||
self.print('No Items are spawning, limit is reached.')
|
||||
|
||||
def do_additional_step(self) -> dict:
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
info_dict = super().do_additional_step()
|
||||
super_reward_info = super().do_additional_step()
|
||||
for item in list(self[c.ITEM].values()):
|
||||
if item.auto_despawn >= 1:
|
||||
item.set_auto_despawn(item.auto_despawn-1)
|
||||
@ -292,35 +299,7 @@ class ItemFactory(BaseFactory):
|
||||
self.trigger_item_spawn()
|
||||
else:
|
||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||
return info_dict
|
||||
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
reward, info_dict = super().calculate_additional_reward(agent)
|
||||
if a.ITEM_ACTION == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||
info_dict.update({f'{agent.name}_item_drop_off': 1})
|
||||
info_dict.update(item_drop_off=1)
|
||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||
reward += 1
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_item_pickup': 1})
|
||||
info_dict.update(item_pickup=1)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
reward += 0.2
|
||||
else:
|
||||
if self[c.DROP_OFF].by_pos(agent.pos):
|
||||
info_dict.update({f'{agent.name}_failed_drop_off': 1})
|
||||
info_dict.update(failed_drop_off=1)
|
||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||
reward -= 0.1
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_failed_item_action': 1})
|
||||
info_dict.update(failed_pick_up=1)
|
||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||
reward -= 0.1
|
||||
return reward, info_dict
|
||||
return super_reward_info
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -335,9 +314,9 @@ class ItemFactory(BaseFactory):
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||
|
||||
render = True
|
||||
render = False
|
||||
|
||||
item_probs = ItemProperties(n_items=30)
|
||||
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2)
|
||||
|
||||
@ -345,7 +324,7 @@ if __name__ == '__main__':
|
||||
'allow_diagonal_movement': True,
|
||||
'allow_no_op': False}
|
||||
|
||||
factory = ItemFactory(n_agents=2, done_at_collision=False,
|
||||
factory = ItemFactory(n_agents=6, done_at_collision=False,
|
||||
level_name='rooms', max_steps=400,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
record_episodes=True, verbose=True,
|
||||
|
@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from typing import Tuple, Union, Dict, List
|
||||
from typing import Tuple, Union, Dict, List, NamedTuple
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@ -38,37 +38,27 @@ class Constants:
|
||||
OPEN_DOOR = 'open'
|
||||
|
||||
ACTION = 'action'
|
||||
COLLISIONS = 'collision'
|
||||
VALID = 'valid'
|
||||
NOT_VALID = 'not_valid'
|
||||
|
||||
# Battery Env
|
||||
CHARGE_POD = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
|
||||
# Destination Env
|
||||
DESTINATION = 'Destination'
|
||||
REACHEDDESTINATION = 'ReachedDestination'
|
||||
COLLISION = 'collision'
|
||||
VALID = True
|
||||
NOT_VALID = False
|
||||
|
||||
|
||||
class EnvActions:
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
|
||||
# Other
|
||||
NOOP = 'no_op'
|
||||
# MOVE = 'move'
|
||||
NOOP = 'no_op'
|
||||
USE_DOOR = 'use_door'
|
||||
|
||||
CHARGE = 'charge'
|
||||
WAIT_ON_DEST = 'wait'
|
||||
|
||||
@classmethod
|
||||
def is_move(cls, other):
|
||||
return any([other == direction for direction in cls.movement_actions()])
|
||||
@ -86,8 +76,19 @@ class EnvActions:
|
||||
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||
|
||||
|
||||
class Rewards:
|
||||
|
||||
MOVEMENTS_VALID = -0.001
|
||||
MOVEMENTS_FAIL = -0.001
|
||||
NOOP = -0.1
|
||||
USE_DOOR_VALID = -0.001
|
||||
USE_DOOR_FAIL = -0.001
|
||||
COLLISION = -1
|
||||
|
||||
|
||||
m = EnvActions
|
||||
c = Constants
|
||||
r = Rewards
|
||||
|
||||
ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1),
|
||||
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
|
||||
@ -184,15 +185,20 @@ def asset_str(agent):
|
||||
# What does this abonimation do?
|
||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||
# print('error')
|
||||
col_names = [x.name for x in agent.temp_collisions]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not agent.temp_valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif agent.temp_valid and not EnvActions.is_move(agent.temp_action):
|
||||
return c.AGENT, 'valid'
|
||||
elif agent.temp_valid and EnvActions.is_move(agent.temp_action):
|
||||
return c.AGENT, 'move'
|
||||
if step_result := agent.step_result:
|
||||
action = step_result['action_name']
|
||||
valid = step_result['action_valid']
|
||||
col_names = [x.name for x in step_result['collisions']]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif valid and not EnvActions.is_move(action):
|
||||
return c.AGENT, 'valid'
|
||||
elif valid and EnvActions.is_move(action):
|
||||
return c.AGENT, 'move'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
|
||||
|
@ -134,8 +134,7 @@ if __name__ == '__main__':
|
||||
max_spawn_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||
item_props = ItemProperties(n_items=10, agent_can_interact=True,
|
||||
spawn_frequency=30, n_drop_off_locations=2,
|
||||
item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2,
|
||||
max_agent_inventory_capacity=15)
|
||||
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
|
||||
factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user