diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 4ac7b0d..2553245 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -13,9 +13,9 @@ class AgentState: self.i = i self.action = action - self.pos = None self.collision_vector = None self.action_valid = None + self.pos = None @property def collisions(self): @@ -34,6 +34,7 @@ class FactoryMonitor: def __init__(self, env): self._env = env self._monitor = defaultdict(lambda: defaultdict(lambda: 0)) + self._last_vals = defaultdict(lambda: 0) def __iter__(self): for key, value in self._monitor.items(): @@ -42,19 +43,23 @@ class FactoryMonitor: def add(self, key, value, step=None): assert step is None or step >= 1 # Is this good practice? step = step or self._env.steps - self._monitor[key][step] = list(self._monitor[key].values())[-1] + value - return self._monitor[key][step] + self._last_vals[key] = self._last_vals[key] + value + self._monitor[key][step] = self._last_vals[key] + return self._last_vals[key] def set(self, key, value, step=None): assert step is None or step >= 1 # Is this good practice? step = step or self._env.steps - self._monitor[key][step] = value - return self._monitor[key][step] + self._last_vals[key] = value + self._monitor[key][step] = self._last_vals[key] + return self._last_vals[key] - def reduce(self, key, value, step=None): + def remove(self, key, value, step=None): assert step is None or step >= 1 # Is this good practice? step = step or self._env.steps - self._monitor[key][step] = list(self._monitor[key].values())[-1] - value + self._last_vals[key] = self._last_vals[key] - value + self._monitor[key][step] = self._last_vals[key] + return self._last_vals[key] def to_dict(self): return dict(self) @@ -83,13 +88,13 @@ class BaseFactory: h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') ) self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} - self.monitor = FactoryMonitor(self) self.reset() def reset(self): self.done = False self.steps = 0 self.cumulative_reward = 0 + self.monitor = FactoryMonitor(self) # Agent placement ... agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8) floor_tiles = np.argwhere(self.level == h.IS_FREE_CELL) diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 3c2d495..b9b743e 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -1,5 +1,5 @@ import numpy as np -from environments.factory.base_factory import BaseFactory +from environments.factory.base_factory import BaseFactory, FactoryMonitor class SimpleFactory(BaseFactory): @@ -14,28 +14,36 @@ class SimpleFactory(BaseFactory): self.state[-1, x, y] = 1 def reset(self): - super().reset() + state, r, done, _ = super().reset() dirt_slice = np.zeros((1, *self.state.shape[1:])) self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice self.spawn_dirt() + # Always: This should return state, r, done, info + return self.state, r, done, _ def calculate_reward(self, agent_states): for agent_state in agent_states: collisions = agent_state.collisions entities = [self.slice_strings[entity] for entity in collisions] - for entity in entities: - self.monitor.add(f'{entity}_collisions', 1) - print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' - f'{entities}') + if entities: + for entity in entities: + self.monitor.add(f'agent_{agent_state.i}_collision_{entity}', 1) + print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' + f'{entities}') return 0, {} if __name__ == '__main__': import random factory = SimpleFactory(n_agents=1, max_dirt=8) - random_actions = [random.randint(0, 7) for _ in range(200)] - for action in random_actions: - state, r, done, _ = factory.step(action) - print(f'Factory run done, reward is:\n {r}') - print(f'There have been the following collisions: \n {dict(factory.monitor)}') + monitor_list = list() + for epoch in range(100): + random_actions = [random.randint(0, 7) for _ in range(200)] + state, r, done, _ = factory.reset() + for action in random_actions: + state, r, done, info = factory.step(action) + monitor_list.append(factory.monitor) + + print(f'Factory run done, reward is:\n {r}') + print(f'There have been the following collisions: \n {dict(factory.monitor)}')