Skip to content

Commit

Permalink
Change filenames
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing committed Aug 19, 2024
1 parent 3553a89 commit 41f8b7d
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 313 deletions.
2 changes: 1 addition & 1 deletion tuning/tune.py → tuning/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen

"""
Usage: ./tune.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk
Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk
"""

tune_logger = logging.getLogger("tune")
Expand Down
20 changes: 10 additions & 10 deletions tuning/autotune.py → tuning/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import time
import multiprocessing
import queue
import tune
import candidate_gen
from tqdm import tqdm
import re
import hashlib
Expand Down Expand Up @@ -48,7 +48,7 @@ class CandidateTracker:
candidate_id: int
dispatch_mlir_path: Optional[Path] = None
dispatch_config_path: Optional[Path] = None
configuration: Optional[tune.Configuration] = None
configuration: Optional[candidate_gen.Configuration] = None
compilation_successful: Optional[bool] = None
compiled_dispatch_path: Optional[Path] = None
compiled_dispatch_hash: Optional[str] = None
Expand Down Expand Up @@ -408,12 +408,12 @@ def parse_arguments() -> argparse.Namespace:
help="Do not attempt to run any modules or initialize the IREE runtime",
)

# tune.tune() options
# candidate_gen.tune() options
parser.add_argument(
"--num-candidates",
type=int,
default=DEFAULT_NUM_CANDIDATES,
help=f"Number of candidates to be generated by tune.py (default: {DEFAULT_NUM_CANDIDATES})",
help=f"Number of candidates to be generated by candidate_gen.py (default: {DEFAULT_NUM_CANDIDATES})",
)
parser.add_argument(
"--num-subgroups",
Expand Down Expand Up @@ -477,7 +477,7 @@ def format(self, record):
verbose_console_handler.setFormatter(file_formatter)
logging.getLogger().addHandler(verbose_console_handler)

# config logger in tune.py
# config logger in candidate_gen.py
tune_logger = logging.getLogger("tune")
tune_logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -725,8 +725,8 @@ def generate_candidates(

mlirs = []
try:
logging.debug("Captured messages from tune.py:")
tune.tune(
logging.debug("Captured messages from candidate_gen.py:")
candidate_gen.tune(
input=str(path_config.template_mlir),
output=str(path_config.candidates_dir),
limit=args.num_candidates,
Expand All @@ -740,14 +740,14 @@ def generate_candidates(
)
except Exception as e:
logging.error("An error occurred during candidates generation: %s", str(e))
# Capture and log debug messages from tune.py
# Capture and log debug messages from candidate_gen.py
tune_logger = logging.getLogger("tune")
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.FileHandler):
tune_logger.handlers.append(handler)
tune_logger.exception("Error in tune.py:")
tune_logger.exception("Error in candidate_gen.py:")
raise
logging.debug("tune.py ends")
logging.debug("candidate_gen.py ends")

candidate_configs = load_pickle(path_config.candidate_configs_pkl)
candidate_configs.insert(0, None) # No Configuration class for 0.mlir
Expand Down
48 changes: 24 additions & 24 deletions tuning/punet_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import autotune
import libtuner
from pathlib import Path


Expand All @@ -26,10 +26,10 @@
"""


class PunetClient(autotune.TuningClient):
class PunetClient(libtuner.TuningClient):

def get_dispatch_compile_command(
self, candidate_tracker: autotune.CandidateTracker
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
mlir_path = candidate_tracker.dispatch_mlir_path
assert mlir_path is not None
Expand All @@ -41,7 +41,7 @@ def get_dispatch_compile_command(
return command

def get_dispatch_benchmark_command(
self, candidate_tracker: autotune.CandidateTracker
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
compiled_vmfb_path = candidate_tracker.compiled_dispatch_path
assert compiled_vmfb_path is not None
Expand All @@ -52,7 +52,7 @@ def get_dispatch_benchmark_command(
return command

def get_model_compile_command(
self, candidate_tracker: autotune.CandidateTracker
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
mlir_spec_path = candidate_tracker.spec_path
assert mlir_spec_path is not None
Expand All @@ -64,7 +64,7 @@ def get_model_compile_command(
return command

def get_model_benchmark_command(
self, candidate_tracker: autotune.CandidateTracker
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
unet_candidate_path = candidate_tracker.model_path
assert unet_candidate_path is not None
Expand All @@ -76,73 +76,73 @@ def get_model_benchmark_command(


def main():
args = autotune.parse_arguments()
path_config = autotune.PathConfig()
args = libtuner.parse_arguments()
path_config = libtuner.PathConfig()
path_config.base_dir.mkdir(parents=True, exist_ok=True)
path_config.output_unilog.touch()
candidate_trackers: list[autotune.CandidateTracker] = []
candidate_trackers: list[libtuner.CandidateTracker] = []
punet_client = PunetClient()
stop_after_phase: str = args.stop_after

print("Setup logging")
autotune.setup_logging(args, path_config)
libtuner.setup_logging(args, path_config)
print(path_config.run_log, end="\n\n")

print("Validating devices")
autotune.validate_devices(args.devices)
libtuner.validate_devices(args.devices)
print("Validation successful!\n")

print("Generating candidates...")
candidates = autotune.generate_candidates(
candidates = libtuner.generate_candidates(
args, path_config, candidate_trackers, punet_client
)
print(f"Generated [{len(candidates)}] candidates in {path_config.candidates_dir}\n")
if stop_after_phase == autotune.ExecutionPhases.generate_candidates:
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling candidates...")
compiled_candidates = autotune.compile_dispatches(
compiled_candidates = libtuner.compile_dispatches(
args, path_config, candidates, candidate_trackers, punet_client
)
print(f"Compiled files are stored in {path_config.compiled_dir}\n")
if stop_after_phase == autotune.ExecutionPhases.compile_dispatches:
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled candidates...")
top_candidates = autotune.benchmark_dispatches(
top_candidates = libtuner.benchmark_dispatches(
args, path_config, compiled_candidates, candidate_trackers, punet_client
)
print(f"Stored results in {path_config.output_unilog}\n")
if stop_after_phase == autotune.ExecutionPhases.benchmark_dispatches:
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

print(f"Compiling top model candidates...")
punet_candidates = autotune.compile_models(
punet_candidates = libtuner.compile_models(
args, path_config, top_candidates, candidate_trackers, punet_client
)
print(f"Model candidates compiled in {path_config.base_dir}\n")
if stop_after_phase == autotune.ExecutionPhases.compile_models:
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking model candidates...")
autotune.benchmark_models(
libtuner.benchmark_models(
args, path_config, punet_candidates, candidate_trackers, punet_client
)
print(f"Stored results in {path_config.output_unilog}")
if stop_after_phase == autotune.ExecutionPhases.benchmark_models:
if stop_after_phase == libtuner.ExecutionPhases.benchmark_models:
return

autotune.summerize_top_candidates(path_config, candidate_trackers)
libtuner.summerize_top_candidates(path_config, candidate_trackers)
print(f"Stored top candidates info in {path_config.result_summary_log}\n")

autotune.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers)
libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers)
print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n")

print("Check the detailed execution logs in:")
print(path_config.run_log)

for candidate in candidate_trackers:
autotune.logging.debug(candidate)
libtuner.logging.debug(candidate)
if args.verbose:
print(candidate)

Expand Down
Loading

0 comments on commit 41f8b7d

Please sign in to comment.