mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			410 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			410 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from collections import defaultdict
 | |
| from typing import Union
 | |
| 
 | |
| import networkx as nx
 | |
| import numpy as np
 | |
| 
 | |
| from environments import helpers as h
 | |
| from environments.helpers import Constants as c
 | |
| import itertools
 | |
| 
 | |
| ##########################################################################
 | |
| # ##################### Base Object Building Blocks ######################### #
 | |
| ##########################################################################
 | |
| 
 | |
| 
 | |
| # TODO: Missing Documentation
 | |
| class Object:
 | |
| 
 | |
|     """Generell Objects for Organisation and Maintanance such as Actions etc..."""
 | |
| 
 | |
|     _u_idx = defaultdict(lambda: 0)
 | |
| 
 | |
|     def __bool__(self):
 | |
|         return True
 | |
| 
 | |
|     @property
 | |
|     def name(self):
 | |
|         return self._name
 | |
| 
 | |
|     @property
 | |
|     def identifier(self):
 | |
|         if self._str_ident is not None:
 | |
|             return self._str_ident
 | |
|         else:
 | |
|             return self._name
 | |
| 
 | |
|     def __init__(self, str_ident: Union[str, None] = None, **kwargs):
 | |
| 
 | |
|         self._str_ident = str_ident
 | |
| 
 | |
|         if self._str_ident is not None:
 | |
|             self._name = f'{self.__class__.__name__}[{self._str_ident}]'
 | |
|         elif self._str_ident is None:
 | |
|             self._name = f'{self.__class__.__name__}#{Object._u_idx[self.__class__.__name__]}'
 | |
|             Object._u_idx[self.__class__.__name__] += 1
 | |
|         else:
 | |
|             raise ValueError('Please use either of the idents.')
 | |
| 
 | |
|         if kwargs:
 | |
|             print(f'Following kwargs were passed, but ignored: {kwargs}')
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return f'{self.name}'
 | |
| 
 | |
|     def __eq__(self, other) -> bool:
 | |
|         return other == self.identifier
 | |
| # Base
 | |
| 
 | |
| 
 | |
| # TODO: Missing Documentation
 | |
| class EnvObject(Object):
 | |
| 
 | |
|     """Objects that hold Information that are observable, but have no position on the env grid. Inventories etc..."""
 | |
| 
 | |
|     _u_idx = defaultdict(lambda: 0)
 | |
| 
 | |
|     @property
 | |
|     def can_collide(self):
 | |
|         return False
 | |
| 
 | |
|     @property
 | |
|     def encoding(self):
 | |
|         return c.OCCUPIED_CELL
 | |
| 
 | |
|     def __init__(self, register, **kwargs):
 | |
|         super(EnvObject, self).__init__(**kwargs)
 | |
|         self._register = register
 | |
| 
 | |
|     def change_register(self, register):
 | |
|         register.register_item(self)
 | |
|         self._register.delete_env_object(self)
 | |
|         self._register = register
 | |
|         return self._register == register
 | |
| # With Rendering
 | |
| 
 | |
| 
 | |
| # TODO: Missing Documentation
 | |
| class Entity(EnvObject):
 | |
