mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
120 lines
3.9 KiB
Python
Executable File
120 lines
3.9 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import argparse
|
|
import glob
|
|
import gym
|
|
from gym.utils import play
|
|
|
|
from .matting import BackgroundMattingWithColor
|
|
from .imgsource import (
|
|
RandomImageSource,
|
|
RandomColorSource,
|
|
NoiseSource,
|
|
RandomVideoSource,
|
|
)
|
|
|
|
|
|
class ReplaceBackgroundEnv(gym.ObservationWrapper):
|
|
|
|
viewer = None
|
|
|
|
def __init__(self, env, bg_matting, natural_source):
|
|
"""
|
|
The source must produce a image with a shape that's compatible to
|
|
`env.observation_space`.
|
|
"""
|
|
super(ReplaceBackgroundEnv, self).__init__(env)
|
|
self._bg_matting = bg_matting
|
|
self._natural_source = natural_source
|
|
|
|
def observation(self, obs):
|
|
mask = self._bg_matting.get_mask(obs)
|
|
img = self._natural_source.get_image()
|
|
obs[mask] = img[mask]
|
|
self._last_ob = obs
|
|
return obs
|
|
|
|
def reset(self):
|
|
self._natural_source.reset()
|
|
return super(ReplaceBackgroundEnv, self).reset()
|
|
|
|
# modified from gym/envs/atari/atari_env.py
|
|
# This makes the monitor work
|
|
def render(self, mode="human"):
|
|
img = self._last_ob
|
|
if mode == "rgb_array":
|
|
return img
|
|
elif mode == "human":
|
|
from gym.envs.classic_control import rendering
|
|
|
|
if self.viewer is None:
|
|
self.viewer = rendering.SimpleImageViewer()
|
|
self.viewer.imshow(img)
|
|
return env.viewer.isopen
|
|
|
|
|
|
def make(name='Pong-v0', imgsource='color', files=None):
|
|
env = gym.make(name) # gravitar, breakout, MsPacman, Space Invaders
|
|
shape2d = env.observation_space.shape[:2]
|
|
color = (0, 0, 0) if 'Pong' not in name else (144, 72, 17)
|
|
if imgsource == 'video':
|
|
imgsource = RandomVideoSource(shape2d, ['/Users/romue/PycharmProjects/EDYS/environments/policy_adaption/natural_rl_environment/videos/stars.mp4'])
|
|
elif imgsource == "color":
|
|
imgsource = RandomColorSource(shape2d)
|
|
elif imgsource == "noise":
|
|
imgsource = NoiseSource(shape2d)
|
|
if args.imgsource == "images":
|
|
imgsource = RandomImageSource(shape2d, files)
|
|
else:
|
|
raise NotImplementedError(f'{imgsource} is not supported, use one of {{video, color, noise}}')
|
|
wrapped_env = ReplaceBackgroundEnv(
|
|
env, BackgroundMattingWithColor(color), imgsource
|
|
)
|
|
return wrapped_env
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--env", help="The gym environment to base on")
|
|
parser.add_argument("--imgsource", choices=["color", "noise", "images", "videos"])
|
|
parser.add_argument(
|
|
"--resource-files", help="A glob pattern to obtain images or videos"
|
|
)
|
|
parser.add_argument("--dump-video", help="If given, a directory to dump video")
|
|
args = parser.parse_args()
|
|
|
|
env = gym.make(args.env)
|
|
shape2d = env.observation_space.shape[:2]
|
|
|
|
if args.imgsource:
|
|
if args.imgsource == "color":
|
|
imgsource = RandomColorSource(shape2d)
|
|
elif args.imgsource == "noise":
|
|
imgsource = NoiseSource(shape2d)
|
|
else:
|
|
files = glob.glob(os.path.expanduser(args.resource_files))
|
|
assert len(files), "Pattern {} does not match any files".format(
|
|
args.resource_files
|
|
)
|
|
if args.imgsource == "images":
|
|
imgsource = RandomImageSource(shape2d, files)
|
|
else:
|
|
imgsource = RandomVideoSource(shape2d, files)
|
|
|
|
wrapped_env = ReplaceBackgroundEnv(
|
|
env, BackgroundMattingWithColor((0, 0, 0)), imgsource
|
|
)
|
|
else:
|
|
wrapped_env = env
|
|
|
|
if args.dump_video:
|
|
assert os.path.isdir(args.dump_video)
|
|
wrapped_env = gym.wrappers.Monitor(wrapped_env, args.dump_video)
|
|
play.play(wrapped_env, zoom=4)
|