mirror of
https://github.com/sstent/AICyclingCoach.git
synced 2026-02-15 11:52:10 +00:00
sync
This commit is contained in:
130
backend/app/services/ai_service.py
Normal file
130
backend/app/services/ai_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
import httpx
|
||||
import json
|
||||
from app.services.prompt_manager import PromptManager
|
||||
from app.models.workout import Workout
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""Service for AI-powered analysis and plan generation."""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.db = db_session
|
||||
self.prompt_manager = PromptManager(db_session)
|
||||
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
self.model = os.getenv("AI_MODEL", "anthropic/claude-3-sonnet-20240229")
|
||||
self.base_url = "https://openrouter.ai/api/v1"
|
||||
|
||||
async def analyze_workout(self, workout: Workout, plan: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Analyze a workout using AI and generate feedback."""
|
||||
prompt_template = await self.prompt_manager.get_active_prompt("workout_analysis")
|
||||
|
||||
if not prompt_template:
|
||||
raise ValueError("No active workout analysis prompt found")
|
||||
|
||||
# Build context from workout data
|
||||
workout_context = {
|
||||
"activity_type": workout.activity_type,
|
||||
"duration_minutes": workout.duration_seconds / 60 if workout.duration_seconds else 0,
|
||||
"distance_km": workout.distance_m / 1000 if workout.distance_m else 0,
|
||||
"avg_hr": workout.avg_hr,
|
||||
"avg_power": workout.avg_power,
|
||||
"elevation_gain": workout.elevation_gain_m,
|
||||
"planned_workout": plan
|
||||
}
|
||||
|
||||
prompt = prompt_template.format(**workout_context)
|
||||
|
||||
response = await self._make_ai_request(prompt)
|
||||
return self._parse_workout_analysis(response)
|
||||
|
||||
async def generate_plan(self, rules: List[Dict], goals: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate a training plan using AI."""
|
||||
prompt_template = await self.prompt_manager.get_active_prompt("plan_generation")
|
||||
|
||||
context = {
|
||||
"rules": rules,
|
||||
"goals": goals,
|
||||
"current_fitness_level": goals.get("fitness_level", "intermediate")
|
||||
}
|
||||
|
||||
prompt = prompt_template.format(**context)
|
||||
response = await self._make_ai_request(prompt)
|
||||
return self._parse_plan_response(response)
|
||||
|
||||
async def parse_rules_from_natural_language(self, natural_language: str) -> Dict[str, Any]:
|
||||
"""Parse natural language rules into structured format."""
|
||||
prompt_template = await self.prompt_manager.get_active_prompt("rule_parsing")
|
||||
prompt = prompt_template.format(user_rules=natural_language)
|
||||
|
||||
response = await self._make_ai_request(prompt)
|
||||
return self._parse_rules_response(response)
|
||||
|
||||
async def _make_ai_request(self, prompt: str) -> str:
|
||||
"""Make async request to OpenRouter API with retry logic."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
for attempt in range(3): # Simple retry logic
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2000,
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
except Exception as e:
|
||||
if attempt == 2: # Last attempt
|
||||
logger.error(f"AI request failed after 3 attempts: {str(e)}")
|
||||
raise AIServiceError(f"AI request failed after 3 attempts: {str(e)}")
|
||||
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
||||
|
||||
def _parse_workout_analysis(self, response: str) -> Dict[str, Any]:
|
||||
"""Parse AI response for workout analysis."""
|
||||
try:
|
||||
# Assume AI returns JSON
|
||||
clean_response = response.strip()
|
||||
if clean_response.startswith("```json"):
|
||||
clean_response = clean_response[7:-3]
|
||||
return json.loads(clean_response)
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_analysis": response, "structured": False}
|
||||
|
||||
def _parse_plan_response(self, response: str) -> Dict[str, Any]:
|
||||
"""Parse AI response for plan generation."""
|
||||
try:
|
||||
clean_response = response.strip()
|
||||
if clean_response.startswith("```json"):
|
||||
clean_response = clean_response[7:-3]
|
||||
return json.loads(clean_response)
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_plan": response, "structured": False}
|
||||
|
||||
def _parse_rules_response(self, response: str) -> Dict[str, Any]:
|
||||
"""Parse AI response for rule parsing."""
|
||||
try:
|
||||
clean_response = response.strip()
|
||||
if clean_response.startswith("```json"):
|
||||
clean_response = clean_response[7:-3]
|
||||
return json.loads(clean_response)
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_rules": response, "structured": False}
|
||||
|
||||
|
||||
class AIServiceError(Exception):
|
||||
"""Raised when AI service requests fail."""
|
||||
pass
|
||||
84
backend/app/services/garmin.py
Normal file
84
backend/app/services/garmin.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import garth
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GarminService:
|
||||
"""Service for interacting with Garmin Connect API."""
|
||||
|
||||
def __init__(self):
|
||||
self.username = os.getenv("GARMIN_USERNAME")
|
||||
self.password = os.getenv("GARMIN_PASSWORD")
|
||||
self.client: Optional[garth.Client] = None
|
||||
self.session_dir = "/app/data/sessions"
|
||||
|
||||
# Ensure session directory exists
|
||||
os.makedirs(self.session_dir, exist_ok=True)
|
||||
|
||||
async def authenticate(self) -> bool:
|
||||
"""Authenticate with Garmin Connect and persist session."""
|
||||
if not self.client:
|
||||
self.client = garth.Client()
|
||||
|
||||
try:
|
||||
# Try to load existing session
|
||||
self.client.load(self.session_dir)
|
||||
logger.info("Loaded existing Garmin session")
|
||||
return True
|
||||
except Exception:
|
||||
# Fresh authentication required
|
||||
try:
|
||||
await self.client.login(self.username, self.password)
|
||||
self.client.save(self.session_dir)
|
||||
logger.info("Successfully authenticated with Garmin Connect")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Garmin authentication failed: {str(e)}")
|
||||
raise GarminAuthError(f"Authentication failed: {str(e)}")
|
||||
|
||||
async def get_activities(self, limit: int = 10, start_date: datetime = None) -> List[Dict[str, Any]]:
|
||||
"""Fetch recent activities from Garmin Connect."""
|
||||
if not self.client:
|
||||
await self.authenticate()
|
||||
|
||||
if not start_date:
|
||||
start_date = datetime.now() - timedelta(days=7)
|
||||
|
||||
try:
|
||||
activities = self.client.get_activities(limit=limit, start=start_date)
|
||||
logger.info(f"Fetched {len(activities)} activities from Garmin")
|
||||
return activities
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch activities: {str(e)}")
|
||||
raise GarminAPIError(f"Failed to fetch activities: {str(e)}")
|
||||
|
||||
async def get_activity_details(self, activity_id: str) -> Dict[str, Any]:
|
||||
"""Get detailed activity data including metrics."""
|
||||
if not self.client:
|
||||
await self.authenticate()
|
||||
|
||||
try:
|
||||
details = self.client.get_activity(activity_id)
|
||||
logger.info(f"Fetched details for activity {activity_id}")
|
||||
return details
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch activity details for {activity_id}: {str(e)}")
|
||||
raise GarminAPIError(f"Failed to fetch activity details: {str(e)}")
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if we have a valid authenticated session."""
|
||||
return self.client is not None
|
||||
|
||||
|
||||
class GarminAuthError(Exception):
|
||||
"""Raised when Garmin authentication fails."""
|
||||
pass
|
||||
|
||||
|
||||
class GarminAPIError(Exception):
|
||||
"""Raised when Garmin API calls fail."""
|
||||
pass
|
||||
62
backend/app/services/gpx.py
Normal file
62
backend/app/services/gpx.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import UploadFile, HTTPException
|
||||
import gpxpy
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def store_gpx_file(file: UploadFile) -> str:
|
||||
"""Store uploaded GPX file and return path"""
|
||||
try:
|
||||
file_ext = os.path.splitext(file.filename)[1]
|
||||
if file_ext.lower() != '.gpx':
|
||||
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||
|
||||
file_name = f"{uuid.uuid4()}{file_ext}"
|
||||
file_path = os.path.join(settings.GPX_STORAGE_PATH, file_name)
|
||||
|
||||
# Ensure storage directory exists
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
# Save file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
return file_path
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing GPX file: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error storing file")
|
||||
|
||||
async def parse_gpx(file_path: str) -> dict:
|
||||
"""Parse GPX file and extract key metrics"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
gpx = gpxpy.parse(f)
|
||||
|
||||
total_distance = 0.0
|
||||
elevation_gain = 0.0
|
||||
points = []
|
||||
|
||||
for track in gpx.tracks:
|
||||
for segment in track.segments:
|
||||
total_distance += segment.length_3d()
|
||||
for i in range(1, len(segment.points)):
|
||||
elevation_gain += max(0, segment.points[i].elevation - segment.points[i-1].elevation)
|
||||
|
||||
points = [{
|
||||
'lat': point.latitude,
|
||||
'lon': point.longitude,
|
||||
'ele': point.elevation,
|
||||
'time': point.time.isoformat() if point.time else None
|
||||
} for point in segment.points]
|
||||
|
||||
return {
|
||||
'total_distance': total_distance,
|
||||
'elevation_gain': elevation_gain,
|
||||
'points': points
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing GPX file: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error parsing GPX file")
|
||||
74
backend/app/services/plan_evolution.py
Normal file
74
backend/app/services/plan_evolution.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.services.ai_service import AIService
|
||||
from app.models.analysis import Analysis
|
||||
from app.models.plan import Plan
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlanEvolutionService:
|
||||
"""Service for evolving training plans based on workout analysis."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.ai_service = AIService(db)
|
||||
|
||||
async def evolve_plan_from_analysis(
|
||||
self,
|
||||
analysis: Analysis,
|
||||
current_plan: Plan
|
||||
) -> Plan:
|
||||
"""Create a new plan version based on workout analysis."""
|
||||
if not analysis.approved:
|
||||
return None
|
||||
|
||||
suggestions = analysis.suggestions
|
||||
if not suggestions:
|
||||
return None
|
||||
|
||||
# Generate new plan incorporating suggestions
|
||||
evolution_context = {
|
||||
"current_plan": current_plan.jsonb_plan,
|
||||
"workout_analysis": analysis.jsonb_feedback,
|
||||
"suggestions": suggestions,
|
||||
"evolution_type": "workout_feedback"
|
||||
}
|
||||
|
||||
new_plan_data = await self.ai_service.evolve_plan(evolution_context)
|
||||
|
||||
# Create new plan version
|
||||
new_plan = Plan(
|
||||
jsonb_plan=new_plan_data,
|
||||
version=current_plan.version + 1,
|
||||
parent_plan_id=current_plan.id
|
||||
)
|
||||
|
||||
self.db.add(new_plan)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(new_plan)
|
||||
|
||||
logger.info(f"Created new plan version {new_plan.version} from analysis {analysis.id}")
|
||||
return new_plan
|
||||
|
||||
async def get_plan_evolution_history(self, plan_id: int) -> list[Plan]:
|
||||
"""Get the evolution history for a plan."""
|
||||
result = await self.db.execute(
|
||||
select(Plan)
|
||||
.where(
|
||||
(Plan.id == plan_id) |
|
||||
(Plan.parent_plan_id == plan_id)
|
||||
)
|
||||
.order_by(Plan.version)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_current_active_plan(self) -> Plan:
|
||||
"""Get the most recent active plan."""
|
||||
result = await self.db.execute(
|
||||
select(Plan)
|
||||
.order_by(Plan.version.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
92
backend/app/services/prompt_manager.py
Normal file
92
backend/app/services/prompt_manager.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, func
|
||||
from app.models.prompt import Prompt
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Service for managing AI prompts with versioning."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_active_prompt(self, action_type: str, model: str = None) -> str:
|
||||
"""Get the active prompt for a specific action type."""
|
||||
query = select(Prompt).where(
|
||||
Prompt.action_type == action_type,
|
||||
Prompt.active == True
|
||||
)
|
||||
if model:
|
||||
query = query.where(Prompt.model == model)
|
||||
|
||||
result = await self.db.execute(query.order_by(Prompt.version.desc()))
|
||||
prompt = result.scalar_one_or_none()
|
||||
return prompt.prompt_text if prompt else None
|
||||
|
||||
async def create_prompt_version(
|
||||
self,
|
||||
action_type: str,
|
||||
prompt_text: str,
|
||||
model: str = None
|
||||
) -> Prompt:
|
||||
"""Create a new version of a prompt."""
|
||||
# Deactivate previous versions
|
||||
await self.db.execute(
|
||||
update(Prompt)
|
||||
.where(Prompt.action_type == action_type)
|
||||
.values(active=False)
|
||||
)
|
||||
|
||||
# Get next version number
|
||||
result = await self.db.execute(
|
||||
select(func.max(Prompt.version))
|
||||
.where(Prompt.action_type == action_type)
|
||||
)
|
||||
max_version = result.scalar() or 0
|
||||
|
||||
# Create new prompt
|
||||
new_prompt = Prompt(
|
||||
action_type=action_type,
|
||||
model=model,
|
||||
prompt_text=prompt_text,
|
||||
version=max_version + 1,
|
||||
active=True
|
||||
)
|
||||
|
||||
self.db.add(new_prompt)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(new_prompt)
|
||||
|
||||
logger.info(f"Created new prompt version {new_prompt.version} for {action_type}")
|
||||
return new_prompt
|
||||
|
||||
async def get_prompt_history(self, action_type: str) -> list[Prompt]:
|
||||
"""Get all versions of prompts for an action type."""
|
||||
result = await self.db.execute(
|
||||
select(Prompt)
|
||||
.where(Prompt.action_type == action_type)
|
||||
.order_by(Prompt.version.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def activate_prompt_version(self, prompt_id: int) -> bool:
|
||||
"""Activate a specific prompt version."""
|
||||
# First deactivate all prompts for this action type
|
||||
prompt = await self.db.get(Prompt, prompt_id)
|
||||
if not prompt:
|
||||
return False
|
||||
|
||||
await self.db.execute(
|
||||
update(Prompt)
|
||||
.where(Prompt.action_type == prompt.action_type)
|
||||
.values(active=False)
|
||||
)
|
||||
|
||||
# Activate the specific version
|
||||
prompt.active = True
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(f"Activated prompt version {prompt.version} for {prompt.action_type}")
|
||||
return True
|
||||
90
backend/app/services/workout_sync.py
Normal file
90
backend/app/services/workout_sync.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.services.garmin import GarminService, GarminAPIError
|
||||
from app.models.workout import Workout
|
||||
from app.models.garmin_sync_log import GarminSyncLog
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkoutSyncService:
|
||||
"""Service for syncing Garmin activities to database."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.garmin_service = GarminService()
|
||||
|
||||
async def sync_recent_activities(self, days_back: int = 7) -> int:
|
||||
"""Sync recent Garmin activities to database."""
|
||||
try:
|
||||
# Create sync log entry
|
||||
sync_log = GarminSyncLog(status="in_progress")
|
||||
self.db.add(sync_log)
|
||||
await self.db.commit()
|
||||
|
||||
# Calculate start date
|
||||
start_date = datetime.now() - timedelta(days=days_back)
|
||||
|
||||
# Fetch activities from Garmin
|
||||
activities = await self.garmin_service.get_activities(
|
||||
limit=50, start_date=start_date
|
||||
)
|
||||
|
||||
synced_count = 0
|
||||
for activity in activities:
|
||||
if await self.activity_exists(activity['activityId']):
|
||||
continue
|
||||
|
||||
# Parse and create workout
|
||||
workout_data = await self.parse_activity_data(activity)
|
||||
workout = Workout(**workout_data)
|
||||
self.db.add(workout)
|
||||
synced_count += 1
|
||||
|
||||
# Update sync log
|
||||
sync_log.status = "success"
|
||||
sync_log.activities_synced = synced_count
|
||||
sync_log.last_sync_time = datetime.now()
|
||||
|
||||
await self.db.commit()
|
||||
logger.info(f"Successfully synced {synced_count} activities")
|
||||
return synced_count
|
||||
|
||||
except GarminAPIError as e:
|
||||
sync_log.status = "error"
|
||||
sync_log.error_message = str(e)
|
||||
await self.db.commit()
|
||||
logger.error(f"Garmin API error during sync: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
sync_log.status = "error"
|
||||
sync_log.error_message = str(e)
|
||||
await self.db.commit()
|
||||
logger.error(f"Unexpected error during sync: {str(e)}")
|
||||
raise
|
||||
|
||||
async def activity_exists(self, garmin_activity_id: str) -> bool:
|
||||
"""Check if activity already exists in database."""
|
||||
result = await self.db.execute(
|
||||
select(Workout).where(Workout.garmin_activity_id == garmin_activity_id)
|
||||
)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
async def parse_activity_data(self, activity: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse Garmin activity data into workout model format."""
|
||||
return {
|
||||
"garmin_activity_id": activity['activityId'],
|
||||
"activity_type": activity.get('activityType', {}).get('typeKey'),
|
||||
"start_time": datetime.fromisoformat(activity['startTimeLocal'].replace('Z', '+00:00')),
|
||||
"duration_seconds": activity.get('duration'),
|
||||
"distance_m": activity.get('distance'),
|
||||
"avg_hr": activity.get('averageHR'),
|
||||
"max_hr": activity.get('maxHR'),
|
||||
"avg_power": activity.get('avgPower'),
|
||||
"max_power": activity.get('maxPower'),
|
||||
"avg_cadence": activity.get('averageBikingCadenceInRevPerMinute'),
|
||||
"elevation_gain_m": activity.get('elevationGain'),
|
||||
"metrics": activity # Store full Garmin data as JSONB
|
||||
}
|
||||
Reference in New Issue
Block a user