-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
134 lines (113 loc) · 4.5 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import yaml
import argparse
from dataclasses import dataclass
from typing import List, Optional, Dict
from pathlib import Path
@dataclass
class PretrainingData:
"""Information about model's pretraining data."""
dataset: str
dataset_version: str
dataset_size: str
@dataclass
class ModelConfig:
"""Configuration for model evaluation."""
name: str
cls: str
batch_size: int = 8
device: str = "cuda"
architecture: str = "Unknown"
dataset: str = "Unknown"
dataset_version: str = "Unknown"
dataset_size: str = "Unknown"
max_length: Optional[int] = None
revision: Optional[str] = None
tokenizer_name: Optional[str] = None # if not specified, will default to the model name
@classmethod
def from_dict(cls, data: Dict):
"""Create ModelConfig from dictionary."""
# Copy the data to avoid modifying the original
config_data = data.copy()
return cls(**config_data)
@dataclass
class ValidationConfig:
"""Configuration for validation dataset."""
path: str
batch_size: int = 16
max_seq_len: Optional[int] = 1024
max_steps: Optional[int] = None
@dataclass
class EvalConfig:
"""Configuration for evaluation settings."""
tasks: Dict
validation_configs: Dict[str, ValidationConfig]
output_dir: str = "results"
save_details: bool = True
compute_loss: bool = False
def load_yaml(file_path: str):
with open(file_path, 'r') as file:
return yaml.safe_load(file)
def get_model_configs(yaml_path: str) -> List[ModelConfig]:
data = load_yaml(yaml_path)
return [ModelConfig.from_dict(model) for model in data['models']]
def get_eval_config(yaml_path: str, output_dir: str, compute_loss: bool = False) -> EvalConfig:
data = load_yaml(yaml_path)
# Convert validation configs to ValidationConfig objects
validation_configs = {}
if 'validation_configs' in data:
for name, config in data['validation_configs'].items():
validation_configs[name] = ValidationConfig(**config)
return EvalConfig(
tasks=data['task_configs'],
validation_configs=validation_configs,
output_dir=output_dir,
compute_loss=compute_loss
)
def format_model_info(model_config) -> str:
"""Format model info for filename and metadata."""
dataset = model_config.dataset.replace(' ', '')
tokens = model_config.dataset_size.split()[0]
filename = (
f"name={model_config.name.replace('/', '_')}"
f"__arch={model_config.architecture}"
f"__dataset={dataset}"
f"__dataset_version={model_config.dataset_version}"
f"__size={tokens}"
)
if model_config.revision:
filename += f"__checkpoint={model_config.revision}"
return filename
def get_model_metadata(model_config) -> dict:
"""Get model metadata for storing in results."""
metadata = {
"model_name": model_config.name,
"architecture": model_config.architecture,
"dataset": model_config.dataset,
"dataset_version": model_config.dataset_version,
"dataset_size": model_config.dataset_size,
"batch_size": model_config.batch_size,
"device": model_config.device
}
if model_config.revision:
metadata["checkpoint"] = model_config.revision
return metadata
def parse_args():
parser = argparse.ArgumentParser(description='LLM Evaluation Framework')
parser.add_argument('--models_yaml', type=str, default='pile_models.yaml',
help='Path to models configuration file')
parser.add_argument('--tasks_yaml', type=str, default='tasks.yaml',
help='Path to tasks configuration file')
parser.add_argument('--results_dir', type=str,
default='/is/cluster/fast/pmayilvahanan/llm_line/results/',
help='Directory to store evaluation results')
parser.add_argument('--shuffle', action='store_true',
help='shuffling the models to evaluate')
# Add evaluation mode group
eval_mode = parser.add_mutually_exclusive_group()
eval_mode.add_argument('--accuracy-only', action='store_true',
help='Only compute accuracy metrics')
eval_mode.add_argument('--loss-only', action='store_true',
help='Only compute cross-entropy loss')
eval_mode.add_argument('--compute-both', action='store_true',
help='Compute both accuracy and loss metrics (default)')
return parser.parse_args()