#!/usr/bin/env python3
"""
Pose Analyzer using MediaPipe for real-time workout form analysis.
"""

import math
import threading
import time
from typing import Optional, Dict, Any, List, Tuple
from dataclasses import dataclass, field
from enum import Enum
import logging

import cv2
import numpy as np

try:
    import mediapipe as mp
    MEDIAPIPE_AVAILABLE = True
except ImportError:
    MEDIAPIPE_AVAILABLE = False
    
from exercises import get_exercise_config, EXERCISE_CONFIGS

logger = logging.getLogger(__name__)


class Phase(Enum):
    STANDING = "standing"
    DESCENDING = "descending"
    BOTTOM = "bottom"
    ASCENDING = "ascending"


@dataclass
class RepState:
    """Tracks rep counting state."""
    count: int = 0
    phase: Phase = Phase.STANDING
    last_phase: Phase = Phase.STANDING
    phase_history: List[Phase] = field(default_factory=list)
    last_rep_time: float = 0.0
    

@dataclass
class FormAnalysis:
    """Form analysis results."""
    score: float = 100.0
    feedback: List[str] = field(default_factory=list)
    angles: Dict[str, float] = field(default_factory=dict)


class PoseAnalyzer:
    """Real-time pose analysis with MediaPipe."""
    
    def __init__(self, camera_id: int = 0):
        self.camera_id = camera_id
        self.cap: Optional[cv2.VideoCapture] = None
        self._running = False
        self._thread: Optional[threading.Thread] = None
        self._lock = threading.Lock()
        
        # Current state
        self.current_exercise = "squat"
        self.rep_state = RepState()
        self.form_analysis = FormAnalysis()
        self.last_landmarks = None
        self.frame_count = 0
        
        # Initialize MediaPipe
        if MEDIAPIPE_AVAILABLE:
            self.mp_pose = mp.solutions.pose
            self.pose = self.mp_pose.Pose(
                static_image_mode=False,
                model_complexity=1,
                smooth_landmarks=True,
                min_detection_confidence=0.5,
                min_tracking_confidence=0.5,
            )
        else:
            logger.warning("MediaPipe not available - using mock data")
            self.mp_pose = None
            self.pose = None
            
        # Coach messages based on form
        self.coach_messages = {
            "excellent": [
                "Perfect form! Keep it up! 💪",
                "You're crushing it! Championship form!",
                "That's exactly right! Great depth!",
            ],
            "good": [
                "Good rep! Stay focused!",
                "Nice work! Keep that form tight!",
                "You're doing great!",
            ],
            "needs_work": [
                "Watch your form on that one!",
                "Let's tighten up the form!",
                "Focus on the technique!",
            ],
        }
        
    @property
    def is_running(self) -> bool:
        return self._running
    
    def start(self) -> bool:
        """Start camera capture and pose analysis."""
        if self._running:
            return True
            
        try:
            self.cap = cv2.VideoCapture(self.camera_id)
            if not self.cap.isOpened():
                logger.warning(f"Could not open camera {self.camera_id}")
                self.cap = None
                # Continue without camera for demo/mock mode
                
            self._running = True
            self._thread = threading.Thread(target=self._analysis_loop, daemon=True)
            self._thread.start()
            logger.info("Pose analyzer started")
            return True
            
        except Exception as e:
            logger.error(f"Failed to start pose analyzer: {e}")
            return False
    
    def stop(self):
        """Stop pose analysis."""
        self._running = False
        if self._thread:
            self._thread.join(timeout=1.0)
        if self.cap:
            self.cap.release()
            self.cap = None
        logger.info("Pose analyzer stopped")
    
    def reset(self):
        """Reset rep counter and stats."""
        with self._lock:
            self.rep_state = RepState()
            self.form_analysis = FormAnalysis()
            
    def set_exercise(self, exercise: str):
        """Switch to a different exercise."""
        config = get_exercise_config(exercise)
        if config:
            with self._lock:
                self.current_exercise = exercise
                self.rep_state = RepState()
            logger.info(f"Switched to exercise: {exercise}")
    
    def get_stats(self) -> Dict[str, Any]:
        """Get current workout statistics."""
        with self._lock:
            return {
                "exercise": self.current_exercise,
                "reps": self.rep_state.count,
                "sets_completed": 0,  # TODO: track sets
                "form_score": self.form_analysis.score,
                "feedback": self.form_analysis.feedback,
                "phase": self.rep_state.phase.value,
            }
    
    def get_frame_data(self) -> Optional[Dict[str, Any]]:
        """Get latest frame analysis data for WebSocket streaming."""
        with self._lock:
            coach_msg = None
            if self.frame_count % 100 == 0:  # Occasional coach messages
                if self.form_analysis.score >= 90:
                    import random
                    coach_msg = random.choice(self.coach_messages["excellent"])
                elif self.form_analysis.score >= 70:
                    import random
                    coach_msg = random.choice(self.coach_messages["good"])
            
            return {
                "reps": self.rep_state.count,
                "phase": self.rep_state.phase.value,
                "form_score": round(self.form_analysis.score, 1),
                "feedback": self.form_analysis.feedback[:3],  # Top 3 feedback items
                "angles": self.form_analysis.angles,
                "coach_message": coach_msg,
                "timestamp": time.time(),
            }
    
    def _analysis_loop(self):
        """Main analysis loop running in background thread."""
        last_mock_update = time.time()
        
        while self._running:
            if self.cap and self.cap.isOpened():
                ret, frame = self.cap.read()
                if ret:
                    self._analyze_frame(frame)
            else:
                # Mock mode - simulate pose data
                if time.time() - last_mock_update > 0.1:
                    self._simulate_pose()
                    last_mock_update = time.time()
                    
            self.frame_count += 1
            time.sleep(0.033)  # ~30 FPS
    
    def _analyze_frame(self, frame: np.ndarray):
        """Analyze a single frame for pose landmarks."""
        if not self.pose:
            return
            
        # Convert to RGB for MediaPipe
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = self.pose.process(rgb_frame)
        
        if results.pose_landmarks:
            landmarks = results.pose_landmarks.landmark
            self.last_landmarks = landmarks
            
            # Get exercise config
            config = get_exercise_config(self.current_exercise)
            if config:
                self._analyze_exercise(landmarks, config)
    
    def _analyze_exercise(self, landmarks, config: Dict):
        """Analyze pose for specific exercise."""
        with self._lock:
            # Calculate angles
            angles = {}
            for angle_name, angle_config in config.get("angles", {}).items():
                points = angle_config["points"]
                angle = self._calculate_angle_from_landmarks(landmarks, points)
                angles[angle_name] = angle
            
            self.form_analysis.angles = angles
            
            # Determine phase based on primary angle
            primary_angle = config.get("primary_angle", "knee_angle")
            if primary_angle in angles:
                self._update_phase(angles[primary_angle], config)
            
            # Analyze form and generate feedback
            self._analyze_form(angles, config)
    
    def _calculate_angle_from_landmarks(
        self, landmarks, points: Tuple[str, str, str]
    ) -> float:
        """Calculate angle between three body landmarks."""
        if not MEDIAPIPE_AVAILABLE:
            return 90.0
            
        mp_pose = mp.solutions.pose.PoseLandmark
        
        # Map string names to MediaPipe landmarks
        landmark_map = {
            "LEFT_HIP": mp_pose.LEFT_HIP,
            "RIGHT_HIP": mp_pose.RIGHT_HIP,
            "LEFT_KNEE": mp_pose.LEFT_KNEE,
            "RIGHT_KNEE": mp_pose.RIGHT_KNEE,
            "LEFT_ANKLE": mp_pose.LEFT_ANKLE,
            "RIGHT_ANKLE": mp_pose.RIGHT_ANKLE,
            "LEFT_SHOULDER": mp_pose.LEFT_SHOULDER,
            "RIGHT_SHOULDER": mp_pose.RIGHT_SHOULDER,
            "LEFT_ELBOW": mp_pose.LEFT_ELBOW,
            "RIGHT_ELBOW": mp_pose.RIGHT_ELBOW,
            "LEFT_WRIST": mp_pose.LEFT_WRIST,
            "RIGHT_WRIST": mp_pose.RIGHT_WRIST,
            "NOSE": mp_pose.NOSE,
        }
        
        try:
            p1 = landmarks[landmark_map[points[0]].value]
            p2 = landmarks[landmark_map[points[1]].value]
            p3 = landmarks[landmark_map[points[2]].value]
            
            angle = self._calculate_angle(
                (p1.x, p1.y),
                (p2.x, p2.y),
                (p3.x, p3.y),
            )
            return angle
        except (KeyError, IndexError) as e:
            logger.debug(f"Could not calculate angle: {e}")
            return 0.0
    
    @staticmethod
    def _calculate_angle(a: Tuple[float, float], b: Tuple[float, float], c: Tuple[float, float]) -> float:
        """Calculate angle at point b given three points."""
        ba = (a[0] - b[0], a[1] - b[1])
        bc = (c[0] - b[0], c[1] - b[1])
        
        dot_product = ba[0] * bc[0] + ba[1] * bc[1]
        magnitude_ba = math.sqrt(ba[0]**2 + ba[1]**2)
        magnitude_bc = math.sqrt(bc[0]**2 + bc[1]**2)
        
        if magnitude_ba * magnitude_bc == 0:
            return 0.0
            
        cos_angle = dot_product / (magnitude_ba * magnitude_bc)
        cos_angle = max(-1, min(1, cos_angle))  # Clamp to valid range
        angle = math.degrees(math.acos(cos_angle))
        
        return angle
    
    def _update_phase(self, angle: float, config: Dict):
        """Update movement phase based on angle."""
        thresholds = config.get("thresholds", {})
        standing_threshold = thresholds.get("standing", 160)
        bottom_threshold = thresholds.get("bottom", 100)
        
        old_phase = self.rep_state.phase
        
        if angle > standing_threshold:
            self.rep_state.phase = Phase.STANDING
        elif angle < bottom_threshold:
            self.rep_state.phase = Phase.BOTTOM
        elif old_phase in (Phase.STANDING, Phase.DESCENDING):
            self.rep_state.phase = Phase.DESCENDING
        else:
            self.rep_state.phase = Phase.ASCENDING
        
        # Detect completed rep
        if old_phase == Phase.ASCENDING and self.rep_state.phase == Phase.STANDING:
            # Debounce - minimum 0.5s between reps
            if time.time() - self.rep_state.last_rep_time > 0.5:
                self.rep_state.count += 1
                self.rep_state.last_rep_time = time.time()
                logger.info(f"Rep counted! Total: {self.rep_state.count}")
    
    def _analyze_form(self, angles: Dict[str, float], config: Dict):
        """Analyze form quality and generate feedback."""
        feedback = []
        score = 100.0
        
        for angle_name, angle_config in config.get("angles", {}).items():
            if angle_name not in angles:
                continue
                
            angle = angles[angle_name]
            min_good = angle_config.get("min_good", 0)
            max_good = angle_config.get("max_good", 180)
            feedback_msgs = angle_config.get("feedback", {})
            
            if angle < min_good:
                msg = feedback_msgs.get("too_low", f"{angle_name}: too low")
                feedback.append(msg)
                score -= 10
            elif angle > max_good:
                msg = feedback_msgs.get("too_high", f"{angle_name}: too high")
                feedback.append(msg)
                score -= 10
        
        self.form_analysis.feedback = feedback
        self.form_analysis.score = max(0, min(100, score))
    
    def _simulate_pose(self):
        """Simulate pose data for demo/mock mode."""
        import random
        
        with self._lock:
            # Simulate rep counting every few seconds
            if random.random() < 0.03:  # ~3% chance per tick
                old_phase = self.rep_state.phase
                
                # Cycle through phases
                if self.rep_state.phase == Phase.STANDING:
                    self.rep_state.phase = Phase.DESCENDING
                elif self.rep_state.phase == Phase.DESCENDING:
                    self.rep_state.phase = Phase.BOTTOM
                elif self.rep_state.phase == Phase.BOTTOM:
                    self.rep_state.phase = Phase.ASCENDING
                else:  # ASCENDING
                    self.rep_state.phase = Phase.STANDING
                    if time.time() - self.rep_state.last_rep_time > 0.5:
                        self.rep_state.count += 1
                        self.rep_state.last_rep_time = time.time()
            
            # Simulate varying form scores
            base_score = 85 + random.random() * 15
            self.form_analysis.score = round(base_score, 1)
            
            # Random feedback
            if random.random() < 0.1:
                possible_feedback = [
                    "Great depth!",
                    "Keep your chest up",
                    "Good knee tracking",
                    "Strong back position",
                ]
                self.form_analysis.feedback = random.sample(
                    possible_feedback, 
                    k=min(2, len(possible_feedback))
                )
