Tiresias/transfer/transfer.py

180 lines
5.7 KiB
Python

import os
import cv2
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
from transformers import AutoModelForVideoClassification, AutoProcessor, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from PIL import Image
dataset_folder = os.getenv("dataset_folder", "/app/dataset_folder")
output_folder = os.getenv("output_folder", "/app/output_folder")
os.makedirs(dataset_folder, exist_ok=True)
os.makedirs(f"{output_folder}/base_model", exist_ok=True)
os.makedirs(f"{output_folder}/tuned", exist_ok=True)
model = AutoModelForVideoClassification.from_pretrained(f"{output_folder}/base_model")
processor = AutoProcessor.from_pretrained(f"{output_folder}/base_model")
# for param in model.base_model.parameters():
# param.requires_grad = False
# Replace classifier with binary output
model.classifier = nn.Linear(model.classifier.in_features, 2)
class VideoNPZDataset(Dataset):
def __init__(self, files, labels, processor):
self.files = files
self.labels = labels
self.processor = processor
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
file = self.files[idx]
label = self.labels[idx]
data = np.load(file)
frames = [data[key] for key in sorted(data.files)]
# Debug print
# for i, frame in enumerate(frames):
# print(f" Frame {i} shape: {frame.shape}, dtype: {frame.dtype}")
# Convert to RGB (assumes frames are in BGR format)
frames_rgb = [Image.fromarray(frame[:, :, ::-1]) for frame in frames]
# for i, frame in enumerate(frames_rgb):
# print(f" Frame {i} post-RGB shape: {frame.shape}")
# Process the video frames
pixel_values = self.processor(images=frames_rgb, return_tensors="pt")["pixel_values"][0] # shape: [16, 3, 224, 224]
label_values = torch.tensor(label, dtype=torch.long)
# print(f"[{file}]")
# print(f" Processor output shape (before permute): {pixel_values.shape}")
# print(f" Processor label shape (before permute): {label_values.shape}")
return {
"pixel_values": pixel_values, # [3, 16, 224, 224]
"labels": label_values # scalar tensor
}
class myTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
print(f"[DEBUG] pixel_values: {inputs['pixel_values'].shape}")
print(f"[DEBUG] labels: {inputs['labels'].shape}")
return super().compute_loss(model, inputs, return_outputs)
def collate_fn(batch):
pixel_values = []
labels = []
for item in batch:
video = item["pixel_values"] # shape: (16, 3, 224, 224)
label = item["labels"]
pixel_values.append(video)
labels.append(label)
pixel_values = torch.stack(pixel_values) # (batch_size, 3, 16, 224, 224)
labels_values = torch.tensor(labels, dtype=torch.long) # (batch_size, 3, 16, 224, 224)
return {
"pixel_values": pixel_values,
"labels": labels_values,
}
def load_dataset_from_npz(root_dir):
files = []
labels = []
groups = []
label_map = {} # for inference class map
group_map = {}
group_counter = 0
for i, class_name in enumerate(sorted(os.listdir(root_dir))):
label_map[i] = class_name
class_dir = os.path.join(root_dir, class_name)
for file in os.listdir(class_dir):
if file.endswith(".npz"):
group_id = f"{file.split('_')[0]}_{i}"
if group_id not in group_map:
group_counter += 1
group_map[group_id] = group_counter
groups.append(group_map[group_id])
files.append(os.path.join(class_dir, file))
labels.append(i)
return files, labels, label_map, groups
files, labels, label_map, groups = load_dataset_from_npz(dataset_folder)
# for file, label in zip(files, labels):
# print(f"{file} {label} {label_map[label]}")
print(f" files: {len(files)}")
print(f" labels: {len(labels)}")
# import torch
# print("CUDA available:", torch.cuda.is_available())
# print("Device count:", torch.cuda.device_count())
# print("Current device:", torch.cuda.current_device())
# print("Device name:", torch.cuda.get_device_name(0))
train_files, val_files, train_labels, val_labels = train_test_split(files, labels, test_size=0.2, stratify=groups, random_state=random.randint(1,5000))
train_dataset = VideoNPZDataset(train_files, train_labels, processor)
val_dataset = VideoNPZDataset(val_files, val_labels, processor)
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = torch.sigmoid(torch.tensor(logits)).numpy() > 0.5
accuracy = (preds.flatten() == labels).mean()
return {"accuracy": accuracy}
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
save_strategy="epoch",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
logging_dir="./logs",
logging_steps=2,
save_total_limit=2,
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=collate_fn,
)
trainer.train()
logs = trainer.state.log_history
train_loss = [x["loss"] for x in logs if "loss" in x]
eval_loss = [x["eval_loss"] for x in logs if "eval_loss" in x]
plt.plot(train_loss, label="Train Loss")
plt.plot(eval_loss, label="Eval Loss")
plt.xlabel("Log Steps")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss Curve")
plt.savefig(f"{output_folder}/tuned/loss_curve.png")
trainer.save_model(f"{output_folder}/tuned")
processor.save_pretrained(f"{output_folder}/tuned")