#!/usr/bin/env python3
"""
AI Fitness Buddy - Pose Analysis Service
FastAPI server with WebSocket for real-time pose analysis and rep counting.
"""

import asyncio
import json
import logging
from typing import Optional
from contextlib import asynccontextmanager

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from pose_analyzer import PoseAnalyzer
from exercises import EXERCISE_CONFIGS, get_exercise_config

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global state
pose_analyzer: Optional[PoseAnalyzer] = None
connected_clients: list[WebSocket] = []


class WorkoutStats(BaseModel):
    exercise: str
    reps: int
    sets_completed: int
    form_score: float
    feedback: list[str]
    phase: str


class ExerciseRequest(BaseModel):
    exercise: str


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize pose analyzer on startup."""
    global pose_analyzer
    logger.info("Starting Pose Analysis Service...")
    pose_analyzer = PoseAnalyzer()
    yield
    logger.info("Shutting down Pose Analysis Service...")
    if pose_analyzer:
        pose_analyzer.stop()


app = FastAPI(
    title="AI Fitness Buddy - Pose Service",
    description="Real-time pose analysis and rep counting for workouts",
    version="0.1.0",
    lifespan=lifespan,
)

# CORS for dashboard connection
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
async def root():
    """Health check endpoint."""
    return {
        "service": "AI Fitness Buddy - Pose Analysis",
        "status": "running",
        "camera_active": pose_analyzer.is_running if pose_analyzer else False,
    }


@app.get("/stats")
async def get_stats() -> WorkoutStats:
    """Get current workout statistics."""
    if not pose_analyzer:
        raise HTTPException(status_code=503, detail="Pose analyzer not initialized")
    
    stats = pose_analyzer.get_stats()
    return WorkoutStats(**stats)


@app.post("/exercise")
async def set_exercise(request: ExerciseRequest):
    """Switch to a different exercise."""
    if not pose_analyzer:
        raise HTTPException(status_code=503, detail="Pose analyzer not initialized")
    
    config = get_exercise_config(request.exercise)
    if not config:
        raise HTTPException(
            status_code=400, 
            detail=f"Unknown exercise: {request.exercise}. Available: {list(EXERCISE_CONFIGS.keys())}"
        )
    
    pose_analyzer.set_exercise(request.exercise)
    return {"status": "ok", "exercise": request.exercise}


@app.post("/reset")
async def reset_workout():
    """Reset rep counter and stats."""
    if not pose_analyzer:
        raise HTTPException(status_code=503, detail="Pose analyzer not initialized")
    
    pose_analyzer.reset()
    return {"status": "ok", "message": "Workout reset"}


@app.get("/exercises")
async def list_exercises():
    """List available exercises with their configurations."""
    return {
        "exercises": list(EXERCISE_CONFIGS.keys()),
        "configs": EXERCISE_CONFIGS,
    }


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """
    WebSocket endpoint for real-time pose data streaming.
    
    Sends JSON messages with:
    - reps: Current rep count
    - phase: Current movement phase (standing, descending, bottom, ascending)
    - form_score: Form quality percentage (0-100)
    - feedback: List of form feedback strings
    - landmarks: Optional raw landmark data
    """
    await websocket.accept()
    connected_clients.append(websocket)
    logger.info(f"Client connected. Total clients: {len(connected_clients)}")
    
    try:
        # Start camera if not already running
        if pose_analyzer and not pose_analyzer.is_running:
            pose_analyzer.start()
        
        while True:
            # Get latest pose data
            if pose_analyzer:
                data = pose_analyzer.get_frame_data()
                if data:
                    await websocket.send_json(data)
            
            # Also listen for commands from client
            try:
                message = await asyncio.wait_for(
                    websocket.receive_text(), 
                    timeout=0.05  # 20 FPS max
                )
                # Handle client commands
                cmd = json.loads(message)
                if cmd.get("action") == "reset":
                    pose_analyzer.reset()
                elif cmd.get("action") == "set_exercise":
                    pose_analyzer.set_exercise(cmd.get("exercise", "squat"))
            except asyncio.TimeoutError:
                pass
            except json.JSONDecodeError:
                pass
            
            await asyncio.sleep(0.05)  # ~20 FPS
            
    except WebSocketDisconnect:
        logger.info("Client disconnected")
    except Exception as e:
        logger.error(f"WebSocket error: {e}")
    finally:
        if websocket in connected_clients:
            connected_clients.remove(websocket)
        logger.info(f"Client removed. Total clients: {len(connected_clients)}")


@app.on_event("shutdown")
async def shutdown():
    """Clean up on shutdown."""
    global pose_analyzer
    if pose_analyzer:
        pose_analyzer.stop()
    
    # Close all WebSocket connections
    for client in connected_clients:
        await client.close()


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8765)
