Item and Dirt Factory Working again

This commit is contained in:
Steffen Illium
2021-12-23 13:19:31 +01:00
parent b43f595207
commit 78bf19f7f4
11 changed files with 257 additions and 321 deletions

View File

@ -1,5 +1,4 @@
from collections import defaultdict
from enum import Enum
from typing import Union
import networkx as nx
@ -29,24 +28,18 @@ class Object:
@property
def identifier(self):
if self._enum_ident is not None:
return self._enum_ident
elif self._str_ident is not None:
if self._str_ident is not None:
return self._str_ident
else:
return self._name
def __init__(self, str_ident: Union[str, None] = None, enum_ident: Union[Enum, None] = None,
is_blocking_light=False, **kwargs):
def __init__(self, str_ident: Union[str, None] = None, is_blocking_light=False, **kwargs):
self._str_ident = str_ident
self._enum_ident = enum_ident
if self._enum_ident is not None and self._str_ident is None:
self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]'
elif self._str_ident is not None and self._enum_ident is None:
if self._str_ident is not None:
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
elif self._str_ident is None and self._enum_ident is None:
elif self._str_ident is None:
self._name = f'{self.__class__.__name__}#{Object._u_idx[self.__class__.__name__]}'
Object._u_idx[self.__class__.__name__] += 1
else:
@ -60,16 +53,7 @@ class Object:
return f'{self.name}'
def __eq__(self, other) -> bool:
if self._enum_ident is not None:
if isinstance(other, Enum):
return other == self._enum_ident
elif isinstance(other, Object):
return other._enum_ident == self._enum_ident
else:
raise ValueError('Must be evaluated against an Enunm Identifier or Object with such.')
else:
assert isinstance(other, Object), ' This Object can only be compared to other Objects.'
return other.name == self.name
return other == self.identifier
class EnvObject(Object):
@ -80,14 +64,17 @@ class EnvObject(Object):
@property
def encoding(self):
return c.OCCUPIED_CELL.value
return c.OCCUPIED_CELL
def __init__(self, register, **kwargs):
super(EnvObject, self).__init__(**kwargs)
self._register = register
def change_register(self, register):
self._register = register
class BoundingMixin:
class BoundingMixin(Object):
@property
def bound_entity(self):
@ -163,7 +150,7 @@ class MoveableEntity(Entity):
if self._last_tile:
return self._last_tile.pos
else:
return c.NO_POS.value
return c.NO_POS
@property
def direction_of_view(self):
@ -218,30 +205,27 @@ class PlaceHolder(Object):
return "PlaceHolder"
class GlobalPosition(EnvObject):
class GlobalPosition(EnvObject, BoundingMixin):
def belongs_to_entity(self, entity):
return self._agent == entity
@property
def encoding(self):
if self._normalized:
return tuple(np.diff(self._bound_entity.pos, self._level_shape))
else:
return self.bound_entity.pos
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(self, *args, **kwargs)
def __init__(self, level_shape, obs_shape, agent, normalized: bool = True):
super(GlobalPosition, self).__init__(self)
self._obs_shape = (1, *obs_shape) if len(obs_shape) == 2 else obs_shape
self._agent = agent
self._level_shape = level_shape
self._normalized = normalized
def as_array(self):
pos_array = np.zeros(self._obs_shape)
for xy in range(1):
pos_array[0, 0, xy] = self._agent.pos[xy] / self._level_shape[xy]
return pos_array
class Tile(EnvObject):
@property
def encoding(self):
return c.FREE_CELL.value
return c.FREE_CELL
@property
def guests_that_can_collide(self):
@ -302,7 +286,7 @@ class Wall(Tile):
@property
def encoding(self):
return c.OCCUPIED_CELL.value
return c.OCCUPIED_CELL
pass
@ -319,7 +303,7 @@ class Door(Entity):
@property
def encoding(self):
# This is important as it shadow is checked by occupation value
return c.OCCUPIED_CELL.value if self.is_closed else 2
return c.OCCUPIED_CELL if self.is_closed else 2
@property
def str_state(self):
@ -403,7 +387,7 @@ class Agent(MoveableEntity):
# noinspection PyAttributeOutsideInit
def clear_temp_state(self):
# for attr in self.__dict__:
# for attr in cls.__dict__:
# if attr.startswith('temp'):
self.temp_collisions = []
self.temp_valid = None