Merge remote-tracking branch 'origin/main'

This commit is contained in:
Steffen Illium 2021-11-11 18:42:55 +01:00
commit f625b9d8a5
6 changed files with 53 additions and 50 deletions

View File

@ -6,6 +6,5 @@ Tackling emergent dysfunctions (EDYs) in cooperation with Fraunhofer-IKS
1. Make sure to install `virtualenv` using `pip install virtualenv`
2. Create a new virtual environment `virtualenv venv`
3. Activate the virtual environment `source venv/bin/activate`
4. Install the required dependencies `pip install requirements.txt`
4. Install the required dependencies `pip install -r requirements.txt`
##

View File

@ -11,7 +11,6 @@ from gym import spaces
from gym.wrappers import FrameStack
from environments.factory.base.shadow_casting import Map
from environments.factory.renderer import Renderer, RenderEntity
from environments.helpers import Constants as c, Constants
from environments import helpers as h
from environments.factory.base.objects import Agent, Tile, Action
@ -545,6 +544,8 @@ class BaseFactory(gym.Env):
def render(self, mode='human'):
if not self._renderer: # lazy init
from environments.factory.renderer import Renderer, RenderEntity
global Renderer, RenderEntity
height, width = self._obs_cube.shape[1:]
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)

View File

@ -0,0 +1,13 @@
############
#----------#
#--######--#
#----------#
#--######--#
#----------#
#--######--#
#----------#
#--######--#
#----------#
#--######--#
#----------#
############

View File

@ -1,32 +1,12 @@
appdirs==1.4.4
click==7.1.2
cycler==0.10.0
distlib==0.3.1
filelock==3.0.12
kiwisolver==1.3.1
matplotlib==3.4.1
numpy==1.20.2
pandas~=1.2.3
pygame~=2.0.1
pep517==0.10.0
Pillow==8.2.0
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2021.1
scipy==1.6.3
seaborn==0.11.1
six==1.16.0
toml==0.10.2
torch==1.8.1
torchaudio==0.8.1
torchvision==0.9.1
typing-extensions==3.10.0.0
virtualenv==20.4.6
gym~=0.18.0
PyYAML~=5.3.1
pyglet~=1.5.0
optuna~=2.7.0
natsort~=7.1.1
tqdm~=4.60.0
networkx~=2.6.1
numpy
scipy
tqdm
seaborn>=0.11.1
matplotlib>=3.4.1
stable-baselines3>=1.0
pygame>=2.1.0
gym>=0.18.0
networkx>=2.6.1
simplejson>=3.17.5
PyYAML>=6.0
git+https://github.com/facebookresearch/salina.git@main#egg=salina

View File

@ -1,7 +0,0 @@
# setup.py
from setuptools import setup, find_packages
setup(
name='F_IKS',
packages=find_packages()
)

View File

@ -1,12 +1,29 @@
from environments.factory import make
import random
import salina
import torch
from gym.wrappers import FrameStack
n_agents = 4
env = make('DirtyFactory-v0', n_agents=n_agents)
env = FrameStack(env, num_stack=3)
state, *_ = env.reset()
for i in range(1000):
class MyAgent(salina.TAgent):
def __init__(self):
super(MyAgent, self).__init__()
def forward(self, t, **kwargs):
self.set(('timer', t), torch.tensor([t]))
if __name__ == '__main__':
n_agents = 1
env = make('DirtyFactory-v0', n_agents=n_agents)
env = FrameStack(env, num_stack=3)
env.reset()
agent = MyAgent()
workspace = salina.Workspace()
agent(workspace, t=0, n_steps=10)
print(workspace)
for i in range(1000):
state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)])
env.render()
#env.render()