127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
import os
|
|
import io
|
|
import time
|
|
import redis
|
|
import sqlite3
|
|
import numpy as np
|
|
from PIL import Image
|
|
from datetime import datetime
|
|
|
|
import torch
|
|
from transformers import AutoProcessor, AutoModelForVideoClassification
|
|
|
|
# -----------------------------
|
|
# ENVIRONMENT CONFIG
|
|
# -----------------------------
|
|
redis_host = os.getenv("redis_host", "localhost")
|
|
redis_port = int(os.getenv("redis_port", 6379))
|
|
stream_label = os.getenv("stream_label", "cropped_stream")
|
|
output_folder = os.getenv("out_folder", "/app/out_folder")
|
|
model_folder = os.getenv("model_folder", "base_model")
|
|
model_path = os.path.join(output_folder, model_folder)
|
|
threshold = float(os.getenv("threshold", "0.5"))
|
|
queue_label = f"{stream_label}_cubes"
|
|
|
|
sqlite_path = os.path.join(output_folder, f"{stream_label}_results.db")
|
|
stream_folder = os.path.join(output_folder, stream_label)
|
|
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
os.makedirs(stream_folder, exist_ok=True)
|
|
|
|
# -----------------------------
|
|
# CONNECT TO REDIS
|
|
# -----------------------------
|
|
redis_conn = redis.Redis(host=redis_host, port=redis_port, db=0, decode_responses=False)
|
|
|
|
# -----------------------------
|
|
# LOAD MODEL + PROCESSOR
|
|
# -----------------------------
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = AutoModelForVideoClassification.from_pretrained(model_path).to(device)
|
|
processor = AutoProcessor.from_pretrained(model_path)
|
|
model.eval()
|
|
|
|
# Warm up
|
|
with torch.no_grad():
|
|
dummy = torch.randn(1, 3, 16, 224, 224).to(device)
|
|
_ = model(pixel_values=dummy)
|
|
|
|
# -----------------------------
|
|
# SETUP SQLITE
|
|
# -----------------------------
|
|
conn = sqlite3.connect(sqlite_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS inference_results (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
gif_name TEXT,
|
|
timestamp TEXT,
|
|
prediction INTEGER
|
|
)
|
|
""")
|
|
conn.commit()
|
|
|
|
# -----------------------------
|
|
# REDIS UTILS
|
|
# -----------------------------
|
|
def from_redis_list(queue_label):
|
|
retry = 0
|
|
while True:
|
|
compressed_data = redis_conn.lpop(queue_label)
|
|
if compressed_data:
|
|
buffer = io.BytesIO(compressed_data)
|
|
loaded_data = np.load(buffer)
|
|
frames = [loaded_data[key] for key in sorted(loaded_data.files)]
|
|
return frames
|
|
else:
|
|
retry += 1
|
|
if retry % 50 == 0:
|
|
print(f"[WAIT] Queue {queue_label} empty for {retry/50:.1f} seconds...")
|
|
time.sleep(1/50.0)
|
|
if retry > 2000:
|
|
raise TimeoutError(f"Queue {queue_label} empty for over 40s")
|
|
|
|
# -----------------------------
|
|
# SAVE GIF
|
|
# -----------------------------
|
|
def save_gif(frames, gif_path):
|
|
images = [Image.fromarray(frame.astype(np.uint8)) for frame in frames]
|
|
images[0].save(gif_path, save_all=True, append_images=images[1:], duration=50, loop=0)
|
|
|
|
# -----------------------------
|
|
# MAIN LOOP
|
|
# -----------------------------
|
|
print(f"[INFO] Listening on Redis queue: {queue_label}")
|
|
while True:
|
|
try:
|
|
frames = from_redis_list(queue_label)
|
|
if not all(frame.shape == (224, 224, 3) for frame in frames):
|
|
print("[WARN] Skipped frame batch due to incorrect shape")
|
|
continue
|
|
|
|
timestamp = datetime.utcnow().strftime("%y%m%d_%H%M%S_%f")[:-3]
|
|
gif_filename = f"{stream_label}/{stream_label}_{timestamp}.gif"
|
|
gif_path = os.path.join(stream_folder, gif_filename)
|
|
|
|
# Save GIF
|
|
save_gif(frames, gif_path)
|
|
|
|
# Preprocess and predict
|
|
with torch.no_grad():
|
|
inputs = processor(images=[frame[:, :, ::-1] for frame in frames], return_tensors="pt") # convert BGR to RGB
|
|
pixel_values = inputs["pixel_values"].to(device) # [1, 3, 16, 224, 224]
|
|
logits = model(pixel_values=pixel_values).logits
|
|
prob = torch.softmax(logits, dim=-1)[0][1].item()
|
|
prediction = int(prob > threshold)
|
|
|
|
# Insert into SQLite
|
|
cursor.execute("INSERT INTO inference_results (gif_name, timestamp, prediction) VALUES (?, ?, ?)",
|
|
(gif_filename, timestamp, prediction))
|
|
conn.commit()
|
|
|
|
print(f"[INFO] Saved {gif_filename} | Class={prediction} | Prob={prob:.3f}")
|
|
|
|
except Exception as e:
|
|
print(f"[ERROR] {e}")
|
|
time.sleep(1)
|