AgentState Object

dataclass
class AgentState:
    i: int
    action: int

    pos = None
    collision_vector = None
    action_valid = None
This commit is contained in:
steffen-illium
2021-05-14 09:09:20 +02:00
parent 14741aa5a5
commit 86204a6266
3 changed files with 74 additions and 37 deletions

View File

@ -19,11 +19,11 @@ class SimpleFactory(BaseFactory):
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt()
def calculate_reward(self, collisions_vecs, actions, r):
for agent_i, cols in enumerate(collisions_vecs):
cols = np.argwhere(cols != 0).flatten()
print(f't = {self.steps}\tAgent {agent_i} has collisions with '
f'{[self.slice_strings[entity] for entity in cols]}')
def calculate_reward(self, agent_states):
for agent_state in agent_states:
collisions = agent_state.collisions
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
f'{[self.slice_strings[entity] for entity in collisions]}')
return 0, {}