1
+ import math
1
2
import pytest
2
3
from unittest .mock import MagicMock , AsyncMock
3
4
from sqlalchemy .ext .asyncio import AsyncSession
4
5
5
6
from lcfs .db .models .compliance import FuelSupply
6
7
from lcfs .web .api .fuel_supply .repo import FuelSupplyRepository
7
8
from lcfs .web .api .fuel_supply .schema import FuelSupplyCreateUpdateSchema
9
+ from lcfs .web .api .fuel_supply .schema import (
10
+ FuelSuppliesSchema ,
11
+ PaginationResponseSchema ,
12
+ FuelSupplyResponseSchema ,
13
+ )
14
+ from lcfs .web .api .base import PaginationRequestSchema
8
15
9
16
10
17
@pytest .fixture
11
18
def mock_db_session ():
12
19
session = AsyncMock (spec = AsyncSession )
13
20
14
- # Create a mock that properly mimics SQLAlchemy's async result chain
15
21
async def mock_execute (* args , ** kwargs ):
16
- mock_result = (
17
- MagicMock ()
18
- ) # Changed to MagicMock since the chained methods are sync
22
+ mock_result = MagicMock ()
19
23
mock_result .scalars = MagicMock (return_value = mock_result )
20
24
mock_result .unique = MagicMock (return_value = mock_result )
21
25
mock_result .all = MagicMock (return_value = [MagicMock (spec = FuelSupply )])
@@ -36,24 +40,34 @@ def fuel_supply_repo(mock_db_session):
36
40
37
41
38
42
@pytest .mark .anyio
39
- async def test_get_fuel_supply_list (fuel_supply_repo , mock_db_session ):
43
+ async def test_get_fuel_supply_list_exclude_draft_reports (
44
+ fuel_supply_repo , mock_db_session
45
+ ):
40
46
compliance_report_id = 1
41
- mock_result = [MagicMock (spec = FuelSupply )]
47
+ expected_fuel_supplies = [MagicMock (spec = FuelSupply )]
42
48
43
- # Set up the mock to return our desired result
49
+ # Set up the mock result chain with proper method chaining.
44
50
mock_result_chain = MagicMock ()
45
51
mock_result_chain .scalars = MagicMock (return_value = mock_result_chain )
46
52
mock_result_chain .unique = MagicMock (return_value = mock_result_chain )
47
- mock_result_chain .all = MagicMock (return_value = mock_result )
53
+ mock_result_chain .all = MagicMock (return_value = expected_fuel_supplies )
48
54
49
- async def mock_execute (* args , ** kwargs ):
55
+ async def mock_execute (query , * args , ** kwargs ):
50
56
return mock_result_chain
51
57
52
58
mock_db_session .execute = mock_execute
53
59
54
- result = await fuel_supply_repo .get_fuel_supply_list (compliance_report_id )
60
+ # Test when drafts should be excluded (e.g. government user).
61
+ result_gov = await fuel_supply_repo .get_fuel_supply_list (
62
+ compliance_report_id , exclude_draft_reports = True
63
+ )
64
+ assert result_gov == expected_fuel_supplies
55
65
56
- assert result == mock_result
66
+ # Test when drafts are not excluded.
67
+ result_non_gov = await fuel_supply_repo .get_fuel_supply_list (
68
+ compliance_report_id , exclude_draft_reports = False
69
+ )
70
+ assert result_non_gov == expected_fuel_supplies
57
71
58
72
59
73
@pytest .mark .anyio
@@ -80,19 +94,89 @@ async def test_check_duplicate(fuel_supply_repo, mock_db_session):
80
94
units = "L" ,
81
95
)
82
96
83
- # Set up the mock chain using regular MagicMock since the chained methods are sync
97
+ # Set up the mock chain using MagicMock for synchronous chained methods.
84
98
mock_result_chain = MagicMock ()
85
99
mock_result_chain .scalars = MagicMock (return_value = mock_result_chain )
86
- mock_result_chain .first = MagicMock (
87
- return_value = MagicMock (spec = FuelSupply ))
100
+ mock_result_chain .first = MagicMock (return_value = MagicMock (spec = FuelSupply ))
88
101
89
- # Define an async execute function that returns our mock chain
90
102
async def mock_execute (* args , ** kwargs ):
91
103
return mock_result_chain
92
104
93
- # Replace the session's execute with our new mock
94
105
mock_db_session .execute = mock_execute
95
106
96
107
result = await fuel_supply_repo .check_duplicate (fuel_supply_data )
97
108
98
109
assert result is not None
110
+
111
+
112
+ @pytest .mark .anyio
113
+ async def test_get_fuel_supplies_paginated_exclude_draft_reports (fuel_supply_repo ):
114
+ # Define a sample pagination request.
115
+ pagination = PaginationRequestSchema (page = 1 , size = 10 )
116
+ compliance_report_id = 1
117
+ total_count = 20
118
+
119
+ # Build a valid fuel supply record that passes validation.
120
+ valid_fuel_supply = {
121
+ "fuel_supply_id" : 1 ,
122
+ "complianceReportId" : 1 ,
123
+ "version" : 0 ,
124
+ "fuelTypeId" : 1 ,
125
+ "quantity" : 100 ,
126
+ "groupUuid" : "some-uuid" ,
127
+ "userType" : "SUPPLIER" ,
128
+ "actionType" : "CREATE" ,
129
+ "fuelType" : {"fuel_type_id" : 1 , "fuelType" : "Diesel" , "units" : "L" },
130
+ "fuelCategory" : {"fuel_category_id" : 1 , "category" : "Diesel" },
131
+ "endUseType" : {"endUseTypeId" : 1 , "type" : "Transport" , "subType" : "Personal" },
132
+ "provisionOfTheAct" : {"provisionOfTheActId" : 1 , "name" : "Act Provision" },
133
+ "compliancePeriod" : "2024" ,
134
+ "units" : "L" ,
135
+ "fuelCode" : {
136
+ "fuelStatus" : {"status" : "Approved" },
137
+ "fuelCode" : "FUEL123" ,
138
+ "carbonIntensity" : 15.0 ,
139
+ },
140
+ "fuelTypeOther" : "Optional" ,
141
+ }
142
+ expected_fuel_supplies = [valid_fuel_supply ]
143
+
144
+ async def mock_get_fuel_supplies_paginated (
145
+ pagination , compliance_report_id , exclude_draft_reports
146
+ ):
147
+ total_pages = math .ceil (total_count / pagination .size ) if total_count > 0 else 0
148
+ pagination_response = PaginationResponseSchema (
149
+ page = pagination .page ,
150
+ size = pagination .size ,
151
+ total = total_count ,
152
+ total_pages = total_pages ,
153
+ )
154
+ processed = [
155
+ FuelSupplyResponseSchema .model_validate (fs ) for fs in expected_fuel_supplies
156
+ ]
157
+ return FuelSuppliesSchema (
158
+ pagination = pagination_response , fuel_supplies = processed
159
+ )
160
+
161
+ fuel_supply_repo .get_fuel_supplies_paginated = AsyncMock (
162
+ side_effect = mock_get_fuel_supplies_paginated
163
+ )
164
+
165
+ result = await fuel_supply_repo .get_fuel_supplies_paginated (
166
+ pagination , compliance_report_id , exclude_draft_reports = True
167
+ )
168
+
169
+ # Validate pagination values.
170
+ assert result .pagination .page == pagination .page
171
+ assert result .pagination .size == pagination .size
172
+ assert result .pagination .total == total_count
173
+ expected_total_pages = (
174
+ math .ceil (total_count / pagination .size ) if total_count > 0 else 0
175
+ )
176
+ assert result .pagination .total_pages == expected_total_pages
177
+
178
+ # Validate that the fuel supplies list is correctly transformed.
179
+ expected_processed = [
180
+ FuelSupplyResponseSchema .model_validate (fs ) for fs in expected_fuel_supplies
181
+ ]
182
+ assert result .fuel_supplies == expected_processed
0 commit comments