Skip to content

Commit

Permalink
Add no-users option to DumpImporter and DumpExporter (cms-dev#1165)
Browse files Browse the repository at this point in the history
* feat: Allow DumpExporter to only export tasks

* fix: Dump exporter tests

* fix: Add DumpExporterTest for skip_users, fix bug

* feat: Add no-users option to DumpImporter

* fixup

Co-authored-by: Andrey Vihrov <andrey.vihrov@gmail.com>
  • Loading branch information
2 people authored and zj-cs2103 committed Mar 20, 2022
1 parent e67e816 commit 2dc0372
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 24 deletions.
10 changes: 5 additions & 5 deletions cms/db/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def get_datasets_to_judge(task):

def enumerate_files(
session, contest=None,
skip_submissions=False, skip_user_tests=False, skip_print_jobs=False,
skip_generated=False):
skip_submissions=False, skip_user_tests=False, skip_users=False,
skip_print_jobs=False, skip_generated=False):
"""Enumerate all the files (by digest) referenced by the
contest.
Expand All @@ -302,7 +302,7 @@ def enumerate_files(
queries.append(dataset_q.join(Dataset.testcases)
.with_entities(Testcase.output))

if not skip_submissions:
if not skip_submissions and not skip_users:
submission_q = task_q.join(Task.submissions)
queries.append(submission_q.join(Submission.files)
.with_entities(File.digest))
Expand All @@ -312,7 +312,7 @@ def enumerate_files(
.join(SubmissionResult.executables)
.with_entities(Executable.digest))

if not skip_user_tests:
if not skip_user_tests and not skip_users:
user_test_q = task_q.join(Task.user_tests)
queries.append(user_test_q.with_entities(UserTest.input))
queries.append(user_test_q.join(UserTest.files)
Expand All @@ -327,7 +327,7 @@ def enumerate_files(
queries.append(user_test_result_q
.with_entities(UserTestResult.output))

if not skip_print_jobs:
if not skip_print_jobs and not skip_users:
queries.append(contest_q.join(Contest.participations)
.join(Participation.printjobs)
.with_entities(PrintJob.digest))
Expand Down
27 changes: 23 additions & 4 deletions cmscontrib/DumpExporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from cms.db import version as model_version, Codename, Filename, \
FilenameSchema, FilenameSchemaArray, Digest, SessionGen, Contest, User, \
Task, Submission, UserTest, SubmissionResult, UserTestResult, PrintJob, \
enumerate_files
Announcement, Participation, enumerate_files
from cms.db.filecacher import FileCacher
from cmscommon.datetime import make_timestamp
from cmscommon.digest import path_digest
Expand Down Expand Up @@ -136,13 +136,16 @@ class DumpExporter:

def __init__(self, contest_ids, export_target,
dump_files, dump_model, skip_generated,
skip_submissions, skip_user_tests, skip_print_jobs):
skip_submissions, skip_user_tests, skip_users, skip_print_jobs):
if contest_ids is None:
with SessionGen() as session:
contests = session.query(Contest).all()
self.contests_ids = [contest.id for contest in contests]
users = session.query(User).all()
self.users_ids = [user.id for user in users]
if not skip_users:
users = session.query(User).all()
self.users_ids = [user.id for user in users]
else:
self.users_ids = []
tasks = session.query(Task)\
.filter(Task.contest_id.is_(None)).all()
self.tasks_ids = [task.id for task in tasks]
Expand All @@ -158,6 +161,7 @@ def __init__(self, contest_ids, export_target,
self.skip_generated = skip_generated
self.skip_submissions = skip_submissions
self.skip_user_tests = skip_user_tests
self.skip_users = skip_users
self.skip_print_jobs = skip_print_jobs
self.export_target = export_target

Expand Down Expand Up @@ -208,6 +212,7 @@ def do_export(self):
session, contest,
skip_submissions=self.skip_submissions,
skip_user_tests=self.skip_user_tests,
skip_users=self.skip_users,
skip_print_jobs=self.skip_print_jobs,
skip_generated=self.skip_generated)
for file_ in files:
Expand Down Expand Up @@ -317,6 +322,17 @@ class of the given object), an item for each column property
if self.skip_user_tests and other_cls is UserTest:
continue

if self.skip_users:
skip = False
# User-related classes reachable from root
for rel_class in [Participation, Submission, UserTest,
Announcement]:
if other_cls is rel_class:
skip = True
break
if skip:
continue

# Skip print jobs if requested
if self.skip_print_jobs and other_cls is PrintJob:
continue
Expand Down Expand Up @@ -397,6 +413,8 @@ def main():
help="don't export submissions")
parser.add_argument("-U", "--no-user-tests", action="store_true",
help="don't export user tests")
parser.add_argument("-X", "--no-users", action="store_true",
help="don't export users")
parser.add_argument("-P", "--no-print-jobs", action="store_true",
help="don't export print jobs")
parser.add_argument("export_target", action="store",
Expand All @@ -412,6 +430,7 @@ def main():
skip_generated=args.no_generated,
skip_submissions=args.no_submissions,
skip_user_tests=args.no_user_tests,
skip_users=args.no_users,
skip_print_jobs=args.no_print_jobs)
success = exporter.do_export()
return 0 if success is True else 1
Expand Down
41 changes: 29 additions & 12 deletions cmscontrib/DumpImporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
from cms import utf8_decoder
from cms.db import version as model_version, Codename, Filename, \
FilenameSchema, FilenameSchemaArray, Digest, SessionGen, Contest, \
Submission, SubmissionResult, UserTest, UserTestResult, PrintJob, init_db, \
drop_db, enumerate_files
Submission, SubmissionResult, User, Participation, UserTest, \
UserTestResult, PrintJob, Announcement, init_db, drop_db, enumerate_files
from cms.db.filecacher import FileCacher
from cmscommon.archive import Archive
from cmscommon.datetime import make_datetime
Expand Down Expand Up @@ -128,13 +128,14 @@ class DumpImporter:

def __init__(self, drop, import_source,
load_files, load_model, skip_generated,
skip_submissions, skip_user_tests, skip_print_jobs):
skip_submissions, skip_user_tests, skip_users, skip_print_jobs):
self.drop = drop
self.load_files = load_files
self.load_model = load_model
self.skip_generated = skip_generated
self.skip_submissions = skip_submissions
self.skip_user_tests = skip_user_tests
self.skip_users = skip_users
self.skip_print_jobs = skip_print_jobs

self.import_source = import_source
Expand Down Expand Up @@ -233,9 +234,6 @@ def do_import(self):
for id_, data in self.datas.items():
if not id_.startswith("_"):
self.objs[id_] = self.import_object(data)
for id_, data in self.datas.items():
if not id_.startswith("_"):
self.add_relationships(data, self.objs[id_])

for k, v in list(self.objs.items()):

Expand All @@ -244,18 +242,28 @@ def do_import(self):
del self.objs[k]

# Skip user_tests if requested
if self.skip_user_tests and isinstance(v, UserTest):
elif self.skip_user_tests and isinstance(v, UserTest):
del self.objs[k]

# Skip users if requested
elif self.skip_users and \
isinstance(v, (User, Participation, Submission,
UserTest, Announcement)):
del self.objs[k]

# Skip print jobs if requested
if self.skip_print_jobs and isinstance(v, PrintJob):
elif self.skip_print_jobs and isinstance(v, PrintJob):
del self.objs[k]

# Skip generated data if requested
if self.skip_generated and \
elif self.skip_generated and \
isinstance(v, (SubmissionResult, UserTestResult)):
del self.objs[k]

for id_, data in self.datas.items():
if not id_.startswith("_") and id_ in self.objs:
self.add_relationships(data, self.objs[id_])

contest_id = list()
contest_files = set()

Expand All @@ -266,6 +274,11 @@ def do_import(self):
# that depended on submissions or user tests that we
# might have removed above).
for id_ in self.datas["_objects"]:

# It could have been removed by request
if id_ not in self.objs:
continue

obj = self.objs[id_]
session.add(obj)
session.flush()
Expand All @@ -277,6 +290,7 @@ def do_import(self):
skip_submissions=self.skip_submissions,
skip_user_tests=self.skip_user_tests,
skip_print_jobs=self.skip_print_jobs,
skip_users=self.skip_users,
skip_generated=self.skip_generated)

session.commit()
Expand Down Expand Up @@ -405,12 +419,12 @@ def add_relationships(self, data, obj):
if val is None:
setattr(obj, prp.key, None)
elif isinstance(val, str):
setattr(obj, prp.key, self.objs[val])
setattr(obj, prp.key, self.objs.get(val))
elif isinstance(val, list):
setattr(obj, prp.key, list(self.objs[i] for i in val))
setattr(obj, prp.key, list(self.objs[i] for i in val if i in self.objs))
elif isinstance(val, dict):
setattr(obj, prp.key,
dict((k, self.objs[v]) for k, v in val.items()))
dict((k, self.objs[v]) for k, v in val.items() if v in self.objs))
else:
raise RuntimeError(
"Unknown RelationshipProperty value: %s" % type(val))
Expand Down Expand Up @@ -472,6 +486,8 @@ def main():
help="don't import submissions")
parser.add_argument("-U", "--no-user-tests", action="store_true",
help="don't import user tests")
parser.add_argument("-X", "--no-users", action="store_true",
help="don't import users")
parser.add_argument("-P", "--no-print-jobs", action="store_true",
help="don't import print jobs")
parser.add_argument("import_source", action="store", type=utf8_decoder,
Expand All @@ -486,6 +502,7 @@ def main():
skip_generated=args.no_generated,
skip_submissions=args.no_submissions,
skip_user_tests=args.no_user_tests,
skip_users=args.no_users,
skip_print_jobs=args.no_print_jobs)
success = importer.do_import()
return 0 if success is True else 1
Expand Down
22 changes: 21 additions & 1 deletion cmstestsuite/unit_tests/cmscontrib/DumpExporterTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def tearDown(self):
super().tearDown()

def do_export(self, contest_ids, dump_files=True, skip_generated=False,
skip_submissions=False):
skip_submissions=False, skip_users=False):
"""Create an exporter and call do_export in a convenient way"""
r = DumpExporter(
contest_ids,
Expand All @@ -95,6 +95,7 @@ def do_export(self, contest_ids, dump_files=True, skip_generated=False,
skip_generated=skip_generated,
skip_submissions=skip_submissions,
skip_user_tests=False,
skip_users=skip_users,
skip_print_jobs=False).do_export()
dump_path = os.path.join(self.target, "contest.json")
try:
Expand Down Expand Up @@ -269,6 +270,25 @@ def test_skip_generated(self):
self.assertNotInDump(SubmissionResult)
self.assertFileNotInDump(self.exe_digest)

def test_skip_users(self):
"""Test skipping users.
Should not export users and depending objects.
Should still export contest, tasks and their depending objects.
"""
self.assertTrue(self.do_export(None, skip_users=True))

self.assertInDump(Statement, digest=self.st_digest)
self.assertFileInDump(self.st_digest, self.st_content)

self.assertNotInDump(User)
self.assertNotInDump(Participation)
self.assertNotInDump(Submission)
self.assertNotInDump(SubmissionResult)
self.assertFileNotInDump(self.file_digest)
self.assertFileNotInDump(self.exe_digest)


if __name__ == "__main__":
unittest.main()
30 changes: 28 additions & 2 deletions cmstestsuite/unit_tests/cmscontrib/DumpImporterTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Needs to be first to allow for monkey patching the DB connection string.
from cmstestsuite.unit_tests.databasemixin import DatabaseMixin

from cms.db import Contest, FSObject, Session, version
from cms.db import Contest, User, FSObject, Session, version
from cmscommon.digest import bytes_digest
from cmscontrib.DumpImporter import DumpImporter
from cmstestsuite.unit_tests.filesystemmixin import FileSystemMixin
Expand Down Expand Up @@ -136,7 +136,8 @@ def tearDown(self):
super().tearDown()

def do_import(self, drop=False, load_files=True,
skip_generated=False, skip_submissions=False):
skip_generated=False, skip_submissions=False,
skip_users=False):
"""Create an importer and call do_import in a convenient way"""
return DumpImporter(
drop,
Expand All @@ -146,6 +147,7 @@ def do_import(self, drop=False, load_files=True,
skip_generated=skip_generated,
skip_submissions=skip_submissions,
skip_user_tests=False,
skip_users=skip_users,
skip_print_jobs=False).do_import()

def write_dump(self, dump):
Expand Down Expand Up @@ -195,6 +197,12 @@ def assertContestNotInDb(self, name):
.filter(Contest.name == name).all()
self.assertEqual(len(db_contests), 0)

def assertUserNotInDb(self, username):
"""Assert that the user with the given username is not in the DB."""
db_users = self.session.query(User)\
.filter(User.username == username).all()
self.assertEqual(len(db_users), 0)

def assertFileInDb(self, digest, description, content):
"""Assert that the file with the given data is in the DB."""
fsos = self.session.query(FSObject)\
Expand Down Expand Up @@ -279,6 +287,24 @@ def test_import_skip_files(self):
self.assertFileNotInDb(TestDumpImporter.GENERATED_FILE_DIGEST)
self.assertFileNotInDb(TestDumpImporter.NON_GENERATED_FILE_DIGEST)

def test_import_skip_users(self):
"""Test importing everything but not the users."""
self.write_dump(TestDumpImporter.DUMP)
self.write_files(TestDumpImporter.FILES)

self.assertTrue(self.do_import(skip_users=True))

self.assertContestInDb("contestname", "contest description 你好",
[("taskname", "task title")],
[])
self.assertContestInDb(
self.other_contest_name, self.other_contest_description, [], [])

self.assertUserNotInDb("username")
self.assertFileNotInDb(TestDumpImporter.GENERATED_FILE_DIGEST)
self.assertFileNotInDb(TestDumpImporter.NON_GENERATED_FILE_DIGEST)


def test_import_old(self):
"""Test importing an old dump.
Expand Down

0 comments on commit 2dc0372

Please sign in to comment.