mirror of
https://github.com/sstent/aicyclingcoach-go.git
synced 2026-02-17 20:56:11 +00:00
sync
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models import Plan, PlanRule, Rule
|
||||
from app.schemas.plan import PlanCreate, Plan as PlanSchema
|
||||
from uuid import UUID
|
||||
from app.models.plan import Plan as PlanModel
|
||||
from app.models.rule import Rule
|
||||
from app.schemas.plan import PlanCreate, Plan as PlanSchema, PlanGenerationRequest, PlanGenerationResponse
|
||||
from app.dependencies import get_ai_service
|
||||
from app.services.ai_service import AIService
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
router = APIRouter(prefix="/plans", tags=["Training Plans"])
|
||||
|
||||
@@ -14,20 +19,12 @@ async def create_plan(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
# Create plan
|
||||
db_plan = Plan(
|
||||
user_id=plan.user_id,
|
||||
start_date=plan.start_date,
|
||||
end_date=plan.end_date,
|
||||
goal=plan.goal
|
||||
db_plan = PlanModel(
|
||||
jsonb_plan=plan.jsonb_plan,
|
||||
version=plan.version,
|
||||
parent_plan_id=plan.parent_plan_id
|
||||
)
|
||||
db.add(db_plan)
|
||||
await db.flush() # Flush to get plan ID
|
||||
|
||||
# Add rules to plan
|
||||
for rule_id in plan.rule_ids:
|
||||
db_plan_rule = PlanRule(plan_id=db_plan.id, rule_id=rule_id)
|
||||
db.add(db_plan_rule)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_plan)
|
||||
return db_plan
|
||||
@@ -37,16 +34,16 @@ async def read_plan(
|
||||
plan_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
plan = await db.get(Plan, plan_id)
|
||||
plan = await db.get(PlanModel, plan_id)
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
return plan
|
||||
|
||||
@router.get("/", response_model=list[PlanSchema])
|
||||
@router.get("/", response_model=List[PlanSchema])
|
||||
async def read_plans(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Plan))
|
||||
result = await db.execute(select(PlanModel))
|
||||
return result.scalars().all()
|
||||
|
||||
@router.put("/{plan_id}", response_model=PlanSchema)
|
||||
@@ -55,21 +52,14 @@ async def update_plan(
|
||||
plan: PlanCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
db_plan = await db.get(Plan, plan_id)
|
||||
db_plan = await db.get(PlanModel, plan_id)
|
||||
if not db_plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
|
||||
# Update plan fields
|
||||
db_plan.user_id = plan.user_id
|
||||
db_plan.start_date = plan.start_date
|
||||
db_plan.end_date = plan.end_date
|
||||
db_plan.goal = plan.goal
|
||||
|
||||
# Update rules
|
||||
await db.execute(PlanRule.delete().where(PlanRule.plan_id == plan_id))
|
||||
for rule_id in plan.rule_ids:
|
||||
db_plan_rule = PlanRule(plan_id=plan_id, rule_id=rule_id)
|
||||
db.add(db_plan_rule)
|
||||
db_plan.jsonb_plan = plan.jsonb_plan
|
||||
db_plan.version = plan.version
|
||||
db_plan.parent_plan_id = plan.parent_plan_id
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_plan)
|
||||
@@ -80,10 +70,63 @@ async def delete_plan(
|
||||
plan_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
plan = await db.get(Plan, plan_id)
|
||||
plan = await db.get(PlanModel, plan_id)
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
|
||||
await db.delete(plan)
|
||||
await db.commit()
|
||||
return {"detail": "Plan deleted"}
|
||||
return {"detail": "Plan deleted"}
|
||||
|
||||
@router.post("/generate", response_model=PlanGenerationResponse)
|
||||
async def generate_plan(
|
||||
request: PlanGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_ai_service)
|
||||
):
|
||||
"""
|
||||
Generate a new training plan using AI based on provided goals and rule set.
|
||||
"""
|
||||
try:
|
||||
# Get all rules from the provided rule IDs
|
||||
rules = []
|
||||
for rule_id in request.rule_ids:
|
||||
rule = await db.get(Rule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail=f"Rule with ID {rule_id} not found")
|
||||
rules.append(rule.jsonb_rules)
|
||||
|
||||
# Generate plan using AI service
|
||||
generated_plan = await ai_service.generate_training_plan(
|
||||
rule_set=rules, # Pass all rules as a list
|
||||
goals=request.goals.model_dump(),
|
||||
preferred_routes=request.preferred_routes
|
||||
)
|
||||
|
||||
# Create a Plan object for the response
|
||||
plan_obj = PlanSchema(
|
||||
id=uuid4(), # Generate a proper UUID
|
||||
jsonb_plan=generated_plan,
|
||||
version=1,
|
||||
parent_plan_id=None,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
# Create response with generated plan
|
||||
response = PlanGenerationResponse(
|
||||
plan=plan_obj,
|
||||
generation_metadata={
|
||||
"status": "success",
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"rule_ids": [str(rule_id) for rule_id in request.rule_ids]
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to generate plan: {str(e)}"
|
||||
)
|
||||
Reference in New Issue
Block a user