Items
This commit is contained in:
@ -26,6 +26,16 @@ class DirtProperties(NamedTuple):
|
||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place
|
||||
|
||||
|
||||
def softmax(x):
|
||||
"""Compute softmax values for each sets of scores in x."""
|
||||
e_x = np.exp(x - np.max(x))
|
||||
return e_x / e_x.sum()
|
||||
|
||||
|
||||
def entropy(x):
|
||||
return -(x * np.log(x + 1e-8)).sum()
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class SimpleFactory(BaseFactory):
|
||||
|
||||
@ -46,9 +56,8 @@ class SimpleFactory(BaseFactory):
|
||||
action = self._actions.by_name(action)
|
||||
return self._actions[action].name == CLEAN_UP_ACTION
|
||||
|
||||
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), verbose=False, **kwargs):
|
||||
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), **kwargs):
|
||||
self.dirt_properties = dirt_properties
|
||||
self.verbose = verbose
|
||||
self._renderer = None # expensive - don't use it when not required !
|
||||
super(SimpleFactory, self).__init__(*args, **kwargs)
|
||||
|
||||
@ -108,8 +117,8 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
def clean_up(self, agent: Agent) -> bool:
|
||||
dirt_slice = self._slices.by_name(DIRT).slice
|
||||
if dirt_slice[agent.pos]:
|
||||
new_dirt_amount = dirt_slice[agent.pos] - self.dirt_properties.clean_amount
|
||||
if old_dirt_amount := dirt_slice[agent.pos]:
|
||||
new_dirt_amount = old_dirt_amount - self.dirt_properties.clean_amount
|
||||
dirt_slice[agent.pos] = max(new_dirt_amount, c.FREE_CELL.value)
|
||||
return True
|
||||
else:
|
||||
@ -135,14 +144,11 @@ class SimpleFactory(BaseFactory):
|
||||
return {}
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: int) -> bool:
|
||||
if action != self._actions.is_moving_action(action):
|
||||
if self._is_clean_up_action(action):
|
||||
valid = self.clean_up(agent)
|
||||
return valid
|
||||
else:
|
||||
raise RuntimeError('This should not happen!!!')
|
||||
if self._is_clean_up_action(action):
|
||||
valid = self.clean_up(agent)
|
||||
return valid
|
||||
else:
|
||||
raise RuntimeError('This should not happen!!!')
|
||||
return c.NOT_VALID.value
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
self.spawn_dirt()
|
||||
@ -155,13 +161,18 @@ class SimpleFactory(BaseFactory):
|
||||
dirty_tiles = [dirt_slice[tile.pos] for tile in self._tiles if dirt_slice[tile.pos]]
|
||||
current_dirt_amount = sum(dirty_tiles)
|
||||
dirty_tile_count = len(dirty_tiles)
|
||||
if dirty_tile_count:
|
||||
dirt_distribution_score = entropy(softmax(dirt_slice)) / dirty_tile_count
|
||||
else:
|
||||
dirt_distribution_score = 0
|
||||
|
||||
info_dict.update(dirt_amount=current_dirt_amount)
|
||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||
info_dict.update(dirt_distribution_score=dirt_distribution_score)
|
||||
|
||||
try:
|
||||
# penalty = current_dirt_amount
|
||||
reward = 0
|
||||
reward = dirt_distribution_score
|
||||
except (ZeroDivisionError, RuntimeWarning):
|
||||
reward = 0
|
||||
|
||||
@ -213,10 +224,6 @@ class SimpleFactory(BaseFactory):
|
||||
# track the last reward , minus the current reward = potential
|
||||
return reward, info_dict
|
||||
|
||||
def print(self, string):
|
||||
if self.verbose:
|
||||
print(string)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
|
Reference in New Issue
Block a user