This commit is contained in:
Steffen Illium
2021-07-29 09:31:32 +02:00
parent b0aeb6f94f
commit 8631f11502
5 changed files with 123 additions and 25 deletions

View File

@ -26,6 +26,16 @@ class DirtProperties(NamedTuple):
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def entropy(x):
return -(x * np.log(x + 1e-8)).sum()
# noinspection PyAttributeOutsideInit
class SimpleFactory(BaseFactory):
@ -46,9 +56,8 @@ class SimpleFactory(BaseFactory):
action = self._actions.by_name(action)
return self._actions[action].name == CLEAN_UP_ACTION
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), verbose=False, **kwargs):
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), **kwargs):
self.dirt_properties = dirt_properties
self.verbose = verbose
self._renderer = None # expensive - don't use it when not required !
super(SimpleFactory, self).__init__(*args, **kwargs)
@ -108,8 +117,8 @@ class SimpleFactory(BaseFactory):
def clean_up(self, agent: Agent) -> bool:
dirt_slice = self._slices.by_name(DIRT).slice
if dirt_slice[agent.pos]:
new_dirt_amount = dirt_slice[agent.pos] - self.dirt_properties.clean_amount
if old_dirt_amount := dirt_slice[agent.pos]:
new_dirt_amount = old_dirt_amount - self.dirt_properties.clean_amount
dirt_slice[agent.pos] = max(new_dirt_amount, c.FREE_CELL.value)
return True
else:
@ -135,14 +144,11 @@ class SimpleFactory(BaseFactory):
return {}
def do_additional_actions(self, agent: Agent, action: int) -> bool:
if action != self._actions.is_moving_action(action):
if self._is_clean_up_action(action):
valid = self.clean_up(agent)
return valid
else:
raise RuntimeError('This should not happen!!!')
if self._is_clean_up_action(action):
valid = self.clean_up(agent)
return valid
else:
raise RuntimeError('This should not happen!!!')
return c.NOT_VALID.value
def do_additional_reset(self) -> None:
self.spawn_dirt()
@ -155,13 +161,18 @@ class SimpleFactory(BaseFactory):
dirty_tiles = [dirt_slice[tile.pos] for tile in self._tiles if dirt_slice[tile.pos]]
current_dirt_amount = sum(dirty_tiles)
dirty_tile_count = len(dirty_tiles)
if dirty_tile_count:
dirt_distribution_score = entropy(softmax(dirt_slice)) / dirty_tile_count
else:
dirt_distribution_score = 0
info_dict.update(dirt_amount=current_dirt_amount)
info_dict.update(dirty_tile_count=dirty_tile_count)
info_dict.update(dirt_distribution_score=dirt_distribution_score)
try:
# penalty = current_dirt_amount
reward = 0
reward = dirt_distribution_score
except (ZeroDivisionError, RuntimeWarning):
reward = 0
@ -213,10 +224,6 @@ class SimpleFactory(BaseFactory):
# track the last reward , minus the current reward = potential
return reward, info_dict
def print(self, string):
if self.verbose:
print(string)
if __name__ == '__main__':
render = True