added domain shift envs and adjusted reload_agent.py

This commit is contained in:
romue
2021-06-02 09:38:15 +02:00
parent dfca68cbeb
commit 38a3ef7687
10 changed files with 1246 additions and 8 deletions

View File

@ -0,0 +1,120 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import cv2
import skvideo.io
class ImageSource(object):
"""
Source of natural images to be added to a simulated environment.
"""
def get_image(self):
"""
Returns:
an RGB image of [h, w, 3] with a fixed shape.
"""
pass
def reset(self):
""" Called when an episode ends. """
pass
class FixedColorSource(ImageSource):
def __init__(self, shape, color):
"""
Args:
shape: [h, w]
color: a 3-tuple
"""
self.arr = np.zeros((shape[0], shape[1], 3))
self.arr[:, :] = color
def get_image(self):
return np.copy(self.arr)
class RandomColorSource(ImageSource):
def __init__(self, shape):
"""
Args:
shape: [h, w]
"""
self.shape = shape
self.reset()
def reset(self):
self._color = np.random.randint(0, 256, size=(3,))
def get_image(self):
arr = np.zeros((self.shape[0], self.shape[1], 3))
arr[:, :] = self._color
return arr
class NoiseSource(ImageSource):
def __init__(self, shape, strength=50):
"""
Args:
shape: [h, w]
strength (int): the strength of noise, in range [0, 255]
"""
self.shape = shape
self.strength = strength
def get_image(self):
return np.maximum(np.random.randn(
self.shape[0], self.shape[1], 3) * self.strength, 0)
class RandomImageSource(ImageSource):
def __init__(self, shape, filelist):
"""
Args:
shape: [h, w]
filelist: a list of image files
"""
self.shape_wh = shape[::-1]
self.filelist = filelist
self.reset()
def reset(self):
fname = np.random.choice(self.filelist)
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = im[:, :, ::-1]
im = cv2.resize(im, self.shape_wh)
self._im = im
def get_image(self):
return self._im
class RandomVideoSource(ImageSource):
def __init__(self, shape, filelist):
"""
Args:
shape: [h, w]
filelist: a list of video files
"""
self.shape_wh = shape[::-1]
self.filelist = filelist
self.reset()
def reset(self):
fname = np.random.choice(self.filelist)
self.frames = skvideo.io.vread(fname)
self.frame_idx = 0
def get_image(self):
if self.frame_idx >= self.frames.shape[0]:
self.reset()
im = self.frames[self.frame_idx][:, :, ::-1]
self.frame_idx += 1
im = im[:, :, ::-1]
im = cv2.resize(im, self.shape_wh)
return im

View File

@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
class BackgroundMatting(object):
"""
Produce a mask of a given image which will be replaced by natural signals.
"""
def get_mask(self, img):
"""
Take an image of [H, W, 3]. Returns a mask of [H, W]
"""
raise NotImplementedError()
class BackgroundMattingWithColor(BackgroundMatting):
"""
Produce a mask by masking the given color. This is a simple strategy
but effective for many games.
"""
def __init__(self, color):
"""
Args:
color: a (r, g, b) tuple
"""
self._color = color
def get_mask(self, img):
return img == self._color

View File

@ -0,0 +1,99 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import glob
import gym
from gym.utils import play
from matting import BackgroundMattingWithColor
from imgsource import (
RandomImageSource,
RandomColorSource,
NoiseSource,
RandomVideoSource,
)
class ReplaceBackgroundEnv(gym.ObservationWrapper):
viewer = None
def __init__(self, env, bg_matting, natural_source):
"""
The source must produce a image with a shape that's compatible to
`env.observation_space`.
"""
super(ReplaceBackgroundEnv, self).__init__(env)
self._bg_matting = bg_matting
self._natural_source = natural_source
def observation(self, obs):
mask = self._bg_matting.get_mask(obs)
img = self._natural_source.get_image()
obs[mask] = img[mask]
self._last_ob = obs
return obs
def reset(self):
self._natural_source.reset()
return super(ReplaceBackgroundEnv, self).reset()
# modified from gym/envs/atari/atari_env.py
# This makes the monitor work
def render(self, mode="human"):
img = self._last_ob
if mode == "rgb_array":
return img
elif mode == "human":
from gym.envs.classic_control import rendering
if self.viewer is None:
self.viewer = rendering.SimpleImageViewer()
self.viewer.imshow(img)
return env.viewer.isopen
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--env", help="The gym environment to base on")
parser.add_argument("--imgsource", choices=["color", "noise", "images", "videos"])
parser.add_argument(
"--resource-files", help="A glob pattern to obtain images or videos"
)
parser.add_argument("--dump-video", help="If given, a directory to dump video")
args = parser.parse_args()
env = gym.make(args.env)
shape2d = env.observation_space.shape[:2]
if args.imgsource:
if args.imgsource == "color":
imgsource = RandomColorSource(shape2d)
elif args.imgsource == "noise":
imgsource = NoiseSource(shape2d)
else:
files = glob.glob(os.path.expanduser(args.resource_files))
assert len(files), "Pattern {} does not match any files".format(
args.resource_files
)
if args.imgsource == "images":
imgsource = RandomImageSource(shape2d, files)
else:
imgsource = RandomVideoSource(shape2d, files)
wrapped_env = ReplaceBackgroundEnv(
env, BackgroundMattingWithColor((0, 0, 0)), imgsource
)
else:
wrapped_env = env
if args.dump_video:
assert os.path.isdir(args.dump_video)
wrapped_env = gym.wrappers.Monitor(wrapped_env, args.dump_video)
play.play(wrapped_env, zoom=4)