This commit is contained in:
2025-09-10 11:46:57 -07:00
parent 2cc2b4c9ce
commit f443e7a64e
33 changed files with 887 additions and 1467 deletions

View File

@@ -45,21 +45,23 @@ COPY . .
RUN echo '#!/bin/bash\n\
set -e\n\
\n\
# Run database migrations\n\
# Run database migrations synchronously\n\
echo "Running database migrations..."\n\
alembic upgrade head\n\
python -m alembic upgrade head\n\
\n\
# Verify migration success\n\
echo "Verifying migration status..."\n\
alembic current\n\
python -m alembic current\n\
\n\
# Start the application\n\
echo "Starting application..."\n\
exec "$@"' > /app/entrypoint.sh && \
chmod +x /app/entrypoint.sh
# Create non-root user
RUN useradd -m appuser && chown -R appuser:appuser /app
# Create non-root user and logs directory
RUN useradd -m appuser && \
mkdir -p /app/logs && \
chown -R appuser:appuser /app
USER appuser
# Expose application port

View File

@@ -1,6 +1,6 @@
[alembic]
script_location = alembic
sqlalchemy.url = postgresql+asyncpg://appuser:password@db:5432/cyclingdb
sqlalchemy.url = postgresql+asyncpg://postgres:password@db:5432/cycling
[loggers]
keys = root
@@ -8,6 +8,9 @@ keys = root
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console

View File

