Tiresias/Inference/infrence.py

127 lines
4.2 KiB
Python

import os
import io
import time
import redis
import sqlite3
import numpy as np
from PIL import Image
from datetime import datetime
import torch
from transformers import AutoProcessor, AutoModelForVideoClassification
# -----------------------------
# ENVIRONMENT CONFIG
# -----------------------------
redis_host = os.getenv("redis_host", "localhost")
redis_port = int(os.getenv("redis_port", 6379))
stream_label = os.getenv("stream_label", "cropped_stream")
output_folder = os.getenv("out_folder", "/app/out_folder")
model_folder = os.getenv("model_folder", "base_model")
model_path = os.path.join(output_folder, model_folder)
threshold = float(os.getenv("threshold", "0.5"))
queue_label = f"{stream_label}_cubes"
sqlite_path = os.path.join(output_folder, f"{stream_label}_results.db")
stream_folder = os.path.join(output_folder, stream_label)
os.makedirs(output_folder, exist_ok=True)
os.makedirs(stream_folder, exist_ok=True)
# -----------------------------
# CONNECT TO REDIS
# -----------------------------
redis_conn = redis.Redis(host=redis_host, port=redis_port, db=0, decode_responses=False)
# -----------------------------
# LOAD MODEL + PROCESSOR
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForVideoClassification.from_pretrained(model_path).to(device)
processor = AutoProcessor.from_pretrained(model_path)
model.eval()
# Warm up
with torch.no_grad():
dummy = torch.randn(1, 3, 16, 224, 224).to(device)
_ = model(pixel_values=dummy)
# -----------------------------
# SETUP SQLITE
# -----------------------------
conn = sqlite3.connect(sqlite_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS inference_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
gif_name TEXT,
timestamp TEXT,
prediction INTEGER
)
""")
conn.commit()
# -----------------------------
# REDIS UTILS
# -----------------------------
def from_redis_list(queue_label):
retry = 0
while True:
compressed_data = redis_conn.lpop(queue_label)
if compressed_data:
buffer = io.BytesIO(compressed_data)
loaded_data = np.load(buffer)
frames = [loaded_data[key] for key in sorted(loaded_data.files)]
return frames
else:
retry += 1
if retry % 50 == 0:
print(f"[WAIT] Queue {queue_label} empty for {retry/50:.1f} seconds...")
time.sleep(1/50.0)
if retry > 2000:
raise TimeoutError(f"Queue {queue_label} empty for over 40s")
# -----------------------------
# SAVE GIF
# -----------------------------
def save_gif(frames, gif_path):
images = [Image.fromarray(frame.astype(np.uint8)) for frame in frames]
images[0].save(gif_path, save_all=True, append_images=images[1:], duration=50, loop=0)
# -----------------------------
# MAIN LOOP
# -----------------------------
print(f"[INFO] Listening on Redis queue: {queue_label}")
while True:
try:
frames = from_redis_list(queue_label)
if not all(frame.shape == (224, 224, 3) for frame in frames):
print("[WARN] Skipped frame batch due to incorrect shape")
continue
timestamp = datetime.utcnow().strftime("%y%m%d_%H%M%S_%f")[:-3]
gif_filename = f"{stream_label}/{stream_label}_{timestamp}.gif"
gif_path = os.path.join(stream_folder, gif_filename)
# Save GIF
save_gif(frames, gif_path)
# Preprocess and predict
with torch.no_grad():
inputs = processor(images=[frame[:, :, ::-1] for frame in frames], return_tensors="pt") # convert BGR to RGB
pixel_values = inputs["pixel_values"].to(device) # [1, 3, 16, 224, 224]
logits = model(pixel_values=pixel_values).logits
prob = torch.softmax(logits, dim=-1)[0][1].item()
prediction = int(prob > threshold)
# Insert into SQLite
cursor.execute("INSERT INTO inference_results (gif_name, timestamp, prediction) VALUES (?, ?, ?)",
(gif_filename, timestamp, prediction))
conn.commit()
print(f"[INFO] Saved {gif_filename} | Class={prediction} | Prob={prob:.3f}")
except Exception as e:
print(f"[ERROR] {e}")
time.sleep(1)