Skip to content

Commit

Permalink
Improve invitations (#7)
Browse files Browse the repository at this point in the history
* Add invitation table

* invitation wip

* Fix test suite

* Fix test

* Fix sqlmodel relationship issue

See: fastapi/sqlmodel#315

* Attempt to fix mypy

* Don't use transformer for testing

* Remove MyPy

* Black 22.3

Co-authored-by: Jeremy Fisher <jeremy@adamsfisher.me>
  • Loading branch information
jeremyadamsfisher and Jeremy Fisher authored May 20, 2022
1 parent bf6df2d commit fa7a166
Show file tree
Hide file tree
Showing 26 changed files with 606 additions and 161 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-mypy black
pip install -r requirements-dev.txt
- name: Run tests with pytest
run: |
pytest -vv --mypy
pytest -vv
black --check .
34 changes: 32 additions & 2 deletions backend/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from sqlmodel import Session, select

from .models import Story, User, UserStoriesRead
from .models import Invitation, InvitationNew, PlayerOrder, Story, User, UserStoriesRead

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,7 +78,7 @@ def get_stories_originated_by_user(user_id: int, session: Session):
return u


def get_story(story_uuid: str, session: Session) -> Optional[Story]:
def get_story(story_uuid: str, session: Session) -> Story:
statement = select(Story).where(Story.story_uuid == story_uuid)
try:
(story,) = session.exec(statement)
Expand All @@ -100,3 +100,33 @@ def convert_story_to_multiplayer_if_needed(story: Story, user: User, session: Se
segment.author_id = user.id
session.add(story)
session.commit()


def add_invitation(invitation: InvitationNew, session: Session) -> Invitation:
"""create an invitation that can be redeemed by anyone with the link
in the email"""
story = get_story(story_uuid=invitation.story_uuid, session=session)
i = Invitation(
invitee_email=invitation.invitee_email,
story=story,
responded=False,
)
session.add(i)
session.commit()
session.refresh(i)
return i


def respond_to_invitation(invitation_id: int, user: User, session: Session):
"""redeem the invitation so that the logged in user will be allowed
to add to the story"""
invitation = session.get(Invitation, invitation_id)
if invitation is None:
raise DbNotFound
invitation.responded = True
n_players = len(invitation.story.players) + 1
invitation.story.player_ordering.append(PlayerOrder(user=user, order=n_players))
session.add(invitation)
session.commit()

return invitation
23 changes: 12 additions & 11 deletions backend/db.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from google.cloud.sql.connector import connector
from sqlmodel import Session, create_engine

engine = create_engine(
"postgresql+pg8000://",
creator=lambda: connector.connect(
"story-circle-ai:us-east1:yakul",
"pg8000",
user="story-circle-app-sa@story-circle-ai.iam",
db="faboo",
enable_iam_auth=True,
),
)


def get_engine():
return create_engine(
"postgresql+pg8000://",
creator=lambda: connector.connect(
"story-circle-ai:us-east1:yakul",
"pg8000",
user="story-circle-app-sa@story-circle-ai.iam",
db="faboo",
enable_iam_auth=True,
),
)
return engine


def get_session():
engine = get_engine()
with Session(engine) as session:
yield session
12 changes: 10 additions & 2 deletions backend/game.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import re
import string

Expand All @@ -7,15 +8,22 @@

from . import crud
from .db import get_engine
from .models import Story, StorySegment
from .models import StorySegment

logger = logging.getLogger(__name__)

N_FAILURES_ALLOWED = 10
MAX_PROMPT_LENGTH = 50
WORDS_THAT_CAN_HAVE_A_PERIOD = ["mr" "ms" "mrs" "jr" "sr"]

text_generator = pipeline("text-generation", "pranavpsv/gpt2-genre-story-generator")
if os.environ["APP_ENV"] == "TESTING":

def text_generator(prompt):
EXAMPLE = "So we beat on, boats against the current, borne back ceaselessly into the past."
return [{"generated_text": prompt + EXAMPLE}]

else:
text_generator = pipeline("text-generation", "pranavpsv/gpt2-genre-story-generator")


class InferenceProblem(Exception):
Expand Down
4 changes: 4 additions & 0 deletions backend/lib/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@
VALIDATE_CERTS=True,
)


email_client = FastMail(conf)

if os.environ.get("SUPPRESS_EMAIL", False):
email_client.config.SUPPRESS_SEND = 1
14 changes: 10 additions & 4 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
import warnings
import logging
from pathlib import Path

logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
Expand All @@ -17,10 +17,9 @@
logger.warning(".env file does not exists, falling down to environment variables")


from .routers.invitations import router as invitations_router
from .routers.story import router as story_router
from .routers.users import router as user_router
from .routers.invitations import router as invitations_router


warnings.filterwarnings(
"ignore", ".*Class SelectOfScalar will not make use of SQL compilation caching.*"
Expand All @@ -30,7 +29,14 @@
app = FastAPI(title="Story Circle")


if os.environ["APP_ENV"] == "DEV":
app_env = os.environ["APP_ENV"]


logger.info(f"booting into {app_env} mode")


if app_env == "DEV":
logger.info("allowing CORS from anywhere")
origins = [
"*",
"http://localhost",
Expand Down
21 changes: 20 additions & 1 deletion backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Story(SQLModel, table=True):
story_uuid: str = Field(default_factory=lambda: str(uuid4()), index=True)
original_author_id: Optional[int] = Field(default=None, foreign_key="users.id")
original_author: Optional[User] = Relationship(back_populates="stories_originated")
invitations: List["Invitation"] = Relationship(back_populates="story")
segments: List["StorySegment"] = Relationship(back_populates="story")
player_ordering: List["PlayerOrder"] = Relationship(back_populates="story")
single_player_mode: bool
Expand Down Expand Up @@ -64,7 +65,25 @@ class PlayerOrder(SQLModel, table=True):
story: Optional[Story] = Relationship(back_populates="player_ordering")
user_id: Optional[int] = Field(default=None, foreign_key="users.id")
user: Optional[User] = Relationship(back_populates="player_ordering")
invitation_accepted: bool


class Invitation(SQLModel, table=True):
__tablename__ = "invitations"
id: Optional[int] = Field(default=None, primary_key=True)
invitee_email: str
responded: bool
story_id: Optional[int] = Field(default=None, foreign_key="stories.id")
story: Optional[Story] = Relationship(back_populates="invitations")


class InvitationNew(SQLModel):
story_uuid: str
invitee_email: str


class InvitationRead(SQLModel):
id: int
story_id: str


class UserRead(SQLModel):
Expand Down
56 changes: 31 additions & 25 deletions backend/routers/invitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,66 @@

from fastapi import BackgroundTasks, Depends, HTTPException
from fastapi_mail import MessageSchema
from pydantic import BaseModel
from sqlmodel import Session

from .. import crud
from ..auth import get_user_from_request
from ..db import get_session
from ..lib.email import email_client
from ..lib.shims import APIRouter
from ..models import PlayerOrder
from ..models import Invitation, InvitationNew, InvitationRead

router = APIRouter()

INVITE_SUBJECT = "You've been invited to contribute to an AI story!"
INVITE_HTML = """
<p>You've been invited to work on a story with AI agents!</p>
<p>Check it out: {}</p>
<html>
<body>
<p>You've been invited to work on a story with AI agents!</p>
<a id="email-link" href="{}">Click here!</a>
<body>
</html>
"""


class Invitation(BaseModel):
story_uuid: str
other_player_email: str


@router.post("/")
async def invite_user(
@router.post("/send", response_model=InvitationRead)
async def send_invitation(
*,
session: Session = Depends(get_session),
user=Depends(get_user_from_request),
background_tasks: BackgroundTasks,
invitation: Invitation
invitation_: InvitationNew,
):
if not any(
story.story_uuid == invitation.story_uuid for story in user.stories_originated
story.story_uuid == invitation_.story_uuid for story in user.stories_originated
):
raise HTTPException(
403, "can only invite other users to stories user originated"
)

story = crud.get_story(story_uuid=invitation.story_uuid, session=session)
other_player = crud.get_user_by_name(invitation.other_player_email, session=session)
player_order = PlayerOrder(
order=len(story.player_ordering) + 1,
user=other_player,
story=story,
invitation_accepted=False,
)
session.add(player_order)
invitation = crud.add_invitation(invitation_, session)

story_url = reduce(urljoin, [os.environ["APP_ORIGIN"], "/s", invitation.story_uuid])
story_url = reduce(
urljoin,
[os.environ["APP_ORIGIN"], "/invitations/respond", str(invitation.id)],
)

msg = MessageSchema(
subject="You've been invited to contribute to an AI story!",
recipients=[invitation.other_player_email],
subject=INVITE_SUBJECT,
recipients=[invitation.invitee_email],
body=INVITE_HTML.format(story_url),
)

background_tasks.add_task(email_client.send_message, msg)

return invitation


@router.get("/respond/{invitation_id}", response_model=InvitationRead)
def respond_to_invitation(
*,
invitation_id: int,
session: Session = Depends(get_session),
user=Depends(get_user_from_request),
):
return crud.respond_to_invitation(invitation_id, user, session)
1 change: 0 additions & 1 deletion backend/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ..lib.shims import APIRouter
from ..models import UserStoriesRead


router = APIRouter()


Expand Down
10 changes: 2 additions & 8 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,18 @@
from sqlmodel.pool import StaticPool

from backend import crud
from backend.routers import story
from backend.auth import get_user_from_request
from backend.db import get_session
from backend.lib.email import email_client
from backend.main import app
from backend.models import *
from backend.routers import invitations, story

SQLALCHEMY_DATABASE_URL = "sqlite://"

EXAMPLE_USER_EMAILS = [f"player{i}@foo.com" for i in range(1, 3)]


def pytest_configure(config):
plugin = config.pluginmanager.getplugin("mypy")
plugin.mypy_argv.extend(
["--no-strict-optional", "--warn-unused-ignores", "--ignore-missing-imports"]
)


class NeedToSetAUser(Exception):
...

Expand All @@ -56,6 +49,7 @@ def session_fixture():
def client_context(session: Session, monkeypatch):
os.environ["APP_ORIGIN"] = "http://localhost"

os.environ["SUPPRESS_EMAIL"] = "1"
email_client.config.SUPPRESS_SEND = 1

def test_perform_ai_turn(story_uuid):
Expand Down
3 changes: 2 additions & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
"@types/react": "^16.9.0",
"@types/react-dom": "^16.9.0",
"axios": "^0.26.0",
"formik": "^2.2.9",
"framer-motion": "^4.0.0",
"react": "^17.0.2",
"react-dom": "^17.0.2",
"react-icons": "^4.3.1",
"react-query": "^3.34.16",
"react-router-dom": "^6.2.1",
"react-router-dom": "^6.2.2",
"react-scripts": "5.0.0",
"react-spinners": "^0.11.0",
"typescript": "^4.3.5",
Expand Down
Loading

0 comments on commit fa7a166

Please sign in to comment.