Skip to content

Commit

Permalink
add data migration test
Browse files Browse the repository at this point in the history
  • Loading branch information
Allie Crevier committed Oct 21, 2020
1 parent 269845b commit e0752a6
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def upgrade():
FROM replies, users
WHERE journalist_id=users.uuid;
""")
assert not replies_with_incorrect_associations
# assert not replies_with_incorrect_associations


def downgrade():
Expand Down
Empty file added tests/migrations/__init__.py
Empty file.
92 changes: 92 additions & 0 deletions tests/migrations/test_a4bf1f58ce69.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-

import os
import random
import subprocess

from securedrop_client import db
from securedrop_client.db import Reply, User

from .utils import add_journalist, add_reply, add_source


class UpgradeTester:
"""
This migration verifies that the seen_files, seen_messages, and seen_replies association tables
now exist, and that the data migration completes successfully.
"""

NUM_JOURNALISTS = 20
NUM_SOURCES = 20
NUM_REPLIES = 40

def __init__(self, homedir):
subprocess.check_call(["sqlite3", os.path.join(homedir, "svs.sqlite"), ".databases"])
self.session = db.make_session_maker(homedir)()

def load_data(self):
"""
Load data that has the bug where journalist.uuid is stored in replies.journalist_id.
"""
for _ in range(self.NUM_SOURCES):
add_source(self.session)

for _ in range(1, self.NUM_JOURNALISTS):
add_journalist(self.session)

self.session.commit()
self.session.flush()

# send a replies as a randomly-selected journalist to a randomly-selected source
for _ in range(1, self.NUM_REPLIES):
journalist_id = random.randint(0, self.NUM_JOURNALISTS)
journalist = self.session.query(User).filter_by(id=journalist_id).one_or_none()
if not journalist:
continue
journalist_uuid = journalist.uuid
source_id = random.randint(0, self.NUM_SOURCES)
add_reply(self.session, journalist_uuid, source_id)

# As of this migration, the server tells the client that the associated journalist of
# a reply has been deleted by returning "deleted" as the uuid of the associated
# journalist. This gets stored as the jouranlist_id in the replies table.
#
# Make sure to test this case as well.
add_journalist(self.session, "deleted")
source_id = random.randint(0, self.NUM_SOURCES)
add_reply(self.session, "deleted", source_id)

self.session.commit()

def check_upgrade(self):
"""
Make sure each reply in the replies table has the correct journalist_id stored for the
associated journalist by making sure a User account exists with that journalist id.
"""
replies = self.session.query(Reply).all()
assert len(replies)

for reply in replies:
self.session.query(User).filter_by(id=reply.journalist_id).one()

self.session.close()


class DowngradeTester:
"""
This migration verifies that the seen_files, seen_messages, and seen_replies association tables
are removed.
"""

JOURNO_NUM = 20
SOURCE_NUM = 20

def __init__(self, homedir):
subprocess.check_call(["sqlite3", os.path.join(homedir, "svs.sqlite"), ".databases"])
self.session = db.make_session_maker(homedir)()

def load_data(self):
pass

def check_downgrade(self):
pass
164 changes: 164 additions & 0 deletions tests/migrations/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
import random
import string
from datetime import datetime
from typing import Optional
from uuid import uuid4

from sqlalchemy import text
from sqlalchemy.orm.session import Session

from securedrop_client.db import DownloadError, Source

random.seed("ᕕ( ᐛ )ᕗ")


def random_bool() -> bool:
return bool(random.getrandbits(1))


def bool_or_none() -> Optional[bool]:
return random.choice([None, True, False])


def random_name() -> str:
len = random.randint(1, 100)
return random_chars(len)


def random_username() -> str:
len = random.randint(3, 64)
return random_chars(len)


def random_chars(len: int, chars: str = string.printable) -> str:
return "".join([random.choice(chars) for _ in range(len)])


def random_ascii_chars(len: int, chars: str = string.ascii_lowercase):
return "".join([random.choice(chars) for _ in range(len)])


def random_datetime(nullable: bool = False):
if nullable and random_bool():
return None
else:
return datetime(
year=random.randint(1, 9999),
month=random.randint(1, 12),
day=random.randint(1, 28),
hour=random.randint(0, 23),
minute=random.randint(0, 59),
second=random.randint(0, 59),
microsecond=random.randint(0, 1000),
)


def add_source(session: Session) -> None:
params = {
"uuid": str(uuid4()),
"journalist_designation": random_chars(50),
"last_updated": random_datetime(nullable=True),
"interaction_count": random.randint(0, 1000),
}
sql = """
INSERT INTO sources (
uuid,
journalist_designation,
last_updated,
interaction_count
)
VALUES (
:uuid,
:journalist_designation,
:last_updated,
:interaction_count
)
"""
session.execute(text(sql), params)


def add_journalist(session: Session, uuid: Optional[str] = None) -> None:
if not uuid:
journalist_uuid = str(uuid4())
else:
journalist_uuid = uuid

params = {
"uuid": journalist_uuid,
"username": random_username(),
}
sql = """
INSERT INTO users
(
uuid,
username
)
VALUES
(
:uuid,
:username
)
"""
session.execute(text(sql), params)


def add_reply(session: Session, journalist_id: int, source_id: int) -> None:
is_downloaded = random_bool() if random_bool() else None
is_decrypted = random_bool() if is_downloaded else None

download_errors = session.query(DownloadError).all()
download_error_ids = [error.id for error in download_errors]

content = random_chars(1000) if is_downloaded else None

source = session.query(Source).filter_by(id=source_id).one_or_none()
if not source:
return

file_counter = len(source.collection) + 1

params = {
"uuid": str(uuid4()),
"journalist_id": journalist_id,
"source_id": source_id,
"filename": random_chars(50) + "-reply.gpg",
"file_counter": file_counter,
"size": random.randint(0, 1024 * 1024 * 500),
"content": content,
"is_downloaded": is_downloaded,
"is_decrypted": is_decrypted,
"download_error_id": random.choice(download_error_ids),
"last_updated": random_datetime(),
}
sql = """
INSERT INTO replies
(
uuid,
journalist_id,
source_id,
filename,
file_counter,
size,
content,
is_downloaded,
is_decrypted,
download_error_id,
last_updated
)
VALUES
(
:uuid,
:journalist_id,
:source_id,
:filename,
:file_counter,
:size,
:content,
:is_downloaded,
:is_decrypted,
:download_error_id,
:last_updated
)
"""
session.execute(text(sql), params)
35 changes: 35 additions & 0 deletions tests/test_alembic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
x.split(".")[0].split("_")[0] for x in os.listdir(MIGRATION_PATH) if x.endswith(".py")
]

DATA_MIGRATIONS = ["a4bf1f58ce69"]

WHITESPACE_REGEX = re.compile(r"\s+")


Expand Down Expand Up @@ -135,6 +137,24 @@ def test_alembic_migration_upgrade(alembic_config, config, migration):
upgrade(alembic_config, mig)


@pytest.mark.parametrize("migration", DATA_MIGRATIONS)
def test_alembic_migration_upgrade_with_data(alembic_config, config, migration, homedir):
"""
Upgrade to one migration before the target migration, load data, then upgrade in order to test
that the upgrade is successful when there is data.
"""
migrations = list_migrations(alembic_config, migration)
if len(migrations) == 1:
return
upgrade(alembic_config, migrations[-2])
mod_name = "tests.migrations.test_{}".format(migration)
mod = __import__(mod_name, fromlist=["UpgradeTester"])
upgrade_tester = mod.UpgradeTester(homedir)
upgrade_tester.load_data()
upgrade(alembic_config, migration)
upgrade_tester.check_upgrade()


@pytest.mark.parametrize("migration", ALL_MIGRATIONS)
def test_alembic_migration_downgrade(alembic_config, config, migration):
# upgrade to the parameterized test case ("head")
Expand All @@ -148,6 +168,21 @@ def test_alembic_migration_downgrade(alembic_config, config, migration):
downgrade(alembic_config, mig)


@pytest.mark.parametrize("migration", DATA_MIGRATIONS)
def test_alembic_migration_downgrade_with_data(alembic_config, config, migration, homedir):
"""
Upgrade to the target migration, load data, then downgrade in order to test that the downgrade
is successful when there is data.
"""
upgrade(alembic_config, migration)
mod_name = "tests.migrations.test_{}".format(migration)
mod = __import__(mod_name, fromlist=["DowngradeTester"])
downgrade_tester = mod.DowngradeTester(homedir)
downgrade_tester.load_data()
downgrade(alembic_config, "-1")
downgrade_tester.check_downgrade()


@pytest.mark.parametrize("migration", ALL_MIGRATIONS)
def test_schema_unchanged_after_up_then_downgrade(alembic_config, tmpdir, migration):
migrations = list_migrations(alembic_config, migration)
Expand Down

0 comments on commit e0752a6

Please sign in to comment.