|     """Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
 | |
| 
 | |
|     @property
 | |
|     def can_collide(self):
 | |
|         return False
 | |
| 
 | |
|     @property
 | |
|     def x(self):
 | |
|         return self.pos[0]
 | |
| 
 | |
|     @property
 | |
|     def y(self):
 | |
|         return self.pos[1]
 | |
| 
 | |
|     @property
 | |
|     def pos(self):
 | |
|         return self._tile.pos
 | |
| 
 | |
|     @property
 | |
|     def tile(self):
 | |
|         return self._tile
 | |
| 
 | |
|     def __init__(self, tile, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
|         self._tile = tile
 | |
|         tile.enter(self)
 | |
| 
 | |
|     def summarize_state(self, **_) -> dict:
 | |
|         return dict(name=str(self.name), x=int(self.x), y=int(self.y),
 | |
|                     tile=str(self.tile.name), can_collide=bool(self.can_collide))
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return super(Entity, self).__repr__() + f'(@{self.pos})'
 | |
| # With Position in Env
 | |
| 
 | |
| 
 | |
| # TODO: Missing Documentation
 | |
| class MoveableEntity(Entity):
 | |
| 
 | |
|     @property
 | |
|     def last_tile(self):
 | |
|         return self._last_tile
 | |
| 
 | |
|     @property
 | |
|     def last_pos(self):
 | |
|         if self._last_tile:
 | |
|             return self._last_tile.pos
 | |
|         else:
 | |
|             return c.NO_POS
 | |
| 
 | |
|     @property
 | |
|     def direction_of_view(self):
 | |
|         last_x, last_y = self.last_pos
 | |
|         curr_x, curr_y = self.pos
 | |
|         return last_x-curr_x, last_y-curr_y
 | |
| 
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
|         self._last_tile = None
 | |
| 
 | |
|     def move(self, next_tile):
 | |
|         curr_tile = self.tile
 | |
|         if curr_tile != next_tile:
 | |
|             next_tile.enter(self)
 | |
|             curr_tile.leave(self)
 | |
|             self._tile = next_tile
 | |
|             self._last_tile = curr_tile
 | |
|             self._register.notify_change_to_value(self)
 | |
|             return c.VALID
 | |
|         else:
 | |
|             return c.NOT_VALID
 | |
| # Can Move
 | |
| 
 | |
| 
 | |
| # TODO: Missing Documentation
 | |
| class BoundingMixin(Object):
 | |
| 
 | |
|     @property
 | |
|     def bound_entity(self):
 | |
|         return self._bound_entity
 | |
| 
 | |
|     def __init__(self,entity_to_be_bound, *args, **kwargs):
 | |
|         super(BoundingMixin, self).__init__(*args, **kwargs)
 | |
|         assert entity_to_be_bound is not None
 | |
|         self._bound_entity = entity_to_be_bound
 | |
| 
 | |
|     @property
 | |
|     def name(self):
 | |
|         return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
 | |
| 
 | |
|     def belongs_to_entity(self, entity):
 | |
|         return entity == self.bound_entity
 | |
| 
 | |
| 
 | |
| ##########################################################################
 | |
| # ####################### Objects and Entitys ########################## #
 | |
| ##########################################################################
 | |
| 
 | |
| 
 | |
| class Action(Object):
 | |
| 
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
| 
 | |
| 
 | |
| class PlaceHolder(Object):
 | |
| 
 | |
|     def __init__(self, *args, fill_value=0, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
|         self._fill_value = fill_value
 | |
| 
 | |
|     @property
 | |
|     def can_collide(self):
 | |
|         return False
 | |
| 
 | |
|     @property
 | |
|     def encoding(self):
 | |
|         return self._fill_value
 | |
| 
 | |
|     @property
 | |
|     def name(self):
 | |
|         return "PlaceHolder"
 | |
| 
 | |
| 
 | |
| class GlobalPosition(BoundingMixin, EnvObject):
 | |
| 
 | |
|     @property
 | |
|     def encoding(self):
 | |
|         if self._normalized:
 | |
|             return tuple(np.divide(self._bound_entity.pos, self._level_shape))
 | |
|         else:
 | |
|             return self.bound_entity.pos
 | |
| 
 | |
|     def __init__(self, level_shape: (int, int), *args, normalized: bool = True, **kwargs):
 | |
|         super(GlobalPosition, self).__init__(*args, **kwargs)
 | |
|         self._level_shape = level_shape
 | |
|         self._normalized = normalized
 | |
| 
 | |
| 
 | |
| class Floor(EnvObject):
 | |
| 
 | |
|     @property
 | |
|     def encoding(self):
 | |
|         return c.FREE_CELL
 | |
| 
 | |
|     @property
 | |
|     def guests_that_can_collide(self):
 | |
|         return [x for x in self.guests if x.can_collide]
 | |
| 
 | |
|     @property
 | |
|     def guests(self):
 | |
|         return self._guests.values()
 | |
| 
 | |
|     @property
 | |
|     def x(self):
 | |
|         return self.pos[0]
 | |
| 
 | |
|     @property
 | |
|     def y(self):
 | |
|         return self.pos[1]
 | |
| 
 | |
|     @property
 | |
|     def pos(self):
 | |
|         return self._pos
 | |
| 
 | |
|     def __init__(self, pos, *args, **kwargs):
 | |
|         super(Floor, self).__init__(*args, **kwargs)
 | |
|         self._guests = dict()
 | |
|         self._pos = tuple(pos)
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self._guests)
 | |
| 
 | |
|     def is_empty(self):
 | |
|         return not len(self._guests)
 | |
| 
 | |
|     def is_occupied(self):
 | |
|         return bool(len(self._guests))
 | |
| 
 | |
|     def enter(self, guest):
 | |
|         if guest.name not in self._guests:
 | |
|             self._guests.update({guest.name: guest})
 | |
|             return True
 | |
|         else:
 | |
|             return False
 | |
| 
 | |
|     def leave(self, guest):
 | |
|         try:
 | |
|             del self._guests[guest.name]
 | |
|         except (ValueError, KeyError):
 | |
|             return False
 | |
|         return True
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return f'{self.name}(@{self.pos})'
 | |
| 
 | |
|     def summarize_state(self, **_):
 | |
|         return dict(name=self.name, x=int(self.x), y=int(self.y))
 | |
| 
 | |
| 
 | |
| class Wall(Floor):
 | |
| 
 | |
|     @property
 | |
|     def can_collide(self):
 | |
|         return True
 | |
| 
 | |
|     @property
 | |
|     def encoding(self):
 | |
|         return c.OCCUPIED_CELL
 | |
| 
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class Door(Entity):
 | |
| 
 | |
|     @property
 | |
|     def can_collide(self):
 | |
|         if self.has_area:
 | |
|             return False if self.is_open else True
 | |
|         else:
 | |
|             return False
 | |
| 
 | |
|     @property
 | |
|     def encoding(self):
 | |
|         # This is important as it shadow is checked by occupation value
 | |
|         return c.CLOSED_DOOR_CELL if self.is_closed else c.OPEN_DOOR_CELL
 | |
| 
 | |
|     @property
 | |
|     def str_state(self):
 | |
|         return 'open' if self.is_open else 'closed'
 | |
| 
 | |
|     @property
 | |
|     def access_area(self):
 | |
|         return [node for node in self.connectivity.nodes
 | |
|                 if node not in range(len(self.connectivity_subgroups)) and node != self.pos]
 | |
| 
 | |
|     def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10, has_area=False, **kwargs):
 | |
|         super(Door, self).__init__(*args, **kwargs)
 | |
|         self._state = c.CLOSED_DOOR
 | |
|         self.has_area = has_area
 | |
|         self.auto_close_interval = auto_close_interval
 | |
|         self.time_to_close = -1
 | |
|         neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
 | |
|         neighbor_tiles = [context.by_pos(tuple([sum(x) for x in zip(self.pos, diff)])) for diff in neighbor_pos]
 | |
|         neighbor_pos = [x.pos for x in neighbor_tiles if x]
 | |
|         self.connectivity = h.points_to_graph(neighbor_pos)
 | |
|         self.connectivity_subgroups = list(nx.algorithms.components.connected_components(self.connectivity))
 | |
|         for idx, group in enumerate(self.connectivity_subgroups):
 | |
|             for tile_pos in group:
 | |
|                 self.connectivity.add_edge(tile_pos, idx)
 | |
|         if not closed_on_init:
 | |
|             self._open()
 | |
| 
 | |
|     def summarize_state(self, **kwargs):
 | |
|         state_dict = super().summarize_state(**kwargs)
 | |
|         state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
 | |
|         return state_dict
 | |
| 
 | |
|     @property
 | |
|     def is_closed(self):
 | |
|         return self._state == c.CLOSED_DOOR
 | |
| 
 | |
|     @property
 | |
|     def is_open(self):
 | |
|         return self._state == c.OPEN_DOOR
 | |
| 
 | |
|     @property
 | |
|     def status(self):
 | |
|         return self._state
 | |
| 
 | |
|     def use(self):
 | |
|         if self._state == c.OPEN_DOOR:
 | |
|             self._close()
 | |
|         else:
 | |
|             self._open()
 | |
| 
 | |
|     def tick(self):
 | |
|         if self.is_open and len(self.tile) == 1 and self.time_to_close:
 | |
|             self.time_to_close -= 1
 | |
|         elif self.is_open and not self.time_to_close and len(self.tile) == 1:
 | |
|             self.use()
 | |
| 
 | |
|     def _open(self):
 | |
|         self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
 | |
|         self._state = c.OPEN_DOOR
 | |
|         self._register.notify_change_to_value(self)
 | |
|         self.time_to_close = self.auto_close_interval
 | |
| 
 | |
|     def _close(self):
 | |
|         self.connectivity.remove_node(self.pos)
 | |
|         self._state = c.CLOSED_DOOR
 | |
|         self._register.notify_change_to_value(self)
 | |
| 
 | |
|     def is_linked(self, old_pos, new_pos):
 | |
|         try:
 | |
|             _ = nx.shortest_path(self.connectivity, old_pos, new_pos)
 | |
|             return True
 | |
|         except nx.exception.NetworkXNoPath:
 | |
|             return False
 | |
| 
 | |
| 
 | |
| class Agent(MoveableEntity):
 | |
| 
 | |
|     @property
 | |
|     def can_collide(self):
 | |
|         return True
 | |
| 
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super(Agent, self).__init__(*args, **kwargs)
 | |
|         self.clear_temp_state()
 | |
| 
 | |
|     # noinspection PyAttributeOutsideInit
 | |
|     def clear_temp_state(self):
 | |
|         # for attr in cls.__dict__:
 | |
|         #   if attr.startswith('temp'):
 | |
|         self.step_result = None
 | |
| 
 | |
|     def summarize_state(self, **kwargs):
 | |
|         state_dict = super().summarize_state(**kwargs)
 | |
|         state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
 | |
|         return state_dict
 | 
