Rework of Observations and Entity Differentiation, lazy obs build by notification
This commit is contained in:
@ -6,11 +6,11 @@ import random
|
||||
import numpy as np
|
||||
|
||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||
from environments.helpers import Constants as c
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister, EntityRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.utility_classes import ObservationProperties
|
||||
@ -42,6 +42,7 @@ class Dirt(Entity):
|
||||
def amount(self):
|
||||
return self._amount
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differntly
|
||||
return self._amount
|
||||
@ -52,6 +53,7 @@ class Dirt(Entity):
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
self._register.notify_change_to_value(self)
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
@ -59,18 +61,7 @@ class Dirt(Entity):
|
||||
return state_dict
|
||||
|
||||
|
||||
class DirtRegister(MovingEntityObjectRegister):
|
||||
|
||||
def as_array(self):
|
||||
if self._array is not None:
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for dirt in list(self.values()):
|
||||
if dirt.amount == 0:
|
||||
self.delete_entity(dirt)
|
||||
self._array[0, dirt.x, dirt.y] = dirt.amount
|
||||
else:
|
||||
self._array = np.zeros((1, *self._level_shape))
|
||||
return self._array
|
||||
class DirtRegister(EntityRegister):
|
||||
|
||||
_accepted_objects = Dirt
|
||||
|
||||
@ -93,7 +84,7 @@ class DirtRegister(MovingEntityObjectRegister):
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
dirt = self.by_pos(tile.pos)
|
||||
if dirt is None:
|
||||
dirt = Dirt(tile, amount=self.dirt_properties.max_spawn_amount)
|
||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
self.register_item(dirt)
|
||||
else:
|
||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||
@ -155,7 +146,7 @@ class DirtFactory(BaseFactory):
|
||||
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
||||
|
||||
if new_dirt_amount <= 0:
|
||||
self[c.DIRT].delete_entity(dirt)
|
||||
self[c.DIRT].delete_env_object(dirt)
|
||||
else:
|
||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||
return c.VALID
|
||||
@ -224,6 +215,11 @@ class DirtFactory(BaseFactory):
|
||||
done = self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0)
|
||||
return super_done or done
|
||||
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
reward, info_dict = super().calculate_additional_reward(agent)
|
||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||
@ -278,41 +274,52 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
|
||||
pomdp_r=2, additional_agent_placeholder=None)
|
||||
pomdp_r=2, additional_agent_placeholder=None, cast_shadows=True)
|
||||
|
||||
move_props = {'allow_square_movement': True,
|
||||
'allow_diagonal_movement': False,
|
||||
'allow_no_op': False}
|
||||
global_timings = []
|
||||
for i in range(20):
|
||||
|
||||
factory = DirtFactory(n_agents=1, done_at_collision=False,
|
||||
level_name='rooms', max_steps=400,
|
||||
doors_have_area=False,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
record_episodes=True, verbose=True,
|
||||
mv_prop=move_props, dirt_prop=dirt_props,
|
||||
inject_agents=[TSPDirtAgent]
|
||||
)
|
||||
factory = DirtFactory(n_agents=2, done_at_collision=False,
|
||||
level_name='rooms', max_steps=1000,
|
||||
doors_have_area=False,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
record_episodes=True, verbose=True,
|
||||
mv_prop=move_props, dirt_prop=dirt_props,
|
||||
# inject_agents=[TSPDirtAgent],
|
||||
)
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
n_actions = factory.action_space.n - 1
|
||||
_ = factory.observation_space
|
||||
|
||||
for epoch in range(10):
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
in range(factory.n_agents)] for _
|
||||
in range(factory.max_steps+1)]
|
||||
env_state = factory.reset()
|
||||
if render:
|
||||
factory.render()
|
||||
tsp_agent = factory.get_injected_agents()[0]
|
||||
|
||||
r = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(tsp_agent.predict())
|
||||
r += step_r
|
||||
# noinspection DuplicatedCode
|
||||
n_actions = factory.action_space.n - 1
|
||||
_ = factory.observation_space
|
||||
obs_space = factory.observation_space
|
||||
obs_space_named = factory.named_observation_space
|
||||
times = []
|
||||
import time
|
||||
for epoch in range(10):
|
||||
start_time = time.time()
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
in range(factory.n_agents)] for _
|
||||
in range(factory.max_steps+1)]
|
||||
env_state = factory.reset()
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
# tsp_agent = factory.get_injected_agents()[0]
|
||||
|
||||
r = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
r += step_r
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
times.append(time.time() - start_time)
|
||||
# print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
print('Time Taken: ', sum(times) / 10)
|
||||
global_timings.append(sum(times) / 10)
|
||||
print('Time Taken: ', sum(global_timings[10:]) / 10)
|
||||
|
||||
pass
|
||||
|
Reference in New Issue
Block a user