@@ -9,8 +9,8 @@ import os
sys.path.append(os.getcwd())
# Import base and models
from app.models import Base
from app.database import DATABASE_URL
from app.models.base import Base
from app.config import settings
config = context.config
fileConfig(config.config_file_name)
@@ -30,7 +30,7 @@ def run_migrations_offline():
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
async def run_migrations_online():
"""Run migrations in 'online' mode."""
connectable = AsyncEngine(
engine_from_config(
@@ -38,16 +38,17 @@ def run_migrations_online():
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
url=DATABASE_URL,
url=settings.DATABASE_URL,
)
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
async def do_run_migrations(connection):
def do_run_migrations(connection):
context.configure(connection=connection, target_metadata=target_metadata)
await connection.run_sync(context.run_migrations)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()

View File

@@ -1,11 +1,11 @@
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
DATABASE_URL: str
GPX_STORAGE_PATH: str
AI_MODEL: str = "openrouter/auto"
API_KEY: str
class Config:
env_file = ".env"
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
settings = Settings()

View File

@@ -1,7 +1,8 @@
import os
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import declarative_base, sessionmaker
DATABASE_URL = "postgresql+asyncpg://appuser:password@db:5432/cyclingdb"
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:password@db:5432/cycling")
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = sessionmaker(

View File

@@ -1,6 +1,9 @@
import logging
import json
from datetime import datetime
from fastapi import FastAPI, Depends, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from .database import get_db, get_database_url
from .database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from alembic.config import Config
@@ -14,6 +17,45 @@ from .routes import prompts as prompt_routes
from .routes import dashboard as dashboard_routes
from .config import settings
# Configure structured JSON logging
class StructuredJSONFormatter(logging.Formatter):
def format(self, record):
log_data = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"message": record.getMessage(),
"logger": record.name,
"module": record.module,
"function": record.funcName,
"line": record.lineno,
"thread": record.threadName,
}
if hasattr(record, 'extra'):
log_data.update(record.extra)
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
return json.dumps(log_data)
# Set up logging
logger = logging.getLogger("ai_cycling_coach")
logger.setLevel(logging.INFO)
# Create console handler with structured JSON format
console_handler = logging.StreamHandler()
console_handler.setFormatter(StructuredJSONFormatter())
logger.addHandler(console_handler)
# Configure rotating file handler
from logging.handlers import RotatingFileHandler
file_handler = RotatingFileHandler(
filename="/app/logs/app.log",
maxBytes=10*1024*1024, # 10 MB
backupCount=5,
encoding='utf-8'
)
file_handler.setFormatter(StructuredJSONFormatter())
logger.addHandler(file_handler)
app = FastAPI(
title="AI Cycling Coach API",
description="Backend service for AI-assisted cycling training platform",
@@ -49,61 +91,16 @@ app.include_router(workout_routes.router, prefix="/workouts", tags=["workouts"])
app.include_router(prompt_routes.router, prefix="/prompts", tags=["prompts"])
app.include_router(dashboard_routes.router, prefix="/api/dashboard", tags=["dashboard"])
async def check_migration_status():
"""Check if database migrations are up to date."""
try:
# Get Alembic configuration
config = Config("alembic.ini")
config.set_main_option("sqlalchemy.url", get_database_url())
script = ScriptDirectory.from_config(config)
# Get current database revision
from sqlalchemy import create_engine
engine = create_engine(get_database_url())
with engine.connect() as conn:
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
# Get head revision
head_rev = script.get_current_head()
return {
"current_revision": current_rev,
"head_revision": head_rev,
"migrations_up_to_date": current_rev == head_rev
}
except Exception as e:
return {
"error": str(e),
"migrations_up_to_date": False
}
@app.get("/health")
async def health_check(db: AsyncSession = Depends(get_db)):
"""Enhanced health check with migration verification."""
health_status = {
async def health_check():
"""Simplified health check endpoint."""
return {
"status": "healthy",
"version": "0.1.0",
"timestamp": "2024-01-15T10:30:00Z" # Should be dynamic
"timestamp": datetime.utcnow().isoformat()
}
# Database connection check
try:
await db.execute(text("SELECT 1"))
health_status["database"] = "connected"
except Exception as e:
health_status["status"] = "unhealthy"
health_status["database"] = f"error: {str(e)}"
# Migration status check
migration_info = await check_migration_status()
health_status["migrations"] = migration_info
if not migration_info.get("migrations_up_to_date", False):
health_status["status"] = "unhealthy"
return health_status
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
logger.info("Starting AI Cycling Coach API server")
uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, ForeignKey, JSON, Boolean, DateTime
from sqlalchemy import Column, Integer, String, ForeignKey, JSON, Boolean, DateTime, func
from sqlalchemy.orm import relationship
from .base import BaseModel

View File

@@ -0,0 +1,12 @@
from sqlalchemy import Column, Integer, ForeignKey
from sqlalchemy.orm import relationship
from .base import BaseModel
class PlanRule(BaseModel):
__tablename__ = "plan_rules"
plan_id = Column(Integer, ForeignKey('plans.id'), primary_key=True)
rule_id = Column(Integer, ForeignKey('rules.id'), primary_key=True)
plan = relationship("Plan", back_populates="rules")
rule = relationship("Rule", back_populates="plans")

View File

@@ -1,5 +1,6 @@
from sqlalchemy import Column, Integer, ForeignKey, Boolean
from sqlalchemy import Column, Integer, ForeignKey, Boolean, String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from .base import BaseModel
class Rule(BaseModel):

View File

@@ -1,9 +1,54 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse, JSONResponse
from app.services.health_monitor import HealthMonitor
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST, Gauge
from pathlib import Path
import json
router = APIRouter()
monitor = HealthMonitor()
# Prometheus metrics
SYNC_QUEUE = Gauge('sync_queue_size', 'Current Garmin sync queue size')
PENDING_ANALYSES = Gauge('pending_analyses', 'Number of pending workout analyses')
@router.get("/health")
async def get_health():
return monitor.check_system_health()
return monitor.check_system_health()
@router.get("/metrics")
async def prometheus_metrics():
# Update metrics with latest values
health_data = monitor.check_system_health()
SYNC_QUEUE.set(health_data['services'].get('sync_queue_size', 0))
PENDING_ANALYSES.set(health_data['services'].get('pending_analyses', 0))
return PlainTextResponse(
content=generate_latest(),
media_type=CONTENT_TYPE_LATEST
)
@router.get("/dashboard/health", response_class=JSONResponse)
async def health_dashboard():
"""Health dashboard endpoint with aggregated monitoring data"""
health_data = monitor.check_system_health()
# Get recent logs (last 100 lines)
log_file = Path("/app/logs/app.log")
recent_logs = []
try:
with log_file.open() as f:
lines = f.readlines()[-100:]
recent_logs = [json.loads(line.strip()) for line in lines]
except FileNotFoundError:
pass
return {
"system": health_data,
"logs": recent_logs,
"statistics": {
"log_entries": len(recent_logs),
"error_count": sum(1 for log in recent_logs if log.get('level') == 'ERROR'),
"warning_count": sum(1 for log in recent_logs if log.get('level') == 'WARNING')
}
}

View File

@@ -8,8 +8,9 @@ from app.models.workout import Workout
from app.models.analysis import Analysis
from app.models.garmin_sync_log import GarminSyncLog
from app.models.plan import Plan
from app.schemas.workout import Workout as WorkoutSchema, WorkoutSyncStatus
from app.schemas.workout import Workout as WorkoutSchema, WorkoutSyncStatus, WorkoutMetric
from app.schemas.analysis import Analysis as AnalysisSchema
from app.schemas.plan import Plan as PlanSchema
from app.services.workout_sync import WorkoutSyncService
from app.services.ai_service import AIService
from app.services.plan_evolution import PlanEvolutionService
@@ -32,7 +33,7 @@ async def read_workout(workout_id: int, db: AsyncSession = Depends(get_db)):
raise HTTPException(status_code=404, detail="Workout not found")
return workout
@router.get("/{workout_id}/metrics", response_model=list[schemas.WorkoutMetric])
@router.get("/{workout_id}/metrics", response_model=list[WorkoutMetric])
async def get_workout_metrics(
workout_id: int,
db: AsyncSession = Depends(get_db)
@@ -153,7 +154,7 @@ async def approve_analysis(
return {"message": "Analysis approved"}
@router.get("/plans/{plan_id}/evolution", response_model=List[schemas.Plan])
@router.get("/plans/{plan_id}/evolution", response_model=List[PlanSchema])
async def get_plan_evolution(
plan_id: int,
db: AsyncSession = Depends(get_db)

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from datetime import datetime
from typing import List, Optional
from uuid import UUID

View File

@@ -36,38 +36,51 @@ class HealthMonitor:
return {
'database': self._check_database(),
'garmin_sync': self._check_garmin_sync(),
'ai_service': self._check_ai_service()
'ai_service': self._check_ai_service(),
'sync_queue_size': self._get_sync_queue_size(),
'pending_analyses': self._count_pending_analyses()
}
def _get_sync_queue_size(self) -> int:
"""Get number of pending sync operations"""
from app.models.garmin_sync_log import GarminSyncLog, SyncStatus
return GarminSyncLog.query.filter_by(status=SyncStatus.PENDING).count()
def _count_pending_analyses(self) -> int:
"""Count workouts needing analysis"""
from app.models.workout import Workout
return Workout.query.filter_by(analysis_status='pending').count()
def _check_database(self) -> str:
try:
with get_db() as db:
db.execute(text("SELECT 1"))
return "ok"
except Exception as e:
logger.error(f"Database check failed: {str(e)}")
logger.error("Database check failed", extra={"component": "database", "error": str(e)})
return "down"
def _check_garmin_sync(self) -> str:
try:
last_sync = GarminSyncLog.get_latest()
if last_sync and last_sync.status == SyncStatus.FAILED:
logger.warning("Garmin sync has failed status", extra={"component": "garmin_sync", "status": last_sync.status.value})
return "warning"
return "ok"
except Exception as e:
logger.error(f"Garmin sync check failed: {str(e)}")
logger.error("Garmin sync check failed", extra={"component": "garmin_sync", "error": str(e)})
return "down"
def _check_ai_service(self) -> str:
try:
response = requests.get(
f"{settings.AI_SERVICE_URL}/ping",
f"{settings.AI_SERVICE_URL}/ping",
timeout=5,
headers={"Authorization": f"Bearer {settings.OPENROUTER_API_KEY}"}
)
return "ok" if response.ok else "down"
except Exception as e:
logger.error(f"AI service check failed: {str(e)}")
logger.error("AI service check failed", extra={"component": "ai_service", "error": str(e)})
return "down"
def _log_anomalies(self, metrics: Dict[str, Any]):
@@ -75,6 +88,7 @@ class HealthMonitor:
for metric, value in metrics.items():
if metric in self.warning_thresholds and value > self.warning_thresholds[metric]:
alerts.append(f"{metric} {value}%")
logger.warning("System threshold exceeded", extra={"metric": metric, "value": value, "threshold": self.warning_thresholds[metric]})
if alerts:
logger.warning(f"System thresholds exceeded: {', '.join(alerts)}")
logger.warning("System thresholds exceeded", extra={"alerts": alerts})

View File

@@ -6,6 +6,7 @@ from app.models.garmin_sync_log import GarminSyncLog
from app.models.garmin_sync_log import GarminSyncLog
from datetime import datetime, timedelta
import logging
from typing import Dict, Any
import asyncio
logger = logging.getLogger(__name__)

View File

@@ -8,4 +8,5 @@ pydantic-settings==2.2.1
python-multipart==0.0.9
gpxpy # Add GPX parsing library
garth==0.4.46 # Garmin Connect API client
httpx==0.25.2 # Async HTTP client for OpenRouter API
httpx==0.25.2 # Async HTTP client for OpenRouter API
asyncpg==0.29.0 # Async PostgreSQL driver

View File

@@ -24,6 +24,9 @@ class DatabaseManager:
def __init__(self, backup_dir: str = "/app/data/backups"):
self.backup_dir = Path(backup_dir)
self.backup_dir.mkdir(parents=True, exist_ok=True)
self.gpx_dir = Path("/app/data/gpx")
self.manifest_file = self.backup_dir / "gpx_manifest.json"
self.encryption_key = os.getenv("BACKUP_ENCRYPTION_KEY").encode()
def get_db_connection_params(self):
"""Extract database connection parameters from URL."""
@@ -39,15 +42,91 @@ class DatabaseManager:
'database': parsed.path.lstrip('/')
}
def _backup_gpx_files(self, backup_dir: Path) -> Optional[Path]:
"""Backup GPX files directory"""
gpx_dir = Path("/app/data/gpx")
if not gpx_dir.exists():
return None
backup_path = backup_dir / "gpx.tar.gz"
with tarfile.open(backup_path, "w:gz") as tar:
tar.add(gpx_dir, arcname="gpx")
return backup_path
def _backup_sessions(self, backup_dir: Path) -> Optional[Path]:
"""Backup Garmin sessions directory"""
sessions_dir = Path("/app/data/sessions")
if not sessions_dir.exists():
return None
backup_path = backup_dir / "sessions.tar.gz"
with tarfile.open(backup_path, "w:gz") as tar:
tar.add(sessions_dir, arcname="sessions")
return backup_path
def _generate_checksum(self, file_path: Path) -> str:
"""Generate SHA256 checksum for a file"""
hash_sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def _verify_backup_integrity(self, backup_path: Path):
"""Verify backup file integrity using checksum"""
checksum_file = backup_path.with_suffix('.sha256')
if not checksum_file.exists():
raise FileNotFoundError(f"Checksum file missing for {backup_path.name}")
with open(checksum_file) as f:
expected_checksum = f.read().split()[0]
actual_checksum = self._generate_checksum(backup_path)
if actual_checksum != expected_checksum:
raise ValueError(f"Checksum mismatch for {backup_path.name}")
def create_backup(self, name: Optional[str] = None) -> str:
"""Create a database backup."""
"""Create a full system backup including database, GPX files, and sessions"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_name = name or f"backup_{timestamp}"
backup_file = self.backup_dir / f"{backup_name}.sql"
backup_name = name or f"full_backup_{timestamp}"
backup_dir = self.backup_dir / backup_name
backup_dir.mkdir(parents=True, exist_ok=True)
try:
# Backup database
db_backup_path = self._backup_database(backup_dir)
# Backup GPX files
gpx_backup_path = self._backup_gpx_files(backup_dir)
# Backup sessions
sessions_backup_path = self._backup_sessions(backup_dir)
# Generate checksums for all backup files
for file in backup_dir.glob("*"):
if file.is_file():
checksum = self._generate_checksum(file)
with open(f"{file}.sha256", "w") as f:
f.write(f"{checksum} {file.name}")
# Verify backups
for file in backup_dir.glob("*"):
if file.is_file() and not file.name.endswith('.sha256'):
self._verify_backup_integrity(file)
print(f"✅ Full backup created successfully: {backup_dir}")
return str(backup_dir)
except Exception as e:
shutil.rmtree(backup_dir, ignore_errors=True)
print(f"❌ Backup failed: {str(e)}")
raise
def _backup_database(self, backup_dir: Path) -> Path:
"""Create database backup"""
params = self.get_db_connection_params()
backup_file = backup_dir / "database.dump"
# Use pg_dump for backup
cmd = [
"pg_dump",
"-h", params['host'],
@@ -56,28 +135,18 @@ class DatabaseManager:
"-d", params['database'],
"-f", str(backup_file),
"--no-password",
"--format=custom", # Custom format for better compression
"--format=custom",
"--compress=9"
]
# Set password environment variable
env = os.environ.copy()
env['PGPASSWORD'] = params['password']
try:
print(f"Creating backup: {backup_file}")
result = subprocess.run(cmd, env=env, capture_output=True, text=True)
if result.returncode == 0:
print(f"✅ Backup created successfully: {backup_file}")
return str(backup_file)
else:
print(f"❌ Backup failed: {result.stderr}")
raise Exception(f"Backup failed: {result.stderr}")
except FileNotFoundError:
print("❌ pg_dump not found. Ensure PostgreSQL client tools are installed.")
raise
result = subprocess.run(cmd, env=env, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"Database backup failed: {result.stderr}")
return backup_file
def restore_backup(self, backup_file: str, confirm: bool = False) -> None:
"""Restore database from backup."""
@@ -128,6 +197,80 @@ class DatabaseManager:
print("❌ pg_restore not found. Ensure PostgreSQL client tools are installed.")
raise
def backup_gpx_files(self, incremental: bool = True) -> Optional[Path]:
"""Handle GPX backup creation with incremental/full strategy"""
try:
if incremental:
return self._incremental_gpx_backup()
return self._full_gpx_backup()
except Exception as e:
print(f"GPX backup failed: {str(e)}")
return None
def _full_gpx_backup(self) -> Path:
"""Create full GPX backup"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = self.backup_dir / f"gpx_full_{timestamp}"
backup_path.mkdir()
# Copy all GPX files
subprocess.run(["rsync", "-a", f"{self.gpx_dir}/", f"{backup_path}/"])
self._encrypt_backup(backup_path)
return backup_path
def _incremental_gpx_backup(self) -> Optional[Path]:
"""Create incremental GPX backup using rsync --link-dest"""
last_full = self._find_last_full_backup()
if not last_full:
return self._full_gpx_backup()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = self.backup_dir / f"gpx_inc_{timestamp}"
backup_path.mkdir()
# Use hardlinks to previous backup for incremental
subprocess.run([
"rsync", "-a",
"--link-dest", str(last_full),
f"{self.gpx_dir}/",
f"{backup_path}/"
])
self._encrypt_backup(backup_path)
return backup_path
def _find_last_full_backup(self) -> Optional[Path]:
"""Find most recent full backup"""
full_backups = sorted(self.backup_dir.glob("gpx_full_*"), reverse=True)
return full_backups[0] if full_backups else None
def _encrypt_backup(self, backup_path: Path):
"""Encrypt backup directory using Fernet (AES-256-CBC with HMAC-SHA256)"""
from cryptography.fernet import Fernet
fernet = Fernet(self.encryption_key)
for file in backup_path.rglob('*'):
if file.is_file():
with open(file, 'rb') as f:
data = f.read()
encrypted = fernet.encrypt(data)
with open(file, 'wb') as f:
f.write(encrypted)
def decrypt_backup(self, backup_path: Path):
"""Decrypt backup directory"""
from cryptography.fernet import Fernet
fernet = Fernet(self.encryption_key)
for file in backup_path.rglob('*'):
if file.is_file():
with open(file, 'rb') as f:
data = f.read()
decrypted = fernet.decrypt(data)
with open(file, 'wb') as f:
f.write(decrypted)
def _recreate_database(self):
"""Drop and recreate the database."""
params = self.get_db_connection_params()
@@ -184,10 +327,11 @@ class DatabaseManager:
cutoff = datetime.now() - timedelta(days=keep_days)
removed = []
for backup in self.backup_dir.glob("*.sql"):
if datetime.fromtimestamp(backup.stat().st_mtime) < cutoff:
backup.unlink()
removed.append(backup.name)
# Clean all backup directories (full_backup_*)
for backup_dir in self.backup_dir.glob("full_backup_*"):
if backup_dir.is_dir() and datetime.fromtimestamp(backup_dir.stat().st_mtime) < cutoff:
shutil.rmtree(backup_dir)
removed.append(backup_dir.name)
if removed:
print(f"Removed {len(removed)} old backups: {', '.join(removed)}")
@@ -198,10 +342,12 @@ def main():
if len(sys.argv) < 2:
print("Usage: python backup_restore.py <command> [options]")
print("Commands:")
print(" backup [name] - Create a new backup")
print(" backup [name] - Create a new database backup")
print(" gpx-backup [--full] - Create GPX backup (incremental by default)")
print(" restore <file> [--yes] - Restore from backup")
print(" list - List available backups")
print(" cleanup [days] - Remove backups older than N days (default: 30)")
print(" decrypt <dir> - Decrypt backup directory")
sys.exit(1)
manager = DatabaseManager()
@@ -210,13 +356,21 @@ def main():
try:
if command == "backup":
name = sys.argv[2] if len(sys.argv) > 2 else None
manager.create_backup(name)
name = sys.argv[2] if len(sys.argv) > 2 else None
manager.create_backup(name)
elif command == "gpx-backup":
if len(sys.argv) > 2 and sys.argv[2] == "--full":
manager.backup_gpx_files(incremental=False)
else:
manager.backup_gpx_files()
elif command == "restore":
if len(sys.argv) < 3:
print("Error: Please specify backup file to restore from")
sys.exit(1)
backup_file = sys.argv[2]
confirm = "--yes" in sys.argv
backup_file = sys.argv[2]
confirm = "--yes" in sys.argv
manager.restore_backup(backup_file, confirm)

View File

@@ -0,0 +1,102 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from app.services.ai_service import AIService, AIServiceError
from app.models.workout import Workout
import json
@pytest.mark.asyncio
async def test_analyze_workout_success():
"""Test successful workout analysis with valid API response"""
mock_db = MagicMock()
mock_prompt = MagicMock()
mock_prompt.format.return_value = "test prompt"
ai_service = AIService(mock_db)
ai_service.prompt_manager.get_active_prompt = AsyncMock(return_value=mock_prompt)
test_response = json.dumps({
"performance_summary": "Good workout",
"suggestions": ["More recovery"]
})
with patch('httpx.AsyncClient.post') as mock_post:
mock_post.return_value = AsyncMock(
status_code=200,
json=lambda: {"choices": [{"message": {"content": test_response}}]}
)
workout = Workout(activity_type="cycling", duration_seconds=3600)
result = await ai_service.analyze_workout(workout)
assert "performance_summary" in result
assert len(result["suggestions"]) == 1
@pytest.mark.asyncio
async def test_generate_plan_success():
"""Test plan generation with structured response"""
mock_db = MagicMock()
ai_service = AIService(mock_db)
ai_service.prompt_manager.get_active_prompt = AsyncMock(return_value="Plan prompt: {rules} {goals}")
test_plan = {
"weeks": [{"workouts": ["ride"]}],
"focus": "endurance"
}
with patch('httpx.AsyncClient.post') as mock_post:
mock_post.return_value = AsyncMock(
status_code=200,
json=lambda: {"choices": [{"message": {"content": json.dumps(test_plan)}}]}
)
result = await ai_service.generate_plan([], {})
assert "weeks" in result
assert result["focus"] == "endurance"
@pytest.mark.asyncio
async def test_api_retry_logic():
"""Test API request retries on failure"""
mock_db = MagicMock()
ai_service = AIService(mock_db)
with patch('httpx.AsyncClient.post') as mock_post:
mock_post.side_effect = Exception("API failure")
with pytest.raises(AIServiceError):
await ai_service._make_ai_request("test")
assert mock_post.call_count == 3
@pytest.mark.asyncio
async def test_invalid_json_handling():
"""Test graceful handling of invalid JSON responses"""
mock_db = MagicMock()
ai_service = AIService(mock_db)
with patch('httpx.AsyncClient.post') as mock_post:
mock_post.return_value = AsyncMock(
status_code=200,
json=lambda: {"choices": [{"message": {"content": "invalid{json"}}]}
)
result = await ai_service.parse_rules_from_natural_language("test")
assert "raw_rules" in result
assert not result["structured"]
@pytest.mark.asyncio
async def test_code_block_parsing():
"""Test extraction of JSON from code blocks"""
mock_db = MagicMock()
ai_service = AIService(mock_db)
test_response = "```json\n" + json.dumps({"max_rides": 4}) + "\n```"
with patch('httpx.AsyncClient.post') as mock_post:
mock_post.return_value = AsyncMock(
status_code=200,
json=lambda: {"choices": [{"message": {"content": test_response}}]}
)
result = await ai_service.evolve_plan({})
assert "max_rides" in result
assert result["max_rides"] == 4

View File

@@ -0,0 +1,56 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from app.services.plan_evolution import PlanEvolutionService
from app.models.plan import Plan
from app.models.analysis import Analysis
from datetime import datetime
@pytest.mark.asyncio
async def test_evolve_plan_with_valid_analysis():
"""Test plan evolution with approved analysis and suggestions"""
mock_db = AsyncMock()
mock_plan = Plan(
id=1,
version=1,
jsonb_plan={"weeks": []},
parent_plan_id=None
)
mock_analysis = Analysis(
approved=True,
jsonb_feedback={"suggestions": ["More recovery"]}
)
service = PlanEvolutionService(mock_db)
service.ai_service.evolve_plan = AsyncMock(return_value={"weeks": [{"recovery": True}]})
result = await service.evolve_plan_from_analysis(mock_analysis, mock_plan)
assert result.version == 2
assert result.parent_plan_id == 1
mock_db.add.assert_called_once()
mock_db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_evolution_skipped_for_unapproved_analysis():
"""Test plan evolution is skipped for unapproved analysis"""
mock_db = AsyncMock()
mock_analysis = Analysis(approved=False)
service = PlanEvolutionService(mock_db)
result = await service.evolve_plan_from_analysis(mock_analysis, MagicMock())
assert result is None
@pytest.mark.asyncio
async def test_evolution_history_retrieval():
"""Test getting plan evolution history"""
mock_db = AsyncMock()
mock_db.execute.return_value.scalars.return_value = [
Plan(version=1), Plan(version=2)
]
service = PlanEvolutionService(mock_db)
history = await service.get_plan_evolution_history(1)
assert len(history) == 2
assert history[0].version == 1

View File

@@ -0,0 +1,81 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.workout_sync import WorkoutSyncService
from app.models.workout import Workout
from app.models.garmin_sync_log import GarminSyncLog
from datetime import datetime, timedelta
import asyncio
@pytest.mark.asyncio
async def test_successful_sync():
"""Test successful sync of new activities"""
mock_db = AsyncMock()
mock_garmin = MagicMock()
mock_garmin.get_activities.return_value = [{'activityId': '123'}]
mock_garmin.get_activity_details.return_value = {'metrics': 'data'}
service = WorkoutSyncService(mock_db)
service.garmin_service = mock_garmin
result = await service.sync_recent_activities()
assert result == 1
mock_db.add.assert_called()
mock_db.commit.assert_awaited()
@pytest.mark.asyncio
async def test_duplicate_activity_handling():
"""Test skipping duplicate activities"""
mock_db = AsyncMock()
mock_db.execute.return_value.scalar_one_or_none.return_value = True
mock_garmin = MagicMock()
mock_garmin.get_activities.return_value = [{'activityId': '123'}]
service = WorkoutSyncService(mock_db)
service.garmin_service = mock_garmin
result = await service.sync_recent_activities()
assert result == 0
@pytest.mark.asyncio
async def test_activity_detail_retry_logic():
"""Test retry logic for activity details"""
mock_db = AsyncMock()
mock_garmin = MagicMock()
mock_garmin.get_activities.return_value = [{'activityId': '123'}]
mock_garmin.get_activity_details.side_effect = [Exception(), {'metrics': 'data'}]
service = WorkoutSyncService(mock_db)
service.garmin_service = mock_garmin
result = await service.sync_recent_activities()
assert mock_garmin.get_activity_details.call_count == 2
assert result == 1
@pytest.mark.asyncio
async def test_auth_error_handling():
"""Test authentication error handling"""
mock_db = AsyncMock()
mock_garmin = MagicMock()
mock_garmin.get_activities.side_effect = Exception("Auth failed")
service = WorkoutSyncService(mock_db)
service.garmin_service = mock_garmin
with pytest.raises(Exception):
await service.sync_recent_activities()
sync_log = mock_db.add.call_args[0][0]
assert sync_log.status == "auth_error"
@pytest.mark.asyncio
async def test_get_sync_status():
"""Test retrieval of latest sync status"""
mock_db = AsyncMock()
mock_log = GarminSyncLog(status="success")
mock_db.execute.return_value.scalar_one_or_none.return_value = mock_log
service = WorkoutSyncService(mock_db)
result = await service.get_latest_sync_status()
assert result.status == "success"