Skip to content

Commit

Permalink
tc
Browse files Browse the repository at this point in the history
  • Loading branch information
clee2000 committed Jan 29, 2025
1 parent 55fa38d commit bc182a2
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 37 deletions.
11 changes: 11 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ init_command = [
]
is_formatter = true

[[linter]]
code = 'SQL_PARAMS'
include_patterns = ['torchci/clickhouse_queries/**/params.json']
exclude_patterns = []
command = [
'python3',
'tools/linter/adapters/sql_params_linter.py',
'@{{PATHSFILE}}',
]
is_formatter = false

[[linter]]
code = 'RUSTFMT'
include_patterns = ['**/*.rs']
Expand Down
129 changes: 129 additions & 0 deletions tools/linter/adapters/sql_params_linter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse
import concurrent.futures
import json
import logging
import os
import re
import subprocess
import time
from enum import Enum
from typing import List, NamedTuple, Optional, Pattern


LINTER_CODE = "SQL_PARAMS"


class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"


class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]


RESULTS_RE: Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
(?P<line>\d+):
(?P<char>\d+):
\s(?P<message>.*)
\s(?P<code>\[.*\])
$
"""
)


def run_command(
args: List[str],
) -> "subprocess.CompletedProcess[bytes]":
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
return subprocess.run(
args,
capture_output=True,
)
finally:
end_time = time.monotonic()
logging.debug("took %dms", (end_time - start_time) * 1000)


def check_file(
filename: str,
) -> List[LintMessage]:
with open(filename, "rb") as f:
data = json.load(f)

message = []
if "params" not in data:
message.append("The file does not contain a 'params' key.")
elif not isinstance(data["params"], dict):
message.append("The 'params' key is not a dictionary.")
if "tests" not in data:
message.append("The file does not contain a 'tests' key.")
elif not isinstance(data["tests"], list):
message.append("The 'tests' key is not a list.")
if len(message) > 0:
return [
LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="lint",
replacement=None,
original=None,
description="; ".join(message),
)
]
return []


def main() -> None:
parser = argparse.ArgumentParser(
description=f"A simple linter for params.json files for sql queries",
fromfile_prefix_chars="@",
)
parser.add_argument(
"filenames",
nargs="+",
help="paths to lint",
)

args = parser.parse_args()

with concurrent.futures.ThreadPoolExecutor(
max_workers=os.cpu_count(),
thread_name_prefix="Thread",
) as executor:
futures = {
executor.submit(
check_file,
filename,
): filename
for filename in args.filenames
}
for future in concurrent.futures.as_completed(futures):
try:
for lint_message in future.result():
print(json.dumps(lint_message._asdict()), flush=True)
except Exception:
logging.critical('Failed at "%s".', futures[future])
raise


if __name__ == "__main__":
main()
82 changes: 45 additions & 37 deletions tools/torchci/clickhouse_query_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,37 @@
from prettytable import PrettyTable
from torchci.clickhouse import get_clickhouse_client, query_clickhouse
from torchci.utils import REPO_ROOT
from tqdm import tqdm
from tqdm import tqdm # type: ignore[import]


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Queue alert for torchci")
parser.add_argument("--query", type=str, help="Query name", required=True)
parser.add_argument(
"--perf", action="store_true", help="Run performance comparison"
"--head",
type=str,
help="Sha for the query to compare or get evaluations for",
required=True,
)
parser.add_argument("--base", type=str, help="Base sha for comparison")
parser.add_argument(
"--perf", action="store_true", help="Run performance analysis/comparison"
)
parser.add_argument(
"--results",
action="store_true",
help="Run results comparison. Requires --base",
)
parser.add_argument("--results", action="store_true", help="Run results comparison")
parser.add_argument(
"--times",
type=int,
help="Number of times to run the query. Only relevant if --perf is used",
default=10,
)
parser.add_argument(
"--compare",
type=str,
help="Either a sha or a branch name to compare against. These should be available locally. Required for --results",
)
args = parser.parse_args()
return args


def get_query_id(query: str, params: dict) -> Optional[str]:
try:
res = get_clickhouse_client().query(query, params)
return res.query_id
except Exception as e:
print(f"Error: {e}")
return None


@cache
def get_base_query(query: str, sha: str) -> str:
return subprocess.check_output(
Expand All @@ -72,24 +69,31 @@ def get_avg_stats(query_ids: list) -> tuple:
return metrics[0]["realTimeMSAvg"], metrics[0]["memoryBytesAvg"]


def get_query_ids(query: str, params: dict, times: int) -> tuple:
def get_query_ids(query: str, params: dict, times: int) -> list[str]:
def _get_query_id(query: str, params: dict) -> Optional[str]:
try:
res = get_clickhouse_client().query(query, params)
return res.query_id
except Exception as e:
print(f"Error: {e}")
return None

return [
x for _ in tqdm(range(times)) if (x := get_query_id(query, params)) is not None
x for _ in tqdm(range(times)) if (x := _get_query_id(query, params)) is not None
]


def format_comparision_string(new: float, old: float) -> str:
return f"{new} vs {old} ({new - old}, {round(100 * (new - old) / old)}%)"


@cache
def get_query(query: str) -> tuple:
with open(
REPO_ROOT / "torchci" / "clickhouse_queries" / query / "params.json"
) as f:
tests = json.load(f).get("tests", [])
with open(REPO_ROOT / "torchci" / "clickhouse_queries" / query / "query.sql") as f:
query = f.read()
def get_query(query: str, sha: str) -> tuple:
def _get_file(file_path: str) -> str:
return subprocess.check_output(["git", "show", f"{sha}:{file_path}"]).decode(
"utf-8"
)

tests = json.loads(_get_file(f"torchci/clickhouse_queries/{query}/params.json"))[
"tests"
]
query = _get_file(f"torchci/clickhouse_queries/{query}/query.sql")
for test in tests:
for key, value in test.items():
if isinstance(value, dict):
Expand All @@ -101,7 +105,7 @@ def get_query(query: str) -> tuple:


def perf_compare(args: argparse.Namespace) -> None:
query, tests = get_query(args.query)
query, tests = get_query(args.query, args.head)

print(
f"Gathering perf stats for: {args.query}\nNum tests: {len(tests)}\nNum times: {args.times}"
Expand All @@ -112,8 +116,8 @@ def perf_compare(args: argparse.Namespace) -> None:
new = get_query_ids(query, test, args.times)

base = None
if args.compare:
base_query = get_base_query(args.query, args.compare)
if args.base:
base_query, _ = get_query(args.query, args.base)
base = get_query_ids(base_query, test, args.times)
query_ids.append((new, base))

Expand All @@ -122,7 +126,7 @@ def perf_compare(args: argparse.Namespace) -> None:
# time to populate
time.sleep(20)
table = PrettyTable()
if args.compare:
if args.base:
table.field_names = [
"Test",
"Avg Time",
Expand Down Expand Up @@ -160,11 +164,12 @@ def perf_compare(args: argparse.Namespace) -> None:

def results_compare(args: argparse.Namespace) -> None:
query, tests = get_query(args.query)
if not args.compare:
if not args.base:
print("Base sha is required for results comparison")
return
base_query = get_base_query(args.query, args.compare)
base_query, _ = get_query(args.query, args.base)
print(
f"Comparing results for query: {args.query}\nNum tests: {len(tests)}\nBase: {args.compare}"
f"Comparing results for query: {args.query}\nNum tests: {len(tests)}\nHead: {args.head} Base: {args.base}"
)
for i, test in enumerate(tests):
new_results = query_clickhouse(query, test)
Expand All @@ -179,6 +184,9 @@ def results_compare(args: argparse.Namespace) -> None:

if __name__ == "__main__":
args = parse_args()
if not args.perf and not args.results:
print("Please specify --perf or --results")
exit(1)
if args.perf:
perf_compare(args)
if args.results:
Expand Down

0 comments on commit bc182a2

Please sign in to comment.