Skip to content

Commit

Permalink
Fix CSV ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Jan 7, 2025
1 parent 921c282 commit 6000b1d
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 24 deletions.
14 changes: 13 additions & 1 deletion py/core/database/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,19 @@ async def export_to_csv(
if not rows:
break
for row in rows:
writer.writerow(row)
row_dict = {
"id": row[0],
"owner_id": row[1],
"name": row[2],
"description": row[3],
"graph_sync_status": row[4],
"graph_cluster_status": row[5],
"created_at": row[6],
"updated_at": row[7],
"user_count": row[8],
"document_count": row[9],
}
writer.writerow([row_dict[col] for col in columns])

temp_file.flush()
return temp_file.name, temp_file
Expand Down
18 changes: 16 additions & 2 deletions py/core/database/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,13 @@ async def export_conversations_to_csv(
if not rows:
break
for row in rows:
writer.writerow(row)
row_dict = {
"id": row[0],
"user_id": row[1],
"created_at": row[2],
"name": row[3],
}
writer.writerow([row_dict[col] for col in columns])

temp_file.flush()
return temp_file.name, temp_file
Expand Down Expand Up @@ -640,7 +646,15 @@ async def export_messages_to_csv(
if not rows:
break
for row in rows:
writer.writerow(row)
row_dict = {
"id": row[0],
"conversation_id": row[1],
"parent_id": row[2],
"content": row[3],
"metadata": row[4],
"created_at": row[5],
}
writer.writerow([row_dict[col] for col in columns])

temp_file.flush()
return temp_file.name, temp_file
Expand Down
17 changes: 16 additions & 1 deletion py/core/database/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,22 @@ async def export_to_csv(
if not rows:
break
for row in rows:
writer.writerow(row)
row_dict = {
"id": row[0],
"collection_ids": row[1],
"owner_id": row[2],
"type": row[3],
"metadata": row[4],
"title": row[5],
"summary": row[6],
"version": row[7],
"size_in_bytes": row[8],
"ingestion_status": row[9],
"extraction_status": row[10],
"created_at": row[11],
"updated_at": row[12],
}
writer.writerow([row_dict[col] for col in columns])

temp_file.flush()
return temp_file.name, temp_file
Expand Down
13 changes: 12 additions & 1 deletion py/core/database/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,18 @@ async def export_to_csv(
if not rows:
break
for row in rows:
writer.writerow(row)
row_dict = {
"id": row[0],
"name": row[1],
"category": row[2],
"description": row[3],
"parent_id": row[4],
"chunk_ids": row[5],
"metadata": row[6],
"created_at": row[7],
"updated_at": row[8],
}
writer.writerow([row_dict[col] for col in columns])

temp_file.flush()
return temp_file.name, temp_file
Expand Down
13 changes: 12 additions & 1 deletion py/core/database/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,18 @@ async def export_to_csv(
if not rows:
break
for row in rows:
writer.writerow(row)
row_dict = {
"id": row[0],
"email": row[1],
"is_superuser": row[2],
"is_active": row[3],
"is_verified": row[4],
"name": row[5],
"bio": row[6],
"created_at": row[7],
"updated_at": row[8],
}
writer.writerow([row_dict[col] for col in columns])

temp_file.flush()
return temp_file.name, temp_file
Expand Down
1 change: 0 additions & 1 deletion py/core/main/api/v3/users_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,6 @@ async def delete_user_api_key(
# client.login(...)
user_limits = client.users.get_limits("550e8400-e29b-41d4-a716-446655440000")
print(user_limits)
""",
},
{
Expand Down
61 changes: 45 additions & 16 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,9 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
- The usage for each relevant limit (how many requests used, how many remain, etc.)
"""
# 1. Fetch the user to see if they have overrides
user = await self.providers.database.users_handler.get_user_by_id(user_id)
user = await self.providers.database.users_handler.get_user_by_id(
user_id
)

# 2. System defaults
system_defaults = {
Expand All @@ -948,7 +950,9 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:

# If the user added "global_per_min" or "monthly_limit" overrides, override them
if user_overrides.get("global_per_min") is not None:
effective_limits["global_per_min"] = user_overrides["global_per_min"]
effective_limits["global_per_min"] = user_overrides[
"global_per_min"
]
if user_overrides.get("monthly_limit") is not None:
effective_limits["monthly_limit"] = user_overrides["monthly_limit"]
if user_overrides.get("route_per_min") is not None:
Expand All @@ -958,8 +962,12 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
# - self.config.route_limits (system route overrides)
# - user_overrides["route_overrides"] (user route overrides)
# So we can later show usage for each route.
system_route_limits = self.config.database.route_limits # dict[str, LimitSettings]
user_route_overrides = user_overrides.get("route_overrides", {}) # e.g. { "/api/foo": {...}, ... }
system_route_limits = (
self.config.database.route_limits
) # dict[str, LimitSettings]
user_route_overrides = user_overrides.get(
"route_overrides", {}
) # e.g. { "/api/foo": {...}, ... }

# 5. Build usage data
usage = {}
Expand All @@ -971,10 +979,14 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
one_min_ago = now - timedelta(minutes=1)

# Use your limits_handler to count
global_per_min_used = await self.providers.database.limits_handler._count_requests(
user_id, route=None, since=one_min_ago
global_per_min_used = (
await self.providers.database.limits_handler._count_requests(
user_id, route=None, since=one_min_ago
)
)
monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
user_id
)
monthly_used = await self.providers.database.limits_handler._count_monthly_requests(user_id)

# The final effective global/min is in `effective_limits["global_per_min"]`, etc.
usage["global_per_min"] = {
Expand All @@ -998,17 +1010,33 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:

# (b) Build route-level usage
# We'll gather a union of the routes from system_route_limits + user_route_overrides
route_keys = set(system_route_limits.keys()) | set(user_route_overrides.keys())
route_keys = set(system_route_limits.keys()) | set(
user_route_overrides.keys()
)
usage["routes"] = {}
for route in route_keys:
# 1) System route-limits
sys_route_lim = system_route_limits.get(route) # or None
route_global_per_min = sys_route_lim.global_per_min if sys_route_lim else system_defaults["global_per_min"]
route_route_per_min = sys_route_lim.route_per_min if sys_route_lim else system_defaults["route_per_min"]
route_monthly_limit = sys_route_lim.monthly_limit if sys_route_lim else system_defaults["monthly_limit"]
route_global_per_min = (
sys_route_lim.global_per_min
if sys_route_lim
else system_defaults["global_per_min"]
)
route_route_per_min = (
sys_route_lim.route_per_min
if sys_route_lim
else system_defaults["route_per_min"]
)
route_monthly_limit = (
sys_route_lim.monthly_limit
if sys_route_lim
else system_defaults["monthly_limit"]
)

# 2) Merge user overrides for that route
user_route_cfg = user_route_overrides.get(route, {}) # e.g. { "route_per_min": 25, "global_per_min": 80, ... }
user_route_cfg = user_route_overrides.get(
route, {}
) # e.g. { "route_per_min": 25, "global_per_min": 80, ... }
if user_route_cfg.get("global_per_min") is not None:
route_global_per_min = user_route_cfg["global_per_min"]
if user_route_cfg.get("route_per_min") is not None:
Expand All @@ -1017,8 +1045,10 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
route_monthly_limit = user_route_cfg["monthly_limit"]

# Now let's measure usage for this route over the last minute
route_per_min_used = await self.providers.database.limits_handler._count_requests(
user_id, route, one_min_ago
route_per_min_used = (
await self.providers.database.limits_handler._count_requests(
user_id, route, one_min_ago
)
)
# monthly usage is the same for all routes if there's a global monthly limit,
# but if you have route-specific monthly limits, we still want to do a global monthly count.
Expand All @@ -1043,7 +1073,6 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
# If you want to represent the "global_per_min" that applies to this route,
# you could put that here too if it’s route-specific.
# But typically "global_per_min" is for all requests, so usage is the same as above.

# The route-specific monthly usage, in your code, is not specifically counted by route,
# but if you want to do it the same as route_per_min, you'd do:
# route_monthly_used = await self.providers.database.limits_handler._count_requests(
Expand All @@ -1058,7 +1087,7 @@ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
if route_monthly_limit is not None
else None
),
}
},
}

# Return a structured response
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def num_tokens_from_messages(messages, model="gpt-4o"):
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
logger.warning("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")

tokens = 0
Expand Down

0 comments on commit 6000b1d

Please sign in to comment.