mirror of
https://github.com/sstent/AICyclingCoach.git
synced 2026-02-06 14:31:52 +00:00
sync
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
from fastapi import HTTPException, Header, status
|
||||
import os
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.services.ai_service import AIService
|
||||
from typing import AsyncGenerator
|
||||
|
||||
async def verify_api_key(api_key: str = Header(..., alias="X-API-Key")):
|
||||
"""Dependency to verify API key header"""
|
||||
expected_key = os.getenv("API_KEY")
|
||||
if not expected_key or api_key != expected_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API Key"
|
||||
)
|
||||
|
||||
async def get_ai_service(db: AsyncSession = Depends(get_db)) -> AIService:
|
||||
"""Get AI service instance with database dependency."""
|
||||
return AIService(db)
|
||||
@@ -65,7 +65,11 @@ app = FastAPI(
|
||||
# API Key Authentication Middleware
|
||||
@app.middleware("http")
|
||||
async def api_key_auth(request: Request, call_next):
|
||||
if request.url.path.startswith("/docs") or request.url.path.startswith("/redoc") or request.url.path == "/health":
|
||||
# Skip authentication for documentation and health endpoints
|
||||
if (request.url.path.startswith("/docs") or
|
||||
request.url.path.startswith("/redoc") or
|
||||
request.url.path == "/health" or
|
||||
request.url.path == "/openapi.json"):
|
||||
return await call_next(request)
|
||||
|
||||
api_key = request.headers.get("X-API-KEY")
|
||||
|
||||
@@ -11,5 +11,6 @@ class Plan(BaseModel):
|
||||
parent_plan_id = Column(Integer, ForeignKey('plans.id'), nullable=True)
|
||||
|
||||
parent_plan = relationship("Plan", remote_side="Plan.id", backref="child_plans")
|
||||
analyses = relationship("Analysis", back_populates="plan")
|
||||
workouts = relationship("Workout", back_populates="plan", cascade="all, delete-orphan")
|
||||
analyses = relationship("Analysis", back_populates="plan", lazy="selectin")
|
||||
workouts = relationship("Workout", back_populates="plan", cascade="all, delete-orphan", lazy="selectin")
|
||||
rules = relationship("Rule", secondary="plan_rules", back_populates="plans", lazy="selectin")
|
||||
@@ -1,5 +1,4 @@
|
||||
from sqlalchemy import Column, Integer, ForeignKey, Boolean, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy import Column, Integer, ForeignKey, Boolean, String, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from .base import BaseModel
|
||||
|
||||
@@ -7,9 +6,11 @@ class Rule(BaseModel):
|
||||
__tablename__ = "rules"
|
||||
|
||||
name = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
user_defined = Column(Boolean, default=True)
|
||||
jsonb_rules = Column(JSONB, nullable=False)
|
||||
rule_text = Column(Text, nullable=False) # Plaintext rules as per design spec
|
||||
version = Column(Integer, default=1)
|
||||
parent_rule_id = Column(Integer, ForeignKey('rules.id'), nullable=True)
|
||||
|
||||
parent_rule = relationship("Rule", remote_side="Rule.id")
|
||||
parent_rule = relationship("Rule", remote_side="Rule.id")
|
||||
plans = relationship("Plan", secondary="plan_rules", back_populates="rules", lazy="selectin")
|
||||
@@ -1,10 +1,15 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models import Plan, PlanRule, Rule
|
||||
from app.schemas.plan import PlanCreate, Plan as PlanSchema
|
||||
from uuid import UUID
|
||||
from app.models.plan import Plan as PlanModel
|
||||
from app.models.rule import Rule
|
||||
from app.schemas.plan import PlanCreate, Plan as PlanSchema, PlanGenerationRequest, PlanGenerationResponse
|
||||
from app.dependencies import get_ai_service
|
||||
from app.services.ai_service import AIService
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
router = APIRouter(prefix="/plans", tags=["Training Plans"])
|
||||
|
||||
@@ -14,20 +19,12 @@ async def create_plan(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
# Create plan
|
||||
db_plan = Plan(
|
||||
user_id=plan.user_id,
|
||||
start_date=plan.start_date,
|
||||
end_date=plan.end_date,
|
||||
goal=plan.goal
|
||||
db_plan = PlanModel(
|
||||
jsonb_plan=plan.jsonb_plan,
|
||||
version=plan.version,
|
||||
parent_plan_id=plan.parent_plan_id
|
||||
)
|
||||
db.add(db_plan)
|
||||
await db.flush() # Flush to get plan ID
|
||||
|
||||
# Add rules to plan
|
||||
for rule_id in plan.rule_ids:
|
||||
db_plan_rule = PlanRule(plan_id=db_plan.id, rule_id=rule_id)
|
||||
db.add(db_plan_rule)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_plan)
|
||||
return db_plan
|
||||
@@ -37,16 +34,16 @@ async def read_plan(
|
||||
plan_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
plan = await db.get(Plan, plan_id)
|
||||
plan = await db.get(PlanModel, plan_id)
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
return plan
|
||||
|
||||
@router.get("/", response_model=list[PlanSchema])
|
||||
@router.get("/", response_model=List[PlanSchema])
|
||||
async def read_plans(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Plan))
|
||||
result = await db.execute(select(PlanModel))
|
||||
return result.scalars().all()
|
||||
|
||||
@router.put("/{plan_id}", response_model=PlanSchema)
|
||||
@@ -55,21 +52,14 @@ async def update_plan(
|
||||
plan: PlanCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
db_plan = await db.get(Plan, plan_id)
|
||||
db_plan = await db.get(PlanModel, plan_id)
|
||||
if not db_plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
|
||||
# Update plan fields
|
||||
db_plan.user_id = plan.user_id
|
||||
db_plan.start_date = plan.start_date
|
||||
db_plan.end_date = plan.end_date
|
||||
db_plan.goal = plan.goal
|
||||
|
||||
# Update rules
|
||||
await db.execute(PlanRule.delete().where(PlanRule.plan_id == plan_id))
|
||||
for rule_id in plan.rule_ids:
|
||||
db_plan_rule = PlanRule(plan_id=plan_id, rule_id=rule_id)
|
||||
db.add(db_plan_rule)
|
||||
db_plan.jsonb_plan = plan.jsonb_plan
|
||||
db_plan.version = plan.version
|
||||
db_plan.parent_plan_id = plan.parent_plan_id
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_plan)
|
||||
@@ -80,10 +70,63 @@ async def delete_plan(
|
||||
plan_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
plan = await db.get(Plan, plan_id)
|
||||
plan = await db.get(PlanModel, plan_id)
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
|
||||
await db.delete(plan)
|
||||
await db.commit()
|
||||
return {"detail": "Plan deleted"}
|
||||
return {"detail": "Plan deleted"}
|
||||
|
||||
@router.post("/generate", response_model=PlanGenerationResponse)
|
||||
async def generate_plan(
|
||||
request: PlanGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_ai_service)
|
||||
):
|
||||
"""
|
||||
Generate a new training plan using AI based on provided goals and rule set.
|
||||
"""
|
||||
try:
|
||||
# Get all rules from the provided rule IDs
|
||||
rules = []
|
||||
for rule_id in request.rule_ids:
|
||||
rule = await db.get(Rule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail=f"Rule with ID {rule_id} not found")
|
||||
rules.append(rule.jsonb_rules)
|
||||
|
||||
# Generate plan using AI service
|
||||
generated_plan = await ai_service.generate_training_plan(
|
||||
rule_set=rules, # Pass all rules as a list
|
||||
goals=request.goals.model_dump(),
|
||||
preferred_routes=request.preferred_routes
|
||||
)
|
||||
|
||||
# Create a Plan object for the response
|
||||
plan_obj = PlanSchema(
|
||||
id=uuid4(), # Generate a proper UUID
|
||||
jsonb_plan=generated_plan,
|
||||
version=1,
|
||||
parent_plan_id=None,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
# Create response with generated plan
|
||||
response = PlanGenerationResponse(
|
||||
plan=plan_obj,
|
||||
generation_metadata={
|
||||
"status": "success",
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"rule_ids": [str(rule_id) for rule_id in request.rule_ids]
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to generate plan: {str(e)}"
|
||||
)
|
||||
@@ -1,9 +1,13 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models import Rule
|
||||
from app.schemas.rule import RuleCreate, Rule as RuleSchema
|
||||
from app.models.rule import Rule
|
||||
from app.schemas.rule import RuleCreate, Rule as RuleSchema, NaturalLanguageRuleRequest, ParsedRuleResponse
|
||||
from app.dependencies import get_ai_service
|
||||
from app.services.ai_service import AIService
|
||||
from uuid import UUID
|
||||
from typing import List
|
||||
|
||||
router = APIRouter(prefix="/rules", tags=["Rules"])
|
||||
|
||||
@@ -12,55 +16,107 @@ async def create_rule(
|
||||
rule: RuleCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
db_rule = Rule(**rule.dict())
|
||||
"""Create new rule set (plaintext) as per design specification."""
|
||||
db_rule = Rule(**rule.model_dump())
|
||||
db.add(db_rule)
|
||||
await db.commit()
|
||||
await db.refresh(db_rule)
|
||||
return db_rule
|
||||
|
||||
@router.get("/", response_model=List[RuleSchema])
|
||||
async def list_rules(
|
||||
active_only: bool = True,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List rule sets as specified in design document."""
|
||||
query = select(Rule)
|
||||
if active_only:
|
||||
# For now, return all rules. Later we can add an 'active' field
|
||||
pass
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
@router.get("/{rule_id}", response_model=RuleSchema)
|
||||
async def read_rule(
|
||||
async def get_rule(
|
||||
rule_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get specific rule set."""
|
||||
rule = await db.get(Rule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
return rule
|
||||
|
||||
@router.get("/", response_model=list[RuleSchema])
|
||||
async def read_rules(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(sa.select(Rule))
|
||||
return result.scalars().all()
|
||||
|
||||
@router.put("/{rule_id}", response_model=RuleSchema)
|
||||
async def update_rule(
|
||||
rule_id: UUID,
|
||||
rule: RuleCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Update rule set - creates new version as per design spec."""
|
||||
db_rule = await db.get(Rule, rule_id)
|
||||
if not db_rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
|
||||
for key, value in rule.dict().items():
|
||||
setattr(db_rule, key, value)
|
||||
# Create new version instead of updating in place
|
||||
new_version = Rule(
|
||||
name=rule.name,
|
||||
description=rule.description,
|
||||
user_defined=rule.user_defined,
|
||||
rule_text=rule.rule_text,
|
||||
version=db_rule.version + 1,
|
||||
parent_rule_id=db_rule.id
|
||||
)
|
||||
|
||||
db.add(new_version)
|
||||
await db.commit()
|
||||
await db.refresh(db_rule)
|
||||
return db_rule
|
||||
await db.refresh(new_version)
|
||||
return new_version
|
||||
|
||||
@router.delete("/{rule_id}")
|
||||
async def delete_rule(
|
||||
rule_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Delete rule set."""
|
||||
rule = await db.get(Rule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
|
||||
await db.delete(rule)
|
||||
await db.commit()
|
||||
return {"detail": "Rule deleted"}
|
||||
return {"detail": "Rule deleted"}
|
||||
|
||||
@router.post("/parse-natural-language", response_model=ParsedRuleResponse)
|
||||
async def parse_natural_language_rules(
|
||||
request: NaturalLanguageRuleRequest,
|
||||
ai_service: AIService = Depends(get_ai_service)
|
||||
):
|
||||
"""
|
||||
Parse natural language training rules into structured format using AI.
|
||||
This helps users create rules but the final rule_text is stored as plaintext.
|
||||
"""
|
||||
try:
|
||||
# Parse rules using AI service - this creates structured data for validation
|
||||
parsed_rules = await ai_service.parse_rules_from_natural_language(request.natural_language_text)
|
||||
|
||||
# Simple validation - just check for basic completeness
|
||||
suggestions = []
|
||||
if len(request.natural_language_text.split()) < 10:
|
||||
suggestions.append("Consider providing more detailed rules")
|
||||
|
||||
response = ParsedRuleResponse(
|
||||
parsed_rules=parsed_rules,
|
||||
confidence_score=0.8, # Simplified confidence
|
||||
suggestions=suggestions,
|
||||
validation_errors=[], # Simplified - no complex validation
|
||||
rule_name=request.rule_name
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to parse natural language rules: {str(e)}"
|
||||
)
|
||||
@@ -154,7 +154,7 @@ async def approve_analysis(
|
||||
return {"message": "Analysis approved"}
|
||||
|
||||
|
||||
@router.get("/plans/{plan_id}/evolution", response_model=List[PlanSchema])
|
||||
@router.get("/plans/{plan_id}/evolution", response_model=List["PlanSchema"])
|
||||
async def get_plan_evolution(
|
||||
plan_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
|
||||
@@ -18,7 +18,7 @@ class Analysis(AnalysisBase):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AnalysisUpdate(BaseModel):
|
||||
|
||||
@@ -1,21 +1,43 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
class TrainingGoals(BaseModel):
|
||||
"""Training goals for plan generation."""
|
||||
primary_goal: str = Field(..., description="Primary training goal")
|
||||
target_weekly_hours: int = Field(..., ge=3, le=20, description="Target hours per week")
|
||||
fitness_level: str = Field(..., description="Current fitness level")
|
||||
event_date: Optional[str] = Field(None, description="Target event date (YYYY-MM-DD)")
|
||||
preferred_routes: List[int] = Field(default=[], description="Preferred route IDs")
|
||||
avoid_days: List[str] = Field(default=[], description="Days to avoid training")
|
||||
|
||||
class PlanBase(BaseModel):
|
||||
jsonb_plan: dict = Field(..., description="Training plan data in JSONB format")
|
||||
jsonb_plan: Dict[str, Any] = Field(..., description="Training plan data in JSONB format")
|
||||
version: int = Field(..., gt=0, description="Plan version number")
|
||||
parent_plan_id: Optional[int] = Field(None, description="Parent plan ID for evolution tracking")
|
||||
parent_plan_id: Optional[UUID] = Field(None, description="Parent plan ID for evolution tracking")
|
||||
|
||||
class PlanCreate(PlanBase):
|
||||
pass
|
||||
|
||||
class Plan(PlanBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
analyses: List["Analysis"] = Field([], description="Analyses that created this plan version")
|
||||
child_plans: List["Plan"] = Field([], description="Evolved versions of this plan")
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
class PlanGenerationRequest(BaseModel):
|
||||
"""Request schema for plan generation."""
|
||||
rule_ids: List[int] = Field(..., description="Rule set IDs to apply")
|
||||
goals: TrainingGoals = Field(..., description="Training goals")
|
||||
duration_weeks: int = Field(4, ge=1, le=20, description="Plan duration in weeks")
|
||||
user_preferences: Optional[Dict[str, Any]] = Field(None, description="Additional preferences")
|
||||
preferred_routes: List[int] = Field(default=[], description="Preferred route IDs")
|
||||
|
||||
class PlanGenerationResponse(BaseModel):
|
||||
"""Response schema for plan generation."""
|
||||
plan: Plan
|
||||
generation_metadata: Dict[str, Any] = Field(..., description="Generation metadata")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -1,17 +1,49 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional, Dict, Any, List
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
class NaturalLanguageRuleRequest(BaseModel):
|
||||
"""Request schema for natural language rule parsing."""
|
||||
natural_language_text: str = Field(
|
||||
...,
|
||||
min_length=10,
|
||||
max_length=5000,
|
||||
description="Natural language rule description"
|
||||
)
|
||||
rule_name: str = Field(..., min_length=1, max_length=100, description="Rule set name")
|
||||
|
||||
@field_validator('natural_language_text')
|
||||
@classmethod
|
||||
def validate_text_content(cls, v):
|
||||
required_keywords = ['ride', 'week', 'hour', 'day', 'rest', 'training']
|
||||
if not any(keyword in v.lower() for keyword in required_keywords):
|
||||
raise ValueError("Text must contain training-related keywords")
|
||||
return v
|
||||
|
||||
class ParsedRuleResponse(BaseModel):
|
||||
"""Response schema for parsed rules."""
|
||||
parsed_rules: Dict[str, Any] = Field(..., description="Structured rule data")
|
||||
confidence_score: Optional[float] = Field(None, ge=0.0, le=1.0, description="Parsing confidence")
|
||||
suggestions: List[str] = Field(default=[], description="Improvement suggestions")
|
||||
validation_errors: List[str] = Field(default=[], description="Validation errors")
|
||||
rule_name: str = Field(..., description="Rule set name")
|
||||
|
||||
class RuleBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
condition: str
|
||||
priority: int = 0
|
||||
"""Base rule schema."""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
user_defined: bool = Field(True, description="Whether rule is user-defined")
|
||||
rule_text: str = Field(..., min_length=10, description="Plaintext rule description")
|
||||
version: int = Field(1, ge=1, description="Rule version")
|
||||
parent_rule_id: Optional[UUID] = Field(None, description="Parent rule for versioning")
|
||||
|
||||
class RuleCreate(RuleBase):
|
||||
pass
|
||||
|
||||
class Rule(RuleBase):
|
||||
id: str
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -43,12 +43,12 @@ class AIService:
|
||||
response = await self._make_ai_request(prompt)
|
||||
return self._parse_workout_analysis(response)
|
||||
|
||||
async def generate_plan(self, rules: List[Dict], goals: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate a training plan using AI."""
|
||||
async def generate_plan(self, rules_text: str, goals: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate a training plan using AI with plaintext rules as per design spec."""
|
||||
prompt_template = await self.prompt_manager.get_active_prompt("plan_generation")
|
||||
|
||||
context = {
|
||||
"rules": rules,
|
||||
"rules_text": rules_text, # Use plaintext rules directly
|
||||
"goals": goals,
|
||||
"current_fitness_level": goals.get("fitness_level", "intermediate")
|
||||
}
|
||||
@@ -57,13 +57,80 @@ class AIService:
|
||||
response = await self._make_ai_request(prompt)
|
||||
return self._parse_plan_response(response)
|
||||
|
||||
async def generate_training_plan(self, rules_text: str, goals: Dict[str, Any], preferred_routes: List[int]) -> Dict[str, Any]:
|
||||
"""Generate a training plan using AI with plaintext rules as per design specification."""
|
||||
prompt_template = await self.prompt_manager.get_active_prompt("training_plan_generation")
|
||||
if not prompt_template:
|
||||
# Fallback to general plan generation prompt
|
||||
prompt_template = await self.prompt_manager.get_active_prompt("plan_generation")
|
||||
|
||||
context = {
|
||||
"rules_text": rules_text, # Use plaintext rules directly without parsing
|
||||
"goals": goals,
|
||||
"preferred_routes": preferred_routes,
|
||||
"current_fitness_level": goals.get("fitness_level", "intermediate")
|
||||
}
|
||||
|
||||
prompt = prompt_template.format(**context)
|
||||
response = await self._make_ai_request(prompt)
|
||||
return self._parse_plan_response(response)
|
||||
|
||||
async def parse_rules_from_natural_language(self, natural_language: str) -> Dict[str, Any]:
|
||||
"""Parse natural language rules into structured format."""
|
||||
prompt_template = await self.prompt_manager.get_active_prompt("rule_parsing")
|
||||
prompt = prompt_template.format(user_rules=natural_language)
|
||||
|
||||
response = await self._make_ai_request(prompt)
|
||||
return self._parse_rules_response(response)
|
||||
parsed_rules = self._parse_rules_response(response)
|
||||
|
||||
# Add confidence scoring to the parsed rules
|
||||
parsed_rules = self._add_confidence_scoring(parsed_rules)
|
||||
|
||||
return parsed_rules
|
||||
|
||||
def _add_confidence_scoring(self, parsed_rules: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Add confidence scoring to parsed rules based on parsing quality."""
|
||||
confidence_score = self._calculate_confidence_score(parsed_rules)
|
||||
|
||||
# Add confidence score to the parsed rules
|
||||
if isinstance(parsed_rules, dict):
|
||||
parsed_rules["_confidence"] = confidence_score
|
||||
parsed_rules["_parsing_quality"] = self._get_parsing_quality(confidence_score)
|
||||
|
||||
return parsed_rules
|
||||
|
||||
def _calculate_confidence_score(self, parsed_rules: Dict[str, Any]) -> float:
|
||||
"""Calculate confidence score based on parsing quality."""
|
||||
if not isinstance(parsed_rules, dict):
|
||||
return 0.5 # Default confidence for non-dict responses
|
||||
|
||||
score = 0.0
|
||||
# Score based on presence of key cycling training rule fields
|
||||
key_fields = {
|
||||
"max_rides_per_week": 0.3,
|
||||
"min_rest_between_hard": 0.2,
|
||||
"max_duration_hours": 0.2,
|
||||
"weather_constraints": 0.3,
|
||||
"intensity_limits": 0.2,
|
||||
"schedule_constraints": 0.2
|
||||
}
|
||||
|
||||
for field, weight in key_fields.items():
|
||||
if parsed_rules.get(field) is not None:
|
||||
score += weight
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def _get_parsing_quality(self, confidence_score: float) -> str:
|
||||
"""Get parsing quality description based on confidence score."""
|
||||
if confidence_score >= 0.8:
|
||||
return "excellent"
|
||||
elif confidence_score >= 0.6:
|
||||
return "good"
|
||||
elif confidence_score >= 0.4:
|
||||
return "fair"
|
||||
else:
|
||||
return "poor"
|
||||
|
||||
async def evolve_plan(self, evolution_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Evolve a training plan using AI based on workout analysis."""
|
||||
|
||||
Reference in New Issue
Block a user