Fix video_processor.py with enhanced frame extraction

- Add VideoMetadata dataclass with has_audio flag
- Implement hybrid frame extraction mode
- Add blur score calculation
- Better scene change detection
- Proper config integration
This commit is contained in:
Mario Lavoie
2026-01-27 20:22:53 +00:00
parent ae5367dab5
commit d2e63a335a

View File

@@ -1,10 +1,13 @@
"""Video processing module - frame extraction and scene detection."""
"""Video processing module - smart frame extraction and scene detection."""
import subprocess
import json
import re
from pathlib import Path
from dataclasses import dataclass
from typing import Literal
from .config import FrameExtractionConfig
@dataclass
@@ -13,61 +16,135 @@ class FrameInfo:
path: Path
timestamp: float # seconds
frame_number: int
scene_score: float = 0.0 # How much the scene changed (0-1)
blur_score: float = 0.0 # Lower = more blurry
@dataclass
class VideoMetadata:
"""Video file metadata."""
duration: float
width: int
height: int
fps: float
codec: str
has_audio: bool
class VideoProcessor:
"""Handles video frame extraction using ffmpeg."""
"""Handles video frame extraction using ffmpeg with smart selection."""
def __init__(self, video_path: Path, output_dir: Path, scene_threshold: float = 0.3):
def __init__(
self,
video_path: Path,
output_dir: Path,
config: FrameExtractionConfig | None = None
):
self.video_path = video_path
self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)
self.scene_threshold = scene_threshold
self._duration: float | None = None
self.config = config or FrameExtractionConfig()
self._metadata: VideoMetadata | None = None
def get_duration(self) -> float:
"""Get video duration in seconds."""
if self._duration is not None:
return self._duration
def get_metadata(self) -> VideoMetadata:
"""Get video metadata using ffprobe."""
if self._metadata:
return self._metadata
cmd = [
"ffprobe", "-v", "quiet",
"-print_format", "json",
"-show_format",
"-show_format", "-show_streams",
str(self.video_path)
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffprobe failed: {result.stderr}")
data = json.loads(result.stdout)
self._duration = float(data["format"]["duration"])
return self._duration
# Find video stream
video_stream = None
has_audio = False
for stream in data.get("streams", []):
if stream.get("codec_type") == "video":
video_stream = stream
elif stream.get("codec_type") == "audio":
has_audio = True
if not video_stream:
raise RuntimeError("No video stream found")
# Parse FPS (can be "30/1" or "29.97")
fps_str = video_stream.get("r_frame_rate", "30/1")
if "/" in fps_str:
num, den = fps_str.split("/")
fps = float(num) / float(den) if float(den) != 0 else 30.0
else:
fps = float(fps_str)
self._metadata = VideoMetadata(
duration=float(data["format"]["duration"]),
width=video_stream.get("width", 1920),
height=video_stream.get("height", 1080),
fps=fps,
codec=video_stream.get("codec_name", "unknown"),
has_audio=has_audio,
)
return self._metadata
def extract_frames(self, interval: float = 2.0) -> list[FrameInfo]:
def get_duration(self) -> float:
"""Get video duration in seconds."""
return self.get_metadata().duration
def extract_frames(
self,
mode: Literal["interval", "scene", "hybrid"] | None = None,
interval: float | None = None,
) -> list[FrameInfo]:
"""
Extract frames at regular intervals.
Extract frames using the specified mode.
Args:
interval: Seconds between frame extractions
mode: Extraction mode (interval, scene, hybrid). Defaults to config.
interval: Override interval seconds for interval mode.
Returns:
List of FrameInfo objects for extracted frames
"""
# Clear existing frames
for old_frame in self.output_dir.glob("frame_*.jpg"):
old_frame.unlink()
mode = mode or self.config.mode
interval = interval or self.config.interval_seconds
if mode == "interval":
return self._extract_interval_frames(interval)
elif mode == "scene":
return self._extract_scene_frames()
elif mode == "hybrid":
return self._extract_hybrid_frames(interval)
else:
raise ValueError(f"Unknown mode: {mode}")
def _extract_interval_frames(self, interval: float) -> list[FrameInfo]:
"""Extract frames at regular intervals."""
frames = []
# Use ffmpeg to extract frames at interval
output_pattern = self.output_dir / "frame_%04d.jpg"
cmd = [
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
"ffmpeg", "-y",
"-i", str(self.video_path),
"-vf", f"fps=1/{interval}",
"-q:v", "2", # High quality JPEG
str(output_pattern)
]
subprocess.run(cmd, capture_output=True)
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Frame extraction failed: {result.stderr}")
# Collect extracted frames
frames = []
for i, frame_path in enumerate(sorted(self.output_dir.glob("frame_*.jpg"))):
timestamp = i * interval
frames.append(FrameInfo(
@@ -78,117 +155,205 @@ class VideoProcessor:
return frames
def extract_at_scene_changes(self, max_frames: int = 15, min_interval: float = 1.0) -> list[FrameInfo]:
def _extract_scene_frames(self) -> list[FrameInfo]:
"""
Extract frames at scene changes (visual transitions).
Extract frames at scene changes.
This is smarter than fixed intervals - it captures when the view changes
(e.g., when the engineer rotates the model or zooms in on a component).
Args:
max_frames: Maximum number of frames to extract
min_interval: Minimum seconds between frames
Returns:
List of FrameInfo objects, or empty list if detection fails
Uses ffmpeg scene detection filter to identify significant visual changes.
"""
# Clear existing frames
for old_frame in self.output_dir.glob("frame_*.jpg"):
old_frame.unlink()
threshold = self.config.scene_threshold
# Detect scene changes
scene_timestamps = self._detect_scene_changes()
# First pass: detect scene changes
scene_timestamps = self.detect_scene_changes(threshold)
if not scene_timestamps:
return []
# Filter timestamps to ensure minimum interval and max count
filtered_timestamps = self._filter_timestamps(scene_timestamps, max_frames, min_interval)
# Always include first frame (t=0) and last frame
# Always include first and last frames
duration = self.get_duration()
if 0.0 not in filtered_timestamps:
filtered_timestamps.insert(0, 0.0)
if duration - filtered_timestamps[-1] > min_interval:
filtered_timestamps.append(duration - 0.5)
all_timestamps = [0.0] + scene_timestamps
if duration not in all_timestamps:
all_timestamps.append(max(0, duration - 0.5))
# Limit to max_frames
if len(filtered_timestamps) > max_frames:
step = len(filtered_timestamps) / max_frames
filtered_timestamps = [filtered_timestamps[int(i * step)] for i in range(max_frames)]
# Ensure minimum frames
if len(all_timestamps) < self.config.min_frames:
# Add evenly spaced frames
additional = self.config.min_frames - len(all_timestamps)
for i in range(additional):
t = duration * (i + 1) / (additional + 1)
if t not in all_timestamps:
all_timestamps.append(t)
all_timestamps = sorted(set(all_timestamps))
# Limit to max frames
if len(all_timestamps) > self.config.max_frames:
step = len(all_timestamps) / self.config.max_frames
all_timestamps = [all_timestamps[int(i * step)] for i in range(self.config.max_frames)]
# Extract frames at these timestamps
frames = []
for i, ts in enumerate(filtered_timestamps):
output_path = self.output_dir / f"frame_{i:04d}.jpg"
cmd = [
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
"-ss", str(ts),
"-i", str(self.video_path),
"-vframes", "1",
"-q:v", "2",
str(output_path)
]
subprocess.run(cmd, capture_output=True)
for i, ts in enumerate(all_timestamps):
output_path = self.output_dir / f"scene_{i:04d}.jpg"
self._extract_frame_at(ts, output_path)
if output_path.exists():
frames.append(FrameInfo(
path=output_path,
timestamp=ts,
frame_number=i
frame_number=i,
scene_score=1.0 if ts in scene_timestamps else 0.5,
))
return frames
def _detect_scene_changes(self) -> list[float]:
def _extract_hybrid_frames(self, base_interval: float) -> list[FrameInfo]:
"""
Detect scene changes in video using ffmpeg's scene filter.
Hybrid extraction: scene-based with interval fallback.
Gets scene changes, then fills gaps with interval sampling.
Filters out blurry frames.
"""
duration = self.get_duration()
threshold = self.config.scene_threshold
# Get scene change timestamps
scene_timestamps = self.detect_scene_changes(threshold)
# Start with scene changes
all_timestamps = {0.0} # Always include start
for ts in scene_timestamps:
all_timestamps.add(ts)
# Fill gaps with interval sampling
current = 0.0
while current < duration:
# Find if there's a scene change nearby
nearby = [ts for ts in all_timestamps if abs(ts - current) < base_interval / 2]
if not nearby:
all_timestamps.add(current)
current += base_interval
# Add end frame
all_timestamps.add(max(0, duration - 0.5))
timestamps = sorted(all_timestamps)
# Limit to max frames
if len(timestamps) > self.config.max_frames:
# Prefer scene change frames
scene_set = set(scene_timestamps)
scene_frames = [t for t in timestamps if t in scene_set]
other_frames = [t for t in timestamps if t not in scene_set]
# Take all scene frames up to half max, fill rest with others
max_scene = self.config.max_frames // 2
timestamps = scene_frames[:max_scene]
remaining = self.config.max_frames - len(timestamps)
# Evenly select from other frames
if other_frames and remaining > 0:
step = max(1, len(other_frames) // remaining)
timestamps.extend(other_frames[::step][:remaining])
timestamps = sorted(timestamps)
# Extract all candidate frames
frames = []
for i, ts in enumerate(timestamps):
output_path = self.output_dir / f"hybrid_{i:04d}.jpg"
self._extract_frame_at(ts, output_path)
if output_path.exists():
# Calculate blur score
blur_score = self._calculate_blur_score(output_path)
frames.append(FrameInfo(
path=output_path,
timestamp=ts,
frame_number=i,
scene_score=1.0 if ts in scene_timestamps else 0.3,
blur_score=blur_score,
))
# Filter out blurry frames (keep at least min_frames)
if len(frames) > self.config.min_frames:
# Sort by blur score (higher = sharper), keep best ones
sorted_by_blur = sorted(frames, key=lambda f: f.blur_score, reverse=True)
# Keep all sharp frames and enough to meet minimum
sharp_frames = [f for f in sorted_by_blur if f.blur_score > self.config.blur_threshold]
if len(sharp_frames) >= self.config.min_frames:
frames = sharp_frames
else:
# Keep minimum number of best frames
frames = sorted_by_blur[:max(self.config.min_frames, len(sharp_frames))]
# Re-sort by timestamp
frames = sorted(frames, key=lambda f: f.timestamp)
return frames
def _extract_frame_at(self, timestamp: float, output_path: Path) -> bool:
"""Extract a single frame at specific timestamp."""
cmd = [
"ffmpeg", "-y",
"-ss", str(timestamp),
"-i", str(self.video_path),
"-frames:v", "1",
"-q:v", "2",
str(output_path)
]
result = subprocess.run(cmd, capture_output=True, text=True)
return result.returncode == 0
def _calculate_blur_score(self, image_path: Path) -> float:
"""
Calculate blur score using file size as proxy.
Higher score = sharper image (more detail = larger file).
"""
try:
# Use file size as rough proxy (sharper = more detail = larger)
size = image_path.stat().st_size
return float(size) / 10000 # Normalize
except Exception:
return 100.0 # Default to "sharp enough"
def detect_scene_changes(self, threshold: float = 0.3) -> list[float]:
"""
Detect scene changes in video.
Returns list of timestamps where significant visual changes occur.
"""
cmd = [
"ffmpeg", "-hide_banner",
"-i", str(self.video_path),
"-vf", f"select='gt(scene,{self.scene_threshold})',showinfo",
"ffmpeg", "-i", str(self.video_path),
"-vf", f"select='gt(scene,{threshold})',showinfo",
"-f", "null", "-"
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
result = subprocess.run(cmd, capture_output=True, text=True)
# Parse scene change timestamps from ffmpeg output
timestamps = []
for line in result.stderr.split("\n"):
if "pts_time:" in line:
# Extract timestamp using regex
match = re.search(r'pts_time:(\d+\.?\d*)', line)
match = re.search(r'pts_time:([0-9.]+)', line)
if match:
ts = float(match.group(1))
timestamps.append(ts)
return sorted(set(timestamps))
def _filter_timestamps(
self, timestamps: list[float], max_count: int, min_interval: float
) -> list[float]:
"""Filter timestamps to ensure minimum interval between frames."""
if not timestamps:
return []
filtered = [timestamps[0]]
for ts in timestamps[1:]:
if ts - filtered[-1] >= min_interval:
filtered.append(ts)
if len(filtered) >= max_count:
break
return filtered
return timestamps
def extract_audio(self, output_path: Path | None = None) -> Path:
"""Extract audio track from video."""
if output_path is None:
output_path = self.output_dir.parent / "audio.wav"
# Check if video has audio
metadata = self.get_metadata()
if not metadata.has_audio:
raise RuntimeError("Video has no audio track")
cmd = [
"ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
"ffmpeg", "-y",
"-i", str(self.video_path),
"-vn", # No video
"-acodec", "pcm_s16le",
@@ -196,16 +361,31 @@ class VideoProcessor:
"-ac", "1", # Mono
str(output_path)
]
subprocess.run(cmd, capture_output=True)
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Audio extraction failed: {result.stderr}")
return output_path
def get_video_info(self) -> dict:
"""Get video metadata."""
cmd = [
"ffprobe", "-v", "quiet",
"-print_format", "json",
"-show_format", "-show_streams",
str(self.video_path)
]
result = subprocess.run(cmd, capture_output=True, text=True)
return json.loads(result.stdout)
def get_frame_at_timestamp(self, timestamp: float) -> FrameInfo | None:
"""Get the closest extracted frame to a timestamp."""
output_path = self.output_dir / f"ts_{timestamp:.2f}.jpg"
if self._extract_frame_at(timestamp, output_path):
return FrameInfo(
path=output_path,
timestamp=timestamp,
frame_number=-1,
)
return None
def create_thumbnail(self, output_path: Path | None = None) -> Path:
"""Create a thumbnail from the video (frame from 10% into video)."""
if output_path is None:
output_path = self.output_dir.parent / "thumbnail.jpg"
duration = self.get_duration()
timestamp = duration * 0.1 # 10% in
self._extract_frame_at(timestamp, output_path)
return output_path