40 lines
1.4 KiB
Python

from common import BaseLearner, TrajectoryBuffer
class AWRLearner(BaseLearner):
def __init__(self, *args, buffer_size=1e5, **kwargs):
super(AWRLearner, self).__init__(*args, **kwargs)
assert self.train_every[0] == 'episode', 'AWR only supports the episodic RL setting!'
self.buffer = TrajectoryBuffer(buffer_size)
def train(self):
# convert to trajectory format
pass
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(font_scale=1.25, rc={'text.usetex': True})
data = np.array([[689, 74], [71, 647]])
cats = ['Mask', 'No Mask']
df = pd.DataFrame(data/np.sum(data), index=cats, columns=cats)
group_counts = ['{0:0.0f}'.format(value) for value in
data.flatten()]
group_percentages = [f'{value*100:.2f}' + r'$\%$' for value in
data.flatten()/np.sum(data)]
labels = [f'{v1}\n{v2}' for v1, v2 in
zip(group_counts,group_percentages)]
labels = np.asarray(labels).reshape(2,2)
with sns.axes_style("white"):
cmap = sns.diverging_palette(h_neg=100, h_pos=10, s=99, l=55, sep=3, as_cmap=True)
sns.heatmap(data, annot=labels, fmt='', cmap='Set2_r', square=True, cbar=False, xticklabels=cats,yticklabels=cats)
plt.title('Simple-CNN')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.savefig('cnn.pdf', bbox_inches='tight')