mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
reward, done -> fixed
This commit is contained in:
parent
e7d31aa272
commit
3114cdffc3
@ -68,9 +68,7 @@ class BaseFactory(gym.Env):
|
|||||||
raise NotImplementedError('Please register additional actions ')
|
raise NotImplementedError('Please register additional actions ')
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
self.done = False
|
|
||||||
self.steps = 0
|
self.steps = 0
|
||||||
self.cumulative_reward = 0
|
|
||||||
self.monitor = FactoryMonitor(self)
|
self.monitor = FactoryMonitor(self)
|
||||||
self.agent_states = []
|
self.agent_states = []
|
||||||
# Agent placement ...
|
# Agent placement ...
|
||||||
@ -95,6 +93,7 @@ class BaseFactory(gym.Env):
|
|||||||
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
||||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
|
done = False
|
||||||
|
|
||||||
# Move this in a seperate function?
|
# Move this in a seperate function?
|
||||||
agent_states = list()
|
agent_states = list()
|
||||||
@ -113,15 +112,15 @@ class BaseFactory(gym.Env):
|
|||||||
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
|
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
|
||||||
agent_states[i].update(collision_vector=collision_vec)
|
agent_states[i].update(collision_vector=collision_vec)
|
||||||
if self.done_at_collision and collision_vec.any():
|
if self.done_at_collision and collision_vec.any():
|
||||||
self.done = True
|
done = True
|
||||||
|
|
||||||
self.agent_states = agent_states
|
self.agent_states = agent_states
|
||||||
reward, info = self.calculate_reward(agent_states)
|
reward, info = self.calculate_reward(agent_states)
|
||||||
|
|
||||||
if self.steps >= self.max_steps:
|
if self.steps >= self.max_steps:
|
||||||
self.done = True
|
done = True
|
||||||
self.monitor.add('step_reward', reward)
|
self.monitor.add('step_reward', reward)
|
||||||
return self.state, reward, self.done, info
|
return self.state, reward, done, info
|
||||||
|
|
||||||
def _is_moving_action(self, action):
|
def _is_moving_action(self, action):
|
||||||
return action < self.movement_actions
|
return action < self.movement_actions
|
||||||
|
@ -82,16 +82,17 @@ class SimpleFactory(BaseFactory):
|
|||||||
return pos, cleanup_was_sucessfull
|
return pos, cleanup_was_sucessfull
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
|
# TODO: For debugging only!!!! Remove at times.....
|
||||||
if self.state[h.LEVEL_IDX][self.agent_i_position(0)] == h.IS_OCCUPIED_CELL:
|
if self.state[h.LEVEL_IDX][self.agent_i_position(0)] == h.IS_OCCUPIED_CELL:
|
||||||
print(f'fAgent placed on wall!!!!, step is :{self.steps}')
|
print(f'fAgent placed on wall!!!!, step is :{self.steps}')
|
||||||
raise Exception('Agent placed on wall!!!!')
|
raise Exception('Agent placed on wall!!!!')
|
||||||
_, _, _, info = super(SimpleFactory, self).step(actions)
|
_, r, done, info = super(SimpleFactory, self).step(actions)
|
||||||
if not self.next_dirt_spawn:
|
if not self.next_dirt_spawn:
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||||
else:
|
else:
|
||||||
self.next_dirt_spawn -= 1
|
self.next_dirt_spawn -= 1
|
||||||
return self.state, self.cumulative_reward, self.done, info
|
return self.state, r, done, info
|
||||||
|
|
||||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
if action != self._is_moving_action(action):
|
if action != self._is_moving_action(action):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user