import pytest from fastapi.testclient import TestClient from api.main import app from uuid import uuid4 from unittest.mock import patch client = TestClient(app) @pytest.fixture def mock_workout_analysis(): # Mock a WorkoutAnalysis object that would be returned by the database class MockWorkoutAnalysis: def __init__(self, analysis_id, chart_paths): self.id = analysis_id self.chart_paths = chart_paths return MockWorkoutAnalysis(uuid4(), { "power_curve": "/tmp/power_curve.png", "elevation_profile": "/tmp/elevation_profile.png", "zone_distribution_power": "/tmp/zone_distribution_power.png", "zone_distribution_hr": "/tmp/zone_distribution_hr.png", "zone_distribution_speed": "/tmp/zone_distribution_speed.png" }) @patch('src.db.session.get_db') @patch('src.core.chart_generator.ChartGenerator') def test_get_analysis_charts_success(mock_chart_generator, mock_get_db, mock_workout_analysis): # Mock the database session to return our mock_workout_analysis mock_db_session = mock_get_db.return_value mock_db_session.query.return_value.filter.return_value.first.return_value = mock_workout_analysis # Mock the ChartGenerator to simulate chart generation mock_chart_instance = mock_chart_generator.return_value mock_chart_instance.generate_power_curve_chart.return_value = None mock_chart_instance.generate_elevation_profile_chart.return_value = None mock_chart_instance.generate_zone_distribution_chart.return_value = None # Create dummy chart files for the test for chart_type, path in mock_workout_analysis.chart_paths.items(): with open(path, "wb") as f: f.write(b"dummy_png_content") chart_type = "power_curve" response = client.get(f"/api/analysis/{mock_workout_analysis.id}/charts?chart_type={chart_type}") assert response.status_code == 200 assert response.headers["content-type"] == "image/png" assert response.content == b"dummy_png_content" @patch('src.db.session.get_db') def test_get_analysis_charts_not_found(mock_get_db): mock_db_session = mock_get_db.return_value mock_db_session.query.return_value.filter.return_value.first.return_value = None analysis_id = uuid4() chart_type = "power_curve" response = client.get(f"/api/analysis/{analysis_id}/charts?chart_type={chart_type}") assert response.status_code == 404 assert response.json()["code"] == "ANALYSIS_NOT_FOUND" @patch('src.db.session.get_db') def test_get_analysis_charts_chart_type_not_found(mock_get_db, mock_workout_analysis): mock_db_session = mock_get_db.return_value mock_db_session.query.return_value.filter.return_value.first.return_value = mock_workout_analysis # Remove the chart path for the requested type to simulate not found mock_workout_analysis.chart_paths.pop("power_curve") chart_type = "power_curve" response = client.get(f"/api/analysis/{mock_workout_analysis.id}/charts?chart_type={chart_type}") assert response.status_code == 404 assert response.json()["code"] == "CHART_NOT_FOUND" @patch('src.db.session.get_db') def test_get_analysis_charts_file_not_found(mock_get_db, mock_workout_analysis): mock_db_session = mock_get_db.return_value mock_db_session.query.return_value.filter.return_value.first.return_value = mock_workout_analysis # Ensure the dummy file is not created to simulate file not found chart_type = "power_curve" response = client.get(f"/api/analysis/{mock_workout_analysis.id}/charts?chart_type={chart_type}") assert response.status_code == 500 assert response.json()["code"] == "CHART_FILE_ERROR"