diff --git a/backend/README.md b/backend/README.md index 2cf7fb34ac..f22f0ca827 100644 --- a/backend/README.md +++ b/backend/README.md @@ -26,7 +26,7 @@ Next, to install all requirements, You can run 1. `pip install -r requirements.txt` inside the `backend` folder; and 2. `pip install -e .` inside the `oasst-shared` folder. 3. `pip install -e .` inside the `oasst-data` folder. -4. `./scripts/backend-development/run-local.sh` to run the backend. This will +4. `../scripts/backend-development/run-local.sh` to run the backend. This will start the backend server at `http://localhost:8080`. ## REST Server Configuration diff --git a/backend/main.py b/backend/main.py index 168de5cbc1..8cef801bd7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -32,8 +32,6 @@ from sqlmodel import Session from starlette.middleware.cors import CORSMiddleware -# from worker.scheduled_tasks import create_task - app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json") startup_time: datetime = utcnow() diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index b2b5aff74f..6e4a6b630e 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -174,7 +174,7 @@ async def interaction_tx(session: deps.Session): ur = UserRepository(session, api_client) task = await tm.handle_interaction(interaction) if type(task) is protocol_schema.TaskDone: - ur.update_user_last_activity(user=pr.user) + ur.update_user_last_activity(user=pr.user, update_streak=True) return task try: diff --git a/backend/oasst_backend/celery_worker.py b/backend/oasst_backend/celery_worker.py index fe99527b48..714772e443 100644 --- a/backend/oasst_backend/celery_worker.py +++ b/backend/oasst_backend/celery_worker.py @@ -19,9 +19,9 @@ # see https://docs.celeryq.dev/en/stable/userguide/periodic-tasks.html app.conf.beat_schedule = { - "update-user-streak": { - "task": "update_user_streak", - "schedule": 60.0 * 60.0 * 4, # seconds + "reset-user-streak": { + "task": "periodic_user_streak_reset", + "schedule": 60.0 * 60.0 * 4, # in seconds, every 4h }, "update-search-vectors": { "task": "update_search_vectors", diff --git a/backend/oasst_backend/scheduled_tasks.py b/backend/oasst_backend/scheduled_tasks.py index 407835c492..4ae1ea1264 100644 --- a/backend/oasst_backend/scheduled_tasks.py +++ b/backend/oasst_backend/scheduled_tasks.py @@ -1,23 +1,20 @@ from __future__ import absolute_import, unicode_literals -from datetime import datetime +from datetime import timedelta from typing import Any, Dict, List from asgiref.sync import async_to_sync from celery import shared_task from loguru import logger from oasst_backend.celery_worker import app -from oasst_backend.models import ApiClient, Message +from oasst_backend.models import ApiClient, Message, User from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository -from oasst_backend.user_repository import User from oasst_backend.utils.database_utils import db_lang_to_postgres_ts_lang, default_session_factory from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI -from oasst_shared.utils import utcnow +from oasst_shared.utils import log_timing, utcnow from sqlalchemy import func -from sqlmodel import select - -startup_time: datetime = utcnow() +from sqlmodel import update async def useHFApi(text, url, model_name): @@ -92,46 +89,19 @@ def update_search_vectors(batch_size: int) -> None: logger.error(f"update_search_vectors failed with error: {str(e)}") -@shared_task(name="update_user_streak") -def update_user_streak() -> None: - logger.info("update_user_streak start...") +@shared_task(name="periodic_user_streak_reset") +@log_timing(level="INFO") +def periodic_user_streak_reset() -> None: try: with default_session_factory() as session: - current_time = utcnow() - timedelta = current_time - startup_time - if timedelta.days > 0: - # Update only greater than 24 hours . Do nothing - logger.info("Process timedelta greater than 24h") - statement = select(User) - result = session.exec(statement).all() - if result is not None: - for user in result: - last_activity_date = user.last_activity_date - streak_last_day_date = user.streak_last_day_date - # set NULL streak_days to 0 - if user.streak_days is None: - user.streak_days = 0 - # if the user had completed a task - if last_activity_date is not None: - lastactitvitydelta = current_time - last_activity_date - # if the user missed consecutive days of completing a task - # reset the streak_days to 0 and set streak_last_day_date to the current_time - if lastactitvitydelta.days > 1 or user.streak_days is None: - user.streak_days = 0 - user.streak_last_day_date = current_time - # streak_last_day_date has a current timestamp in DB. Ideally should not be NULL. - if streak_last_day_date is not None: - streak_delta = current_time - streak_last_day_date - # if user completed tasks on consecutive days then increment the streak days - # update the streak_last_day_date to current time for the next calculation - if streak_delta.days > 0: - user.streak_days += 1 - user.streak_last_day_date = current_time - session.add(user) - session.commit() - - else: - logger.info("Not yet 24hours since the process started! ...") - logger.info("User streak end...") - except Exception as e: - logger.error(str(e)) + # Reset streak_days to 0 for users with more than 1.5 days of inactivity + streak_timeout = utcnow() - timedelta(hours=36) + reset_query = ( + update(User) + .filter(User.last_activity_date < streak_timeout, User.streak_last_day_date.is_not(None)) + .values(streak_days=0, streak_last_day_date=None) + ) + session.execute(reset_query) + session.commit() + except Exception: + logger.exception("Error during periodic user streak reset") diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 13c84074a9..3c326afb4a 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -332,6 +332,17 @@ def query_users_ordered_by_display_name( return qry.all() @managed_tx_method(CommitMode.FLUSH) - def update_user_last_activity(self, user: User) -> None: - user.last_activity_date = utcnow() + def update_user_last_activity(self, user: User, update_streak: bool = False) -> None: + current_time = utcnow() + user.last_activity_date = current_time + + if update_streak: + if user.streak_last_day_date is None or user.streak_last_day_date > current_time: + # begin new streak + user.streak_last_day_date = current_time + user.streak_days = 0 + else: + # update streak day count + user.streak_days = (current_time - user.streak_last_day_date).days + self.db.add(user)