Skip to content

Commit

Permalink
[feat] Add command to migrate aim Repo
Browse files Browse the repository at this point in the history
  • Loading branch information
alberttorosyan committed Jul 19, 2023
1 parent cf3047a commit 4738ad4
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/aimcore/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aimcore.cli.server import commands as server_commands
from aimcore.cli.telemetry import commands as telemetry_commands
from aimcore.cli.package import commands as package_commands
from aimcore.cli.migrate import commands as migrate_commands

core._verify_python3_env = lambda: None

Expand All @@ -22,3 +23,4 @@ def cli_entry_point():
cli_entry_point.add_command(server_commands.server)
cli_entry_point.add_command(telemetry_commands.telemetry)
cli_entry_point.add_command(package_commands.package)
cli_entry_point.add_command(migrate_commands.migrate)
Empty file.
46 changes: 46 additions & 0 deletions src/aimcore/cli/migrate/commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import click
import pathlib
import shutil

from aim import Repo
from aim._sdk.configs import get_aim_repo_name, get_data_version
from aimcore.cli.migrate import utils


@click.command('migrate')
@click.option('--repo', required=True, type=click.Path(exists=True,
file_okay=False,
dir_okay=True,
writable=True))
@click.option('--run', required=False, type=str, default=None)
@click.option('-y', '--yes', is_flag=True, help='Automatically confirm prompt')
def migrate(repo, run, yes):
data_version = Repo.get_version(repo)
if data_version is None:
click.secho(f'Cannot run migration for directory \'{repo}\'. Data version is unknown.')
exit(1)
elif data_version == get_data_version():
click.secho(f'Aim Repo \'{repo}\' is already up-to-date. Skipping.')
exit(0)
repo_path = pathlib.Path(repo)
aim_dir_path = repo_path / get_aim_repo_name()
aim_v3_dir_path = repo_path / f'{get_aim_repo_name()}_v3'
if data_version[0] == 1:
if yes or click.confirm(f'This command will convert Aim Runs at \'{repo}\' to a new format. '
f'This might take a while. Would you like to continue?'):
shutil.move(aim_dir_path, aim_v3_dir_path)
try:
Repo.init(repo)
repo_inst = Repo.from_path(repo, read_only=False)
if run is not None:
utils.migrate_v3_run_data(repo_inst, aim_v3_dir_path, run_hash=run)
else:
utils.migrate_v3_data(repo_inst, aim_v3_dir_path)
except Exception as e:
shutil.rmtree(aim_dir_path)
click.secho(f'Failed to migrate Aim repo \'{repo}\'. Reason: {e}.')
click.secho(f'Old repository backup is available at \'{aim_v3_dir_path}\'.')
raise
else:
click.secho(f'Successfully migrated Aim repo \'{repo}\'.')
click.secho(f'Old repository backup is available at \'{aim_v3_dir_path}\'.')
188 changes: 188 additions & 0 deletions src/aimcore/cli/migrate/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import json
import pathlib

import click
import tqdm
import logging

from typing import Dict

from aim import Repo
from aim._core.storage.rockscontainer import RocksContainer
from aim._core.storage.treeview import TreeView
from aimstack.asp import Run, SystemMetric

logger = logging.getLogger(__name__)

RUN_DATA_QUERY_TEMPLATE = """
SELECT
run.hash as hash,
run.name as name,
run.description as description,
run.is_archived as archived,
experiment.name as experiment,
json_group_array(tag.name) as tags,
{select_notes} as notes
FROM
run
LEFT OUTER JOIN experiment ON run.experiment_id = experiment.id
{notes_join_clause}
LEFT OUTER JOIN run_tag ON run.id = run_tag.run_id LEFT JOIN tag ON run_tag.tag_id = tag.id
GROUP BY
run.hash;
"""


SEQUENCE_TYPE_MAP = {
'float': Run.get_metric,
'int': Run.get_metric,
'number': Run.get_metric,
'aim.image': Run.get_image_sequence,
'list(aim.image)': Run.get_image_sequence,
'aim.audio': Run.get_audio_sequence,
'list(aim.audio)': Run.get_audio_sequence,
'aim.text': Run.get_text_sequence,
'list(aim.text)': Run.get_text_sequence,
'aim.distribution': Run.get_distribution_sequence,
'aim.figure': Run.get_figure_sequence,
}


def migrate_v1_sequence_data(run: Run, trace_data_tree: TreeView, length: int, item_type: str, name: str, context: Dict):
if name.startswith('__system'):
seq = run.sequences.typed_sequence(SystemMetric, name, context)
elif item_type == 'aim.log_line':
seq = run.logs
else:
get_seq_method = SEQUENCE_TYPE_MAP.get(item_type)
if get_seq_method is None:
logger.warning(f'Unknown type of sequence element \'{item_type}\'. Skipping.')
return
seq = get_seq_method(run, name, context)
trace_iter = zip(trace_data_tree.subtree('val').items(),
trace_data_tree.subtree('time').items(),
trace_data_tree.subtree('epoch').items())
trace_iter = tqdm.tqdm(trace_iter, leave=False, total=length)
context_str = str(context)
if len(context_str) > 20:
context_str = context_str[:16] + '...}'
for (step, value), (_, time), (_, epoch) in trace_iter:
trace_iter.set_description(f'Processing sequence context={context_str}, name=\'{name}\'')
seq.track(value, step=step, epoch=epoch, time=time)


