Skip to content

Commit

Permalink
Fix pre-commit err
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing committed Aug 16, 2024
1 parent a58e823 commit 8f26965
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 31 deletions.
100 changes: 79 additions & 21 deletions tuning/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,20 @@ def get_candidate_spec_mlir_path(self, candidate_id: int) -> Path:

def get_exe_format(self, path: Path) -> str:
return f"./{path.as_posix()}"


@dataclass
class TuningClient(ABC):
@abstractmethod
def get_dispatch_compile_command(self, candidate_tracker: CandidateTracker) -> list[str]:
def get_dispatch_compile_command(
self, candidate_tracker: CandidateTracker
) -> list[str]:
pass

@abstractmethod
def get_dispatch_benchmark_command(self, candidate_tracker) -> list[str]:
pass

@abstractmethod
def get_model_compile_command(self, candidate_tracker) -> list[str]:
pass
Expand All @@ -171,16 +173,49 @@ def get_model_compile_command(self, candidate_tracker) -> list[str]:
def get_model_benchmark_command(self, candidate_tracker) -> list[str]:
pass

@abstractmethod
def get_compiled_dispatch_index(self, file_path: Path) -> int:
pass

def get_candidate_spec_filename(self, candidate_id: int) -> Path:
@abstractmethod
def get_candidate_spec_filename(self, candidate_id: int) -> str:
pass

@abstractmethod
def get_compiled_model_index(self, file_path: Path) -> int:
pass


@dataclass
class DefaultTuningClient(TuningClient):
def get_dispatch_compile_command(
self, candidate_tracker: CandidateTracker
) -> list[str]:
command = [""]
return command

def get_dispatch_benchmark_command(self, candidate_tracker) -> list[str]:
command = [""]
return command

def get_model_compile_command(self, candidate_tracker) -> list[str]:
command = [""]
return command

def get_model_benchmark_command(self, candidate_tracker) -> list[str]:
command = [""]
return command

def get_compiled_dispatch_index(self, file_path: Path) -> int:
return 0

def get_candidate_spec_filename(self, candidate_id: int) -> str:
return ""

def get_compiled_model_index(self, file_path: Path) -> int:
return 0


