From e69dee8f62c7ac34409eaf4dfa8b7ba4062a6436 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Sun, 14 Jul 2024 19:38:34 -0400 Subject: [PATCH] Fix typos --- examples/README.md | 16 ++++++++-------- examples/openwebtext/compute_scores.py | 13 ++++++++++++- examples/openwebtext/generate.py | 1 - examples/openwebtext/inpsect_factors.py | 6 +++--- examples/openwebtext/task.py | 24 ++++++++++++++++++++++++ kronfluence/version.py | 2 +- 6 files changed, 48 insertions(+), 14 deletions(-) diff --git a/examples/README.md b/examples/README.md index 328a50f..da1470c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -18,14 +18,14 @@ Our examples cover the following tasks:
-| Task | Example Datasets | -|----------------------|:------------------------:| -| Regression | UCI | -| Image Classification | CIFAR-10 & ImageNet | -| Text Classification | GLUE | -| Multiple-Choice | SWAG | -| Summarization | DNN/DailyMail | -| Language Modeling | WikiText-2 & OpenWebText | +| Task | Example Datasets | +|----------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| Regression | [UCI](https://github.com/pomonam/kronfluence/tree/main/examples/uci) | +| Image Classification | [CIFAR-10](https://github.com/pomonam/kronfluence/tree/main/examples/cifar) & [ImageNet](https://github.com/pomonam/kronfluence/tree/main/examples/imagenet) | +| Text Classification | [GLUE](https://github.com/pomonam/kronfluence/tree/main/examples/glue) | +| Multiple-Choice | [SWAG](https://github.com/pomonam/kronfluence/tree/main/examples/swag) | +| Summarization | [CNN/DailyMail](https://github.com/pomonam/kronfluence/tree/main/examples/dailymail) | +| Language Modeling | [WikiText-2](https://github.com/pomonam/kronfluence/tree/main/examples/wikitext) & [OpenWebText](https://github.com/pomonam/kronfluence/tree/main/examples/openwebtext) |
diff --git a/examples/openwebtext/compute_scores.py b/examples/openwebtext/compute_scores.py index 59bd670..d9d94d7 100644 --- a/examples/openwebtext/compute_scores.py +++ b/examples/openwebtext/compute_scores.py @@ -11,8 +11,9 @@ get_custom_dataset, get_openwebtext_dataset, ) -from examples.openwebtext.task import LanguageModelingTask +from examples.openwebtext.task import LanguageModelingTask, LanguageModelingWithMarginMeasurementTask from kronfluence.analyzer import Analyzer, prepare_model +from kronfluence.utils.common.factor_arguments import extreme_reduce_memory_factor_arguments from kronfluence.utils.common.score_arguments import ( extreme_reduce_memory_score_arguments, ) @@ -28,13 +29,21 @@ def parse_args(): parser.add_argument( "--factors_name", type=str, + required=True, help="Name of the factor.", ) parser.add_argument( "--scores_name", type=str, + required=True, help="Name of the score.", ) + parser.add_argument( + "--use_margin_for_measurement", + action="store_true", + default=False, + help="Boolean flag whether to use margin for measurement.", + ) parser.add_argument( "--query_gradient_rank", type=int, @@ -71,6 +80,8 @@ def main(): # Define task and prepare model. task = LanguageModelingTask() + if args.use_margin_for_measurement: + task = LanguageModelingWithMarginMeasurementTask() model = prepare_model(model, task) kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400)) # 1.5 hours. diff --git a/examples/openwebtext/generate.py b/examples/openwebtext/generate.py index 46eead5..941a4d5 100644 --- a/examples/openwebtext/generate.py +++ b/examples/openwebtext/generate.py @@ -10,7 +10,6 @@ # prompt = "Machine learning can be defined as" # prompt = "Using a distributed database has many advantages." # prompt = "Inflation is typically measured by" -# prompt = "The prime minister of Canada is definitely Justin Bieber. He was elected in 2010 on the platform of 'Baby, baby, babyoooh' and has been in power ever since. Some of Bieber’s key accomplishments as prime minister include:" outputs = pipeline(prompt) print("Prompt:") diff --git a/examples/openwebtext/inpsect_factors.py b/examples/openwebtext/inpsect_factors.py index 3b719ed..90c8aa6 100644 --- a/examples/openwebtext/inpsect_factors.py +++ b/examples/openwebtext/inpsect_factors.py @@ -11,11 +11,11 @@ def main(): plt.rcParams.update(markers.with_edge()) plt.rcParams["axes.axisbelow"] = True - layer_num = 18 + layer_num = 31 module_name = f"model.layers.{layer_num}.mlp.down_proj" # module_name = f"model.layers.{layer_num}.mlp.up_proj" - lambda_processed = Analyzer.load_file("influence_results/num_lambda_processed.safetensors")[module_name] - lambda_matrix = Analyzer.load_file("influence_results/lambda_matrix.safetensors")[module_name] + lambda_processed = Analyzer.load_file("num_lambda_processed.safetensors")[module_name] + lambda_matrix = Analyzer.load_file("lambda_matrix.safetensors")[module_name] lambda_matrix.div_(lambda_processed) lambda_matrix = lambda_matrix.float() plt.matshow(lambda_matrix, cmap="PuBu", norm=LogNorm()) diff --git a/examples/openwebtext/task.py b/examples/openwebtext/task.py index 3bfb078..d43e8f3 100644 --- a/examples/openwebtext/task.py +++ b/examples/openwebtext/task.py @@ -69,3 +69,27 @@ def get_influence_tracked_modules(self) -> List[str]: def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor: return batch["attention_mask"] + + +class LanguageModelingWithMarginMeasurementTask(LanguageModelingTask): + def compute_measurement( + self, + batch: BATCH_TYPE, + model: nn.Module, + ) -> torch.Tensor: + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + ).logits.float() + labels = batch["labels"][..., 1:].contiguous().view(-1) + masks = labels != -100 + logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) + + bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) + logits_correct = logits[bindex, labels] + + cloned_logits = logits.clone() + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) + + margins = logits_correct - cloned_logits.logsumexp(dim=-1) + return -margins[masks].sum() diff --git a/kronfluence/version.py b/kronfluence/version.py index 5becc17..5c4105c 100644 --- a/kronfluence/version.py +++ b/kronfluence/version.py @@ -1 +1 @@ -__version__ = "1.0.0" +__version__ = "1.0.1"