reward, done -> fixed
This commit is contained in:
parent
e7d31aa272
commit
3114cdffc3
@ -68,9 +68,7 @@ class BaseFactory(gym.Env):
|
||||
raise NotImplementedError('Please register additional actions ')
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
self.done = False
|
||||
self.steps = 0
|
||||
self.cumulative_reward = 0
|
||||
self.monitor = FactoryMonitor(self)
|
||||
self.agent_states = []
|
||||
# Agent placement ...
|
||||
@ -95,6 +93,7 @@ class BaseFactory(gym.Env):
|
||||
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||
self.steps += 1
|
||||
done = False
|
||||
|
||||
# Move this in a seperate function?
|
||||
agent_states = list()
|
||||
@ -113,15 +112,15 @@ class BaseFactory(gym.Env):
|
||||
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
|
||||
agent_states[i].update(collision_vector=collision_vec)
|
||||
if self.done_at_collision and collision_vec.any():
|
||||
self.done = True
|
||||
done = True
|
||||
|
||||
self.agent_states = agent_states
|
||||
reward, info = self.calculate_reward(agent_states)
|
||||
|
||||
if self.steps >= self.max_steps:
|
||||
self.done = True
|
||||
done = True
|
||||
self.monitor.add('step_reward', reward)
|
||||
return self.state, reward, self.done, info
|
||||
return self.state, reward, done, info
|
||||
|
||||
def _is_moving_action(self, action):
|
||||
return action < self.movement_actions
|
||||
|
@ -82,16 +82,17 @@ class SimpleFactory(BaseFactory):
|
||||
return pos, cleanup_was_sucessfull
|
||||
|
||||
def step(self, actions):
|
||||
# TODO: For debugging only!!!! Remove at times.....
|
||||
if self.state[h.LEVEL_IDX][self.agent_i_position(0)] == h.IS_OCCUPIED_CELL:
|
||||
print(f'fAgent placed on wall!!!!, step is :{self.steps}')
|
||||
raise Exception('Agent placed on wall!!!!')
|
||||
_, _, _, info = super(SimpleFactory, self).step(actions)
|
||||
_, r, done, info = super(SimpleFactory, self).step(actions)
|
||||
if not self.next_dirt_spawn:
|
||||
self.spawn_dirt()
|
||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||
else:
|
||||
self.next_dirt_spawn -= 1
|
||||
return self.state, self.cumulative_reward, self.done, info
|
||||
return self.state, r, done, info
|
||||
|
||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
if action != self._is_moving_action(action):
|
||||
|
Loading…
x
Reference in New Issue
Block a user