def migrate_v2_sequence_data(run: Run, trace_data_tree: TreeView, length: int, name: str, context: Dict):
if name.startswith('__system'):
seq = run.sequences.typed_sequence(SystemMetric, name, context)
else:
seq = run.get_metric(name, context) # only Metric sequences had V2 data format
trace_iter = zip(trace_data_tree.subtree('step').items(),
trace_data_tree.subtree('val').items(),
trace_data_tree.subtree('time').items(),
trace_data_tree.subtree('epoch').items())
trace_iter = tqdm.tqdm(trace_iter, leave=False, total=length)
context_str = str(context)
if len(context_str) > 20:
context_str = context_str[:16] + '...}'
for (_, step), (_, value), (_, time), (_, epoch) in trace_iter:
trace_iter.set_description(f'Processing sequence context={context_str}, name=\'{name}\'')
seq.track(value, step=step, epoch=epoch, time=time)


def migrate_single_run(repo: Repo, v3_repo_path: pathlib.Path, run_hash: str, run_data: Dict):
meta_container_path = v3_repo_path / 'meta' / 'chunks' / run_hash
meta_container = RocksContainer(str(meta_container_path), read_only=True)
meta_tree: TreeView = meta_container.tree()
run_info_tree = meta_tree.subtree('meta').subtree('chunks').subtree(run_hash)
context_info = run_info_tree.collect('contexts')
trace_info = run_info_tree.collect('traces')

new_run = Run(repo=repo, mode='WRITE')
new_run[...] = run_info_tree.get('attrs', {})
new_run['hash_'] = run_hash

if run_data is not None:
new_run.name = run_data['name']
new_run.archived = run_data['archived']
new_run.description = run_data['description']
new_run['experiment_name'] = run_data['experiment']
if len(run_data['tags']) > 0 and run_data['tags'][0] is not None:
new_run['tags'] = run_data['tags']
if len(run_data['notes']) > 0 and run_data['notes'][0] is not None:
new_run['notes'] = run_data['notes']

trace_container_path = v3_repo_path / 'seqs' / 'chunks' / run_hash
trace_container = RocksContainer(str(trace_container_path), read_only=True)
trace_tree: TreeView = trace_container.tree()
traces_data_tree = trace_tree.subtree('seqs').subtree('chunks').subtree(run_hash)
v2_traces_data_tree = trace_tree.subtree('seqs').subtree('v2').subtree('chunks').subtree(run_hash)

for context_idx, context_data in trace_info.items():
for name, info in context_data.items():
if info.get('version', 1) == 1:
trace_data_tree = traces_data_tree.subtree(context_idx).subtree(name)
item_type = info.get('dtype', 'float')
migrate_v1_sequence_data(
new_run, trace_data_tree,
length=info.get('last_step'), item_type=item_type, name=name, context=context_info[context_idx]
)
else: # v2 sequence
trace_data_tree = v2_traces_data_tree.subtree(context_idx).subtree(name)
migrate_v2_sequence_data(
new_run, trace_data_tree,
length=info.get('last_step'), name=name, context=context_info[context_idx]
)


def get_relational_data(sql_db_path: pathlib.Path) -> Dict:
def table_exists(tbl):
res = cursor.execute(f'SELECT count(name) FROM sqlite_master WHERE type=\'table\' AND name=\'{tbl}\';')
return res.fetchone()[0] == 1

try:
import sqlite3
except ModuleNotFoundError:
if not click.confirm(f'Missing package \'sqlite3\'. Cannot migrate Run experiment, tags and notes info. '
f'Would you like to proceed?'):
exit(0)
return {}
else:
conn = sqlite3.connect(str(sql_db_path))
cursor = conn.cursor()
runs_data = {}

notes_table_exists = table_exists('note')
select_notes = 'json_group_array(note.content)' if notes_table_exists else '\'[]\''
notes_join_clause = 'LEFT OUTER JOIN note ON run.id = note.run_id' if notes_table_exists else ''
query = RUN_DATA_QUERY_TEMPLATE.format(select_notes=select_notes, notes_join_clause=notes_join_clause)
for (run_hash, name, desc, archived, exp, tags, notes) in cursor.execute(query):
runs_data[run_hash] = {
'name': name,
'description': desc,
'archived': archived != 0,
'experiment': exp,
'tags': json.loads(tags),
'notes': json.loads(notes)
}
return runs_data


def migrate_v3_data(repo: Repo, v3_repo_path: pathlib.Path):
chunks_dir = v3_repo_path / 'meta' / 'chunks'
sql_db_path = v3_repo_path / 'run_metadata.sqlite'
run_hash_list = []
if chunks_dir.exists():
run_hash_list = list(map(lambda x: x.relative_to(chunks_dir).name, chunks_dir.glob('*')))

runs_data = get_relational_data(sql_db_path)
runs_iter = tqdm.tqdm(run_hash_list, leave=False)
for run_hash in runs_iter:
runs_iter.set_description(f'Processing Run "{run_hash}"')
migrate_single_run(repo, v3_repo_path, run_hash=run_hash, run_data=runs_data.get(run_hash))


def migrate_v3_run_data(repo: Repo, v3_repo_path: pathlib.Path, run_hash: str):
sql_db_path = v3_repo_path / 'run_metadata.sqlite'

runs_data = get_relational_data(sql_db_path)
migrate_single_run(repo, v3_repo_path, run_hash=run_hash, run_data=runs_data.get(run_hash))
1 change: 1 addition & 0 deletions src/aimcore/cli/package/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .utils import init_template, pyproject_toml_template, get_pkg_distribution_sources
from .watcher import PackageSourceWatcher


@click.group('package')
def package():
pass
Expand Down

0 comments on commit 4738ad4

Please sign in to comment.