170 lines
5.4 KiB
Python
170 lines
5.4 KiB
Python
import os
|
|
import cv2
|
|
import torch
|
|
import numpy as np
|
|
import random
|
|
import matplotlib.pyplot as plt
|
|
from torch import nn
|
|
from transformers import AutoModelForVideoClassification, AutoProcessor, TrainingArguments, Trainer
|
|
from sklearn.model_selection import train_test_split
|
|
from torch.utils.data import Dataset
|
|
|
|
dataset_folder = os.getenv("dataset_folder", "/app/dataset_folder")
|
|
output_folder = os.getenv("output_folder", "/app/output_folder")
|
|
os.makedirs(dataset_folder, exist_ok=True)
|
|
os.makedirs(f"{output_folder}/base_model", exist_ok=True)
|
|
os.makedirs(f"{output_folder}/tuned", exist_ok=True)
|
|
|
|
|
|
model = AutoModelForVideoClassification.from_pretrained(f"{output_folder}/base_model")
|
|
processor = AutoProcessor.from_pretrained(f"{output_folder}/base_model")
|
|
|
|
for param in model.base_model.parameters():
|
|
param.requires_grad = False
|
|
|
|
# Replace classifier with binary output
|
|
model.classifier = nn.Linear(model.classifier.in_features, 2)
|
|
|
|
class VideoNPZDataset(Dataset):
|
|
def __init__(self, files, labels, processor):
|
|
self.files = files
|
|
self.labels = labels
|
|
self.processor = processor
|
|
|
|
def __len__(self):
|
|
return len(self.files)
|
|
|
|
def __getitem__(self, idx):
|
|
file = self.files[idx]
|
|
label = self.labels[idx]
|
|
|
|
|
|
data = np.load(file)
|
|
frames = [data[key] for key in sorted(data.files)]
|
|
|
|
# Debug print
|
|
# for i, frame in enumerate(frames):
|
|
# print(f" Frame {i} shape: {frame.shape}, dtype: {frame.dtype}")
|
|
|
|
# Convert to RGB (assumes frames are in BGR format)
|
|
frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
|
|
|
|
# for i, frame in enumerate(frames_rgb):
|
|
# print(f" Frame {i} post-RGB shape: {frame.shape}")
|
|
|
|
# Process the video frames
|
|
pixel_values = self.processor(images=frames_rgb, return_tensors="pt")["pixel_values"][0] # shape: [16, 3, 224, 224]
|
|
label_values = torch.tensor(label, dtype=torch.long)
|
|
# print(f"[{file}]")
|
|
# print(f" Processor output shape (before permute): {pixel_values.shape}")
|
|
# print(f" Processor label shape (before permute): {label_values.shape}")
|
|
|
|
return {
|
|
"pixel_values": pixel_values, # [3, 16, 224, 224]
|
|
"labels": label_values # scalar tensor
|
|
}
|
|
|
|
class myTrainer(Trainer):
|
|
def compute_loss(self, model, inputs, return_outputs=False):
|
|
print(f"[DEBUG] pixel_values: {inputs['pixel_values'].shape}")
|
|
print(f"[DEBUG] labels: {inputs['labels'].shape}")
|
|
|
|
return super().compute_loss(model, inputs, return_outputs)
|
|
|
|
def collate_fn(batch):
|
|
pixel_values = []
|
|
labels = []
|
|
|
|
for item in batch:
|
|
video = item["pixel_values"] # shape: (16, 3, 224, 224)
|
|
label = item["labels"]
|
|
|
|
pixel_values.append(video)
|
|
labels.append(label)
|
|
|
|
pixel_values = torch.stack(pixel_values) # (batch_size, 3, 16, 224, 224)
|
|
labels_values = torch.tensor(labels, dtype=torch.long) # (batch_size, 3, 16, 224, 224)
|
|
|
|
return {
|
|
"pixel_values": pixel_values,
|
|
"labels": labels_values,
|
|
}
|
|
|
|
|
|
def load_dataset_from_npz(root_dir):
|
|
files = []
|
|
labels = []
|
|
label_map = {} # for inference class map
|
|
for i, class_name in enumerate(sorted(os.listdir(root_dir))):
|
|
label_map[i] = class_name
|
|
class_dir = os.path.join(root_dir, class_name)
|
|
for file in os.listdir(class_dir):
|
|
if file.endswith(".npz"):
|
|
files.append(os.path.join(class_dir, file))
|
|
labels.append(i)
|
|
return files, labels, label_map
|
|
|
|
files, labels, label_map = load_dataset_from_npz(dataset_folder)
|
|
|
|
# for file, label in zip(files, labels):
|
|
# print(f"{file} {label} {label_map[label]}")
|
|
|
|
print(f" files: {len(files)}")
|
|
print(f" labels: {len(labels)}")
|
|
|
|
import torch
|
|
print("CUDA available:", torch.cuda.is_available())
|
|
print("Device count:", torch.cuda.device_count())
|
|
print("Current device:", torch.cuda.current_device())
|
|
print("Device name:", torch.cuda.get_device_name(0))
|
|
|
|
train_files, val_files, train_labels, val_labels = train_test_split(files, labels, test_size=0.2, stratify=labels, random_state=random.randint(1,5000))
|
|
|
|
train_dataset = VideoNPZDataset(train_files, train_labels, processor)
|
|
val_dataset = VideoNPZDataset(val_files, val_labels, processor)
|
|
|
|
def compute_metrics(eval_pred):
|
|
logits, labels = eval_pred
|
|
preds = torch.sigmoid(torch.tensor(logits)).numpy() > 0.5
|
|
accuracy = (preds.flatten() == labels).mean()
|
|
return {"accuracy": accuracy}
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir="./results",
|
|
evaluation_strategy="epoch",
|
|
save_strategy="epoch",
|
|
per_device_train_batch_size=2,
|
|
per_device_eval_batch_size=2,
|
|
num_train_epochs=10,
|
|
logging_dir="./logs",
|
|
logging_steps=10,
|
|
save_total_limit=2,
|
|
load_best_model_at_end=True,
|
|
)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=val_dataset,
|
|
data_collator=collate_fn,
|
|
)
|
|
|
|
|
|
trainer.train()
|
|
|
|
logs = trainer.state.log_history
|
|
train_loss = [x["loss"] for x in logs if "loss" in x]
|
|
eval_loss = [x["eval_loss"] for x in logs if "eval_loss" in x]
|
|
|
|
plt.plot(train_loss, label="Train Loss")
|
|
plt.plot(eval_loss, label="Eval Loss")
|
|
plt.xlabel("Log Steps")
|
|
plt.ylabel("Loss")
|
|
plt.legend()
|
|
plt.title("Loss Curve")
|
|
plt.savefig(f"{output_folder}/tuned/loss_curve.png")
|
|
|
|
trainer.save_model(f"{output_folder}/tuned")
|
|
processor.save_pretrained(f"{output_folder}/tuned")
|