2021-06-04 17:17:29 +02:00

121 lines
2.9 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import cv2
import skvideo.io
class ImageSource(object):
"""
Source of natural images to be added to a simulated environment.
"""
def get_image(self):
"""
Returns:
an RGB image of [h, w, 3] with a fixed shape.
"""
pass
def reset(self):
""" Called when an episode ends. """
pass
class FixedColorSource(ImageSource):
def __init__(self, shape, color):
"""
Args:
shape: [h, w]
color: a 3-tuple
"""
self.arr = np.zeros((shape[0], shape[1], 3))
self.arr[:, :] = color
def get_image(self):
return np.copy(self.arr)
class RandomColorSource(ImageSource):
def __init__(self, shape):
"""
Args:
shape: [h, w]
"""
self.shape = shape
self.reset()
def reset(self):
self._color = np.random.randint(0, 256, size=(3,))
def get_image(self):
arr = np.zeros((self.shape[0], self.shape[1], 3))
arr[:, :] = self._color
return arr
class NoiseSource(ImageSource):
def __init__(self, shape, strength=50):
"""
Args:
shape: [h, w]
strength (int): the strength of noise, in range [0, 255]
"""
self.shape = shape
self.strength = strength
def get_image(self):
return np.maximum(np.random.randn(
self.shape[0], self.shape[1], 3) * self.strength, 0)
class RandomImageSource(ImageSource):
def __init__(self, shape, filelist):
"""
Args:
shape: [h, w]
filelist: a list of image files
"""
self.shape_wh = shape[::-1]
self.filelist = filelist
self.reset()
def reset(self):
fname = np.random.choice(self.filelist)
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = im[:, :, ::-1]
im = cv2.resize(im, self.shape_wh)
self._im = im
def get_image(self):
return self._im
class RandomVideoSource(ImageSource):
def __init__(self, shape, filelist):
"""
Args:
shape: [h, w]
filelist: a list of video files
"""
self.shape_wh = shape[::-1]
self.filelist = filelist
self.reset()
def reset(self):
fname = np.random.choice(self.filelist)
self.frames = skvideo.io.vread(fname)
self.frame_idx = 0
def get_image(self):
if self.frame_idx >= self.frames.shape[0]:
self.reset()
im = self.frames[self.frame_idx][:, :, ::-1]
self.frame_idx += 1
im = im[:, :, ::-1]
im = cv2.resize(im, self.shape_wh)
return im