Skip to content

Commit

Permalink
Replace uses of str with Path, as per Nikita's suggestion (#326)
Browse files Browse the repository at this point in the history
* move transformer configs into JSON files

* fixes

* replace str with path for path manipulation

* code merge
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 3bc5fe3 commit 9209436
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def find_multiple(n: int, k: int) -> int:
return n
return n + k - (n % k)

config_dir = f"{str(Path(__file__).parent)}/known_model_params"
config_path = Path(config_dir)
config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")

@dataclass
class ModelArgs:
Expand Down Expand Up @@ -74,25 +73,24 @@ def from_params(cls, params_path):
@classmethod
def from_table(cls, name: str):
print(f"name {name}")
json_path = Path(f"{config_dir}/{name}.json")
json_path = config_path / f"{name}.json"
if json_path.is_file():
return ModelArgs.from_params(json_path)
else:
config_dir = f"{__file__}/known_model_params"
known_model_params = [config.replace(".json", "") for config in os.listdir(config_dir)]
known_model_params = [config.replace(".json", "") for config in os.listdir(config_path)]
raise RuntimeError(f"unknown table index {name} for transformer config, must be from {known_model_params}")

@classmethod
def from_name(cls, name: str):
print(f"Name {name}")
json_path=f"{config_dir}/{name}.json"
print(f"name {name}")
json_path=config_path / f"{name}.json"
if Path(json_path).is_file():
return ModelArgs.from_params(json_path)

known_model_params = [config.replace(".json", "") for config in os.listdir(config_dir)]
known_model_params = [config.replace(".json", "") for config in os.listdir(config_path)]

print(f"known configs: {known_model_params}")
# Fuzzy search by name (e.g. "7B" and "Mistral-7B")
print(f"Known configs: {known_model_params}")
config = [
config
for config in known_model_params
Expand All @@ -111,7 +109,7 @@ def from_name(cls, name: str):
f"Unknown model directory name {name}. Must be one of {known_model_params}."
)

return ModelArgs.from_params(f"{config_dir}/{config[0]}.json")
return ModelArgs.from_params(config_path / f"{config[0]}.json")



Expand Down

0 comments on commit 9209436

Please sign in to comment.