Robert Müller 482f45df87 big update
2020-04-06 14:46:26 +02:00

150 lines
6.1 KiB
Python

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import pandas as pd
class FeatureExtractor:
supported_extractors = ['resnet18', 'resnet34', 'resnet50',
'alexnet_fc6', 'alexnet_fc7', 'vgg16',
'densenet121', 'inception_v3', 'squeezenet']
def __init__(self, version='resnet18', device='cpu'):
assert version.lower() in self.supported_extractors
self.device = device
self.version = version
self.F = self.__choose_feature_extractor(version)
for param in self.F.parameters():
param.requires_grad = False
self.F.eval()
self.input_size = (299, 299) if version.lower() == 'inception' else (224, 224)
self.transforms = transforms.Compose([
transforms.Resize(self.input_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def to(self, device):
self.device = device
self.F = self.F.to(self.device)
return self
def __choose_feature_extractor(self, version):
if 'resnet' in version.lower():
v = int(version[-2:])
if v == 18:
resnet = torchvision.models.resnet18(pretrained=True)
elif v == 34:
resnet = torchvision.models.resnet34(pretrained=True)
elif v == 50:
resnet = torchvision.models.resnet50(pretrained=True)
return nn.Sequential(*list(resnet.children())[:-1])
elif 'alexnet' in version.lower():
v = int(version[-1])
alexnet = torchvision.models.alexnet(pretrained=True)
if v == 7:
f = nn.Sequential(*list(alexnet.classifier.children())[:-2])
elif v == 6:
f = nn.Sequential(*list(alexnet.classifier.children())[:-5])
alexnet.classifier = f
return alexnet
elif version.lower() == 'vgg16':
vgg = torchvision.models.vgg16_bn(pretrained=True)
classifier = list(
vgg.classifier.children())[:4]
vgg.classifier = nn.Sequential(*classifier)
return vgg
elif version.lower() == 'densenet121':
densenet = torchvision.models.densenet121(pretrained=True)
avg_pool = nn.AvgPool2d(kernel_size=7)
densenet = nn.Sequential(*list(densenet.children())[:-1])
densenet.add_module('avg_pool', avg_pool)
return densenet
elif version.lower() == 'inception_v3':
inception = torchvision.models.inception_v3(pretrained=True)
f = nn.Sequential(*list(inception.children())[:-1])
f._modules.pop('13')
f.add_module('global average', nn.AvgPool2d(26))
return f
elif version.lower() == 'squeezenet':
squeezenet = torchvision.models.squeezenet1_1(pretrained=True)
f = torch.nn.Sequential(
squeezenet.features,
torch.nn.AdaptiveAvgPool2d(output_size=(2, 2))
)
return f
else:
raise NotImplementedError('The feature extractor you requested is not yet supported')
@property
def feature_size(self):
x = torch.randn(size=(1, 3, *self.input_size)).to(self.device)
return self.F(x).squeeze().shape[0]
def __call__(self, batch):
batch = self.transforms(batch)
if len(batch.shape) <= 3:
batch = batch.unsqueeze(0)
return self.F(batch).view(batch.shape[0], -1).squeeze()
def from_image_folder(self, folder_path, extension='jpg'):
sorted_files = sorted(list(folder_path.glob(f'*.{extension}')))
split_names = [x.stem.split('_') for x in sorted_files]
names = [x[0] for x in split_names]
seq_ids = [x[1] for x in split_names]
X = []
for i, p_img in enumerate(tqdm(sorted_files)):
x = Image.open(p_img)
features = self(x)
X.append([names[i], seq_ids[i]] + features.tolist())
return pd.DataFrame(X, columns=['name', 'seq_id', *(f'feature_{i}' for i in range(self.feature_size))])
class AudioTransferLearningImageDataset(Dataset):
def __init__(self, root_or_files, extension='jpg', input_size=224):
self.root_or_files = root_or_files
if type(root_or_files) == list:
self.files = root_or_files
else:
self.files = list(self.root.glob(f'*.{extension}'))
self.transforms = transforms.Compose([
transforms.Resize(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def process_name(self, name):
split_name = name.stem.split('_')
return split_name[0], split_name[1] #name, seq_id
def __getitem__(self, item):
p_img = self.files[item]
x = Image.open(p_img)
x = self.transforms(x)
name, seq_id = self.process_name(p_img)
return x, name, seq_id
def __len__(self):
return len(self.files)
if __name__ == '__main__':
from pathlib import Path
version='resnet18'
extractor = FeatureExtractor(version=version)
models = ['slider', 'pump', 'fan']
model_ids = [0, 2, 4, 6]
for model in models:
for model_id in model_ids:
df = extractor.from_image_folder(Path( f'/home/robert/coding/audio_anomaly_detection/data/mimii/-6_dB_{model}/id_0{model_id}/normal/melspec_images/'))
df.to_csv(Path(f'/home/robert/coding/audio_anomaly_detection/data/mimii/-6_dB_{model}/id_0{model_id}/normal/{version}_features.csv'), index=False)
del df
df = extractor.from_image_folder(Path( f'/home/robert/coding/audio_anomaly_detection/data/mimii/-6_dB_{model}/id_0{model_id}/abnormal/melspec_images/'))
df.to_csv(Path(f'/home/robert/coding/audio_anomaly_detection/data/mimii/-6_dB_{model}/id_0{model_id}/abnormal/{version}_features.csv'), index=False)