diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py index 6dd593f5d6..6ae753352c 100644 --- a/backend/oasst_backend/api/v1/admin.py +++ b/backend/oasst_backend/api/v1/admin.py @@ -8,13 +8,14 @@ from oasst_backend.api import deps from oasst_backend.config import Settings, settings from oasst_backend.models import ApiClient, User -from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.prompt_repository import PromptRepository, UserRepository from oasst_backend.tree_manager import TreeManager from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared import utils from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas.protocol import PageResult, SystemStats -from oasst_shared.utils import ScopeTimer, unaware_to_utc +from oasst_shared.utils import ScopeTimer, log_timing, unaware_to_utc +from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() @@ -263,7 +264,7 @@ async def get_flagged_messages( return resp -@router.post("/admin/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse) +@router.post("/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse) async def process_flagged_messages( message_id: UUID, session: deps.Session = Depends(deps.get_db), @@ -275,3 +276,24 @@ async def process_flagged_messages( flagged_msg = pr.process_flagged_message(message_id=message_id) resp = FlaggedMessageResponse(**flagged_msg.__dict__) return resp + + +class MergeUsersRequest(pydantic.BaseModel): + destination_user_id: UUID + source_user_ids: list[UUID] + + +@log_timing(level="INFO") +@router.post("/merge_users", response_model=None, status_code=HTTP_204_NO_CONTENT) +def merge_users( + request: MergeUsersRequest, + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> None: + @managed_tx_function(CommitMode.COMMIT) + def merge_users_tx(session: deps.Session): + ur = UserRepository(session, api_client) + ur.merge_users(destination_user_id=request.destination_user_id, source_user_ids=request.source_user_ids) + + merge_users_tx() + + logger.info(f"Merged users: {request=}") diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 3c326afb4a..cdea17bf44 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -1,6 +1,7 @@ from typing import Optional from uuid import UUID +import oasst_backend.models as models from oasst_backend.config import settings from oasst_backend.models import ApiClient, User from oasst_backend.utils.database_utils import CommitMode, managed_tx_method @@ -9,7 +10,7 @@ from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.utils import utcnow from sqlalchemy.exc import IntegrityError -from sqlmodel import Session, and_, or_ +from sqlmodel import Session, and_, delete, or_, update from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -346,3 +347,36 @@ def update_user_last_activity(self, user: User, update_streak: bool = False) -> user.streak_days = (current_time - user.streak_last_day_date).days self.db.add(user) + + @managed_tx_method(CommitMode.FLUSH) + def merge_users(self, destination_user_id: UUID, source_user_ids: list[UUID]) -> None: + source_user_ids = list(filter(lambda x: x != destination_user_id, source_user_ids)) + if not source_user_ids: + return + + # ensure the destination user exists + self.get_user(id=destination_user_id) + + # update rows in tables that have affected users_ids as FK + models_to_update = [ + models.Message, + models.MessageRevision, + models.MessageReaction, + models.MessageEmoji, + models.TextLabels, + models.Task, + models.Journal, + ] + for table in models_to_update: + qry = update(table).where(table.user_id.in_(source_user_ids)).values(user_id=destination_user_id) + self.db.execute(qry) + + # delete rows in user stats tables + models_to_delete = [models.UserStats, models.TrollStats] + for table in models_to_delete: + qry = delete(table).where(table.user_id.in_(source_user_ids)) + self.db.execute(qry) + + # finally delete source users from main user table + qry = delete(User).where(User.id.in_(source_user_ids)) + self.db.execute(qry)