126 lines
4.0 KiB
Python
126 lines
4.0 KiB
Python
import os
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import cv2
|
|
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
|
|
from transformers import AutoProcessor, AutoModelForVideoClassification
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
|
|
# -----------------------------
|
|
# Paths
|
|
# -----------------------------
|
|
dataset_folder = os.getenv("dataset_folder", "/app/dataset_folder")
|
|
output_folder = os.getenv("output_folder", "/app/output_folder")
|
|
model_folder = os.getenv("model_folder", "base_model")
|
|
model_path = f"{output_folder}/{model_folder}"
|
|
|
|
# -----------------------------
|
|
# Load Model + Processor
|
|
# -----------------------------
|
|
model = AutoModelForVideoClassification.from_pretrained(model_path)
|
|
processor = AutoProcessor.from_pretrained(model_path)
|
|
model.eval()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
# -----------------------------
|
|
# Dataset class
|
|
# -----------------------------
|
|
class VideoNPZDataset(Dataset):
|
|
def __init__(self, files, labels, processor):
|
|
self.files = files
|
|
self.labels = labels
|
|
self.processor = processor
|
|
|
|
def __len__(self):
|
|
return len(self.files)
|
|
|
|
def __getitem__(self, idx):
|
|
file = self.files[idx]
|
|
label = self.labels[idx]
|
|
|
|
data = np.load(file)
|
|
frames = [data[key] for key in sorted(data.files)]
|
|
frames_rgb = [Image.fromarray(frame[:, :, ::-1]) for frame in frames]
|
|
# frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
|
|
pixel_values = self.processor(images=frames_rgb, return_tensors="pt")["pixel_values"][0]
|
|
label_tensor = torch.tensor(label, dtype=torch.long)
|
|
return {
|
|
"pixel_values": pixel_values,
|
|
"labels": label_tensor,
|
|
}
|
|
|
|
# -----------------------------
|
|
# Collate Function
|
|
# -----------------------------
|
|
def collate_fn(batch):
|
|
pixel_values = [item["pixel_values"] for item in batch]
|
|
labels = [item["labels"] for item in batch]
|
|
|
|
pixel_values = torch.stack(pixel_values) # [B, 3, 16, 224, 224]
|
|
labels = torch.tensor(labels, dtype=torch.long) # [B]
|
|
|
|
return {
|
|
"pixel_values": pixel_values,
|
|
"labels": labels,
|
|
}
|
|
|
|
# -----------------------------
|
|
# Load Dataset
|
|
# -----------------------------
|
|
def load_dataset_from_npz(root_dir):
|
|
files = []
|
|
labels = []
|
|
label_map = {}
|
|
for i, class_name in enumerate(sorted(os.listdir(root_dir))):
|
|
label_map[i] = class_name
|
|
class_dir = os.path.join(root_dir, class_name)
|
|
for file in os.listdir(class_dir):
|
|
if file.endswith(".npz"):
|
|
files.append(os.path.join(class_dir, file))
|
|
labels.append(i)
|
|
return files, labels, label_map
|
|
|
|
files, labels, label_map = load_dataset_from_npz(dataset_folder)
|
|
dataset = VideoNPZDataset(files, labels, processor)
|
|
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
|
|
|
|
# -----------------------------
|
|
# Evaluation Loop
|
|
# -----------------------------
|
|
all_preds = []
|
|
all_labels = []
|
|
all_probs = []
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(dataloader, desc="Evaluating"):
|
|
inputs = batch["pixel_values"].to(device)
|
|
labels = batch["labels"].to(device)
|
|
|
|
outputs = model(pixel_values=inputs)
|
|
logits = outputs.logits
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
|
preds = torch.argmax(probs, dim=-1)
|
|
|
|
all_preds.extend(preds.cpu().numpy())
|
|
all_labels.extend(labels.cpu().numpy())
|
|
all_probs.extend(probs[:, 1].cpu().numpy()) # Class 1 probabilities
|
|
|
|
# -----------------------------
|
|
# Metrics
|
|
# -----------------------------
|
|
accuracy = accuracy_score(all_labels, all_preds)
|
|
f1 = f1_score(all_labels, all_preds, average="binary")
|
|
roc_auc = roc_auc_score(all_labels, all_probs)
|
|
conf_matrix = confusion_matrix(all_labels, all_preds)
|
|
|
|
print("\n=== Evaluation Metrics ===")
|
|
print(f"Accuracy : {accuracy:.4f}")
|
|
print(f"F1 Score : {f1:.4f}")
|
|
print(f"ROC AUC : {roc_auc:.4f}")
|
|
print(f"Confusion Matrix:\n{conf_matrix}")
|