paper preperations and notebooks, optuna callbacks
This commit is contained in:
@@ -5,6 +5,7 @@ from abc import ABC
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.loss_mixin import LossMixin
|
||||
@@ -59,7 +60,7 @@ class ValMixin:
|
||||
|
||||
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
|
||||
if self.params.n_classes <= 2:
|
||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()])
|
||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x.squeeze(-1) for x in sorted_y.values()])
|
||||
self.metrics.update(mean_sorted_y, target_y)
|
||||
else:
|
||||
y_max = torch.stack(
|
||||
@@ -97,8 +98,11 @@ class ValMixin:
|
||||
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
||||
|
||||
if self.params.n_classes <= 2:
|
||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
||||
mean_sorted_y = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
|
||||
mean_sorted_y = torch.stack(mean_sorted_y).squeeze(1)
|
||||
# mean_sorted_y = mean_sorted_y if mean_sorted_y.numel() > 1 else mean_sorted_y.unsqueeze(-1)
|
||||
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
|
||||
|
||||
# Sklearn Scores
|
||||
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
|
||||
|
||||
@@ -129,6 +133,7 @@ class ValMixin:
|
||||
|
||||
for name, image in pl_images.items():
|
||||
self.logger.log_image(name, image, step=self.global_step)
|
||||
plt.close(image)
|
||||
pass
|
||||
|
||||
|
||||
@@ -162,12 +167,13 @@ class TestMixin:
|
||||
class_names = {val: key for val, key in
|
||||
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||
else:
|
||||
pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
||||
pred = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
|
||||
pred = torch.stack(pred).squeeze()
|
||||
pred = torch.where(pred > 0.5, 1, 0)
|
||||
class_names = {val: key for val, key in enumerate(['negative', 'positive'])}
|
||||
|
||||
|
||||
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
||||
df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],
|
||||
prediction=[class_names[x.item()] for x in pred.cpu()]))
|
||||
result_file = Path(self.logger.log_dir / 'predictions.csv')
|
||||
if result_file.exists():
|
||||
@@ -178,12 +184,13 @@ class TestMixin:
|
||||
pass
|
||||
with result_file.open(mode='wb') as csv_file:
|
||||
df.to_csv(index=False, path_or_buf=csv_file)
|
||||
with result_file.open(mode='rb') as csv_file:
|
||||
try:
|
||||
self.logger.neptunelogger.log_artifact(csv_file)
|
||||
except:
|
||||
print('No possible to send to neptune')
|
||||
pass
|
||||
if False:
|
||||
with result_file.open(mode='rb') as csv_file:
|
||||
try:
|
||||
self.logger.neptunelogger.log_artifact(csv_file)
|
||||
except:
|
||||
print('No possible to send to neptune')
|
||||
pass
|
||||
|
||||
|
||||
class CombinedModelMixins(LossMixin,
|
||||
|
||||
Reference in New Issue
Block a user