@dataclass
class TaskTuple:
args: argparse.Namespace
Expand Down Expand Up @@ -719,7 +754,7 @@ def generate_candidates(
args: argparse.Namespace,
path_config: PathConfig,
candidate_trackers: list[CandidateTracker],
tuning_client: TuningClient
tuning_client: TuningClient,
) -> list[int]:
"""Generate candidate files for tuning. Returns the list of candidate indexes"""
logging.info("generate_candidates()")
Expand Down Expand Up @@ -818,7 +853,7 @@ def compile_dispatches(
path_config: PathConfig,
candidates: list[int],
candidate_trackers: list[CandidateTracker],
tuning_client: TuningClient
tuning_client: TuningClient,
) -> list[int]:
"""Compile candidate files for tuning and record in candidate_vmfbs.txt. Returns the list of compiled candidate indexes."""
logging.info("compile_candidates()")
Expand All @@ -827,7 +862,14 @@ def compile_dispatches(
logging.info("No candidates to compile.")
return []

task_list = [TaskTuple(args, tuning_client.get_dispatch_compile_command(candidate_trackers[i]), check=False) for i in candidates]
task_list = [
TaskTuple(
args,
tuning_client.get_dispatch_compile_command(candidate_trackers[i]),
check=False,
)
for i in candidates
]
num_worker = min(args.max_cpu_workers, len(task_list))
multiprocess_progress_wrapper(
num_worker=num_worker, task_list=task_list, function=run_command_wrapper
Expand Down Expand Up @@ -885,7 +927,7 @@ def parse_dispatch_benchmark_results(
path_config: PathConfig,
benchmark_results: list[TaskResult],
candidate_trackers: list[CandidateTracker],
tuning_client: TuningClient
tuning_client: TuningClient,
) -> tuple[list[ParsedDisptachBenchmarkResult], list[str]]:
benchmark_result_configs = []
dump_list = []
Expand All @@ -899,7 +941,10 @@ def parse_dispatch_benchmark_results(
benchmark_time = res.get_benchmark_time()
assert candidate_id is not None and benchmark_time is not None
candidate_trackers[candidate_id].first_benchmark_time = benchmark_time
candidate_trackers[candidate_id].mlir_spec_path = path_config.spec_dir / tuning_client.get_candidate_spec_filename(candidate_id)
candidate_trackers[candidate_id].mlir_spec_path = (
path_config.spec_dir
/ tuning_client.get_candidate_spec_filename(candidate_id)
)
mlir_path = candidate_trackers[candidate_id].mlir_path
mlir_spec_path = candidate_trackers[candidate_id].mlir_spec_path
assert mlir_path is not None and mlir_spec_path is not None
Expand Down Expand Up @@ -940,7 +985,7 @@ def benchmark_dispatches(
path_config: PathConfig,
compiled_candidates: list[int],
candidate_trackers: list[CandidateTracker],
tuning_client: TuningClient
tuning_client: TuningClient,
):
"""Benchmark the candidate files and store the topN results in file (best.log)."""
logging.info("benchmark_top_candidates()")
Expand All @@ -952,7 +997,15 @@ def benchmark_dispatches(
)
else:
# Benchmarking dispatch candidates
task_list = [TaskTuple(args, tuning_client.get_dispatch_benchmark_command(candidate_trackers[i]), check=False, command_need_device_id=True) for i in compiled_candidates]
task_list = [
TaskTuple(
args,
tuning_client.get_dispatch_benchmark_command(candidate_trackers[i]),
check=False,
command_need_device_id=True,
)
for i in compiled_candidates
]
worker_context_queue = create_worker_context_queue(args.devices)
benchmark_results = multiprocess_progress_wrapper(
num_worker=len(args.devices),
Expand Down Expand Up @@ -1006,7 +1059,7 @@ def compile_models(
path_config: PathConfig,
candidates: list[int],
candidate_trackers: list[CandidateTracker],
tuning_client: TuningClient
tuning_client: TuningClient,
) -> list[int]:
"""Compile U-Net candidates stored in best.log. Return the list of U-Net candidate files."""
logging.info("compile_unet_candidates()")
Expand All @@ -1017,10 +1070,11 @@ def compile_models(
if not candidates:
logging.info("No model candidates to compile.")
return []

task_list = [
TaskTuple(args, tuning_client.get_model_compile_command(candidate_trackers[i]))
for i in candidates if i != 0
TaskTuple(args, tuning_client.get_model_compile_command(candidate_trackers[i]))
for i in candidates
if i != 0
]
num_worker = min(args.max_cpu_workers, len(task_list))
multiprocess_progress_wrapper(
Expand Down Expand Up @@ -1226,7 +1280,7 @@ def benchmark_model(
path_config: PathConfig,
unet_candidates: list[int],
candidate_trackers: list[CandidateTracker],
tuning_client: TuningClient
tuning_client: TuningClient,
):
"""Benchmark U-Net candidate files and log the results."""
logging.info("benchmark_unet()")
Expand Down Expand Up @@ -1329,7 +1383,7 @@ def autotune(args: argparse.Namespace) -> None:
path_config.output_unilog.touch()

candidate_trackers: list[CandidateTracker] = []
tuning_client = TuningClient()
tuning_client = DefaultTuningClient()
stop_after_phase: str = args.stop_after

print("Setup logging")
Expand All @@ -1341,7 +1395,9 @@ def autotune(args: argparse.Namespace) -> None:
print("Validation successful!\n")

print("Generating candidates...")
candidates = generate_candidates(args, path_config, candidate_trackers, tuning_client)
candidates = generate_candidates(
args, path_config, candidate_trackers, tuning_client
)
print(f"Generated [{len(candidates)}] candidates in {path_config.candidates_dir}\n")
if stop_after_phase == ExecutionPhases.generate_candidates:
return
Expand All @@ -1356,7 +1412,7 @@ def autotune(args: argparse.Namespace) -> None:

print("Benchmarking compiled candidates...")
top_candidates = benchmark_dispatches(
args, path_config, compiled_candidates, candidate_trackers
args, path_config, compiled_candidates, candidate_trackers, tuning_client
)
print(f"Stored results in {path_config.output_unilog}\n")

Expand All @@ -1365,14 +1421,16 @@ def autotune(args: argparse.Namespace) -> None:

print(f"Compiling top unet candidates...")
unet_candidates = compile_models(
args, path_config, top_candidates, candidate_trackers
args, path_config, top_candidates, candidate_trackers, tuning_client
)
print(f"Unet candidates compiled in {path_config.base_dir}\n")
if stop_after_phase == ExecutionPhases.compile_unet_candidates:
return

print("Benchmarking unet candidates...")
benchmark_model(args, path_config, unet_candidates, candidate_trackers)
benchmark_model(
args, path_config, unet_candidates, candidate_trackers, tuning_client
)
print(f"Stored results in {path_config.output_unilog}")
if stop_after_phase == ExecutionPhases.benchmark_unet_candidates:
return
Expand Down
30 changes: 21 additions & 9 deletions tuning/punet_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
@dataclass
class PunetClient(autotune.TuningClient):

def get_dispatch_compile_command(self, candidate_tracker: autotune.CandidateTracker) -> list[str]:
def get_dispatch_compile_command(
self, candidate_tracker: autotune.CandidateTracker
) -> list[str]:
mlir_path = candidate_tracker.mlir_path
assert mlir_path is not None
command = [
Expand All @@ -17,7 +19,9 @@ def get_dispatch_compile_command(self, candidate_tracker: autotune.CandidateTrac
]
return command

def get_dispatch_benchmark_command(self, candidate_tracker: autotune.CandidateTracker) -> list[str]:
def get_dispatch_benchmark_command(
self, candidate_tracker: autotune.CandidateTracker
) -> list[str]:
compiled_vmfb_path = candidate_tracker.compiled_vmfb_path
assert compiled_vmfb_path is not None
command = [
Expand All @@ -26,7 +30,9 @@ def get_dispatch_benchmark_command(self, candidate_tracker: autotune.CandidateTr
]
return command

def get_model_compile_command(self, candidate_tracker: autotune.CandidateTracker) -> list[str]:
def get_model_compile_command(
self, candidate_tracker: autotune.CandidateTracker
) -> list[str]:
mlir_spec_path = candidate_tracker.mlir_spec_path
assert mlir_spec_path is not None
command = [
Expand All @@ -35,9 +41,10 @@ def get_model_compile_command(self, candidate_tracker: autotune.CandidateTracker
mlir_spec_path.as_posix(),
]
return command


def get_model_benchmark_command(self, candidate_tracker: autotune.CandidateTracker) -> list[str]:
def get_model_benchmark_command(
self, candidate_tracker: autotune.CandidateTracker
) -> list[str]:
unet_candidate_path = candidate_tracker.unet_candidate_path
assert unet_candidate_path is not None
command = [
Expand All @@ -47,9 +54,9 @@ def get_model_benchmark_command(self, candidate_tracker: autotune.CandidateTrack
return command

def get_compiled_dispatch_index(self, file_path: Path) -> int:
return int(file_path.stem)
return int(file_path.stem)

def get_candidate_spec_filename(self, candidate_id: int) -> Path:
def get_candidate_spec_filename(self, candidate_id: int) -> str:
return f"{candidate_id}_spec.mlir"

def get_compiled_model_index(self, file_path: Path) -> int:
Expand Down Expand Up @@ -77,7 +84,9 @@ def main():

autotune.setup_logging(args, path_config)

candidates = autotune.generate_candidates(args, path_config, candidate_trackers, punet_client)
candidates = autotune.generate_candidates(
args, path_config, candidate_trackers, punet_client
)

compiled_candidates = autotune.compile_dispatches(
args, path_config, candidates, candidate_trackers, punet_client
Expand All @@ -91,7 +100,10 @@ def main():
args, path_config, top_candidates, candidate_trackers, punet_client
)

autotune.benchmark_models(args, path_config, punet_candidates, candidate_trackers, punet_client)
autotune.benchmark_models(
args, path_config, punet_candidates, candidate_trackers, punet_client
)


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion tuning/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def generate_parsed_disptach_benchmark_result(
]

mock_tuning_client = MagicMock()
mock_tuning_client.get_candidate_spec_filename.side_effect = lambda i: f'{i}_spec.mlir'
mock_tuning_client.get_candidate_spec_filename.side_effect = (
lambda i: f"{i}_spec.mlir"
)
parsed_results, dump_list = autotune.parse_dispatch_benchmark_results(
path_config, benchmark_results, candidate_trackers, mock_tuning_client
)
Expand Down

0 comments on commit 8f26965

Please sign in to comment.