Tiresias/preprocess/preprocess.py

206 lines
7.0 KiB
Python

import os
import io
import time
import redis
import cv2
import numpy as np
import struct
import mediapipe as mp
# Environment variables
stream_label = os.getenv("stream_label", "default_stream")
stream_label_queue = f"{stream_label}_cubes"
redis_host = os.getenv("redis_host", "localhost")
redis_port = int(os.getenv("redis_port", "6379"))
# Connect to Redis
redis_conn = redis.Redis(host=redis_host, port=redis_port, db=0, decode_responses=False)
redis_conn.delete(stream_label_queue)
def toRedis(queue_label, data):
print(f"Pushed data to queue: {queue_label}")
buffer = io.BytesIO()
np.savez(buffer, data=data)
compressed_data = buffer.getvalue()
return redis_conn.rpush(queue_label, compressed_data)
def fromRedis(queue_label):
compressed_data = None
retry = 0
while compressed_data == None:
compressed_data = redis_conn.lpop(queue_label)
if compressed_data:
retry = 0
print(f"Popped data from queue: {queue_label}")
buffer = io.BytesIO(compressed_data)
loaded_data = np.load(buffer)
return loaded_data['data']
else:
retry += 1
if retry % 50 == 0:
print(f"Queue {queue_label} empty for {retry/50} seconds")
time.sleep(1/50.0)
if retry > 1000:
raise(f'Queue {queue_label} 20s empty')
def toRedisList(queue_label, data_list):
print(f"Pushed data to queue: {queue_label}")
buffer = io.BytesIO()
np.savez(buffer, *data_list) # Use *data_list to unpack the list into arguments for savez_compressed
compressed_data = buffer.getvalue()
redis_conn.rpush(queue_label, compressed_data)
def fromRedisList(queue_label):
compressed_data = None
retry = 0
while compressed_data == None:
compressed_data = redis_conn.lpop(queue_label)
if compressed_data:
retry = 0
print(f"Popped data from queue: {queue_label}")
buffer = io.BytesIO(compressed_data)
loaded_data = np.load(buffer)
data_list = [loaded_data[key] for key in loaded_data.files]
return data_list
else:
retry += 1
if retry % 50 == 0:
print(f"Queue {queue_label} empty for {retry/50} seconds")
time.sleep(1/50.0)
if retry > 1000:
raise(f'Queue {queue_label} 20s empty')
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import numpy as np
options = vision.ObjectDetectorOptions(
base_options=python.BaseOptions(model_asset_path='/app/efficientdet_lite2.tflite'), # You might need to download a model
category_allowlist=['person'],
score_threshold=0.5
)
# Create ObjectDetector
detector = vision.ObjectDetector.create_from_options(options)
def detect_person_bbox(image_frame: np.ndarray) -> list:
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_frame)
detection_result = detector.detect(mp_image)
# Extract bounding boxes for detected persons
if detection_result and detection_result.detections:
h, w, _ = frame.shape
for detection in detection_result.detections:
if 'person' in [category.category_name for category in detection.categories]:
bbox = detection.bounding_box
if bbox.width != bbox.height:
center_x = bbox.origin_x + bbox.width / 2
center_y = bbox.origin_y + bbox.height / 2
size = max(bbox.width, bbox.height)
left = center_x - size / 2
top = center_y - size / 2
width = height = size
else:
width, height = bbox.width, bbox.height
padding = int(max(10, width//10))
width += 2 * padding
height += 2 * padding
left = max(0, left-padding)
top = max(0, top-padding)
if (right := (left+width) ) > w:
left -= right - w
if (bottom := (top+height) ) > h:
top -= bottom - h
yield int(left), int(top), int(width), int(height)
def get_enclosing_box(frame):
mp_pose = mp.solutions.pose
with mp_pose.Pose(min_detection_confidence=0.25, min_tracking_confidence=0) as pose:
results = pose.process(frame)
h, w, _ = frame.shape
if results.pose_landmarks:
x_coords = [landmark.x * w for landmark in results.pose_landmarks.landmark]
y_coords = [landmark.y * h for landmark in results.pose_landmarks.landmark]
min_x, max_x = min(x_coords), max(x_coords)
min_y, max_y = min(y_coords), max(y_coords)
padding = 10
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(w, max_x + padding)
max_y = min(h, max_y + padding)
width = max_x - min_x
height = max_y - min_y
# Make square
if width != height:
center_x = min_x + width / 2
center_y = min_y + height / 2
size = max(width, height)
min_x = center_x - size / 2
min_y = center_y - size / 2
width = height = size
yield int(min_x), int(min_y), int(width), int(height)
# return None
def crop_and_resize_frames(frames, box, target_size=224):
x, y, w, h = box
cropped = []
for frame in frames:
crop = frame[y:y+h, x:x+w]
if crop.shape[0] != target_size or crop.shape[1] != target_size:
crop = cv2.resize(crop, (target_size, target_size))
cropped.append(crop)
return cropped
if __name__ == "__main__":
frame_list = []
frame_count = 0
last_hits = 0
frame_hits = 0
frame_lag = 0
start_time = time.time()
lap_time = start_time
print(f"[INFO] Starting consumer for stream: {stream_label}")
while frame_lag < 100:
frame = fromRedis(stream_label)
if frame is None:
frame_lag += 1
wait_time = frame_lag * 0.01
time.sleep(wait_time)
continue
frame_lag = 0
frame_list.append(frame)
frame_count += 1
if len(frame_list) == 16:
for box in detect_person_bbox(frame_list[0]):
frame_hits += 1
cropped_frames = crop_and_resize_frames(frame_list, box)
toRedisList(stream_label_queue, cropped_frames)
frame_list.pop(0)
if frame_count % 15 == 0:
current_hits = frame_hits - last_hits
last_hits = frame_hits
last_lap = lap_time
lap_time = time.time()
elapsed = lap_time - last_lap
print(f"[INFO] {frame_count} frames, {frame_hits} hits. {current_hits} in {elapsed:.2f} seconds.")
total_time = time.time() - start_time
print(f"[INFO] Finished. Total frames: {frame_count}. Total Hits {frame_hits}. Total time: {total_time:.2f} seconds.")