Skip to content

Commit

Permalink
black linting
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 25, 2023
1 parent 7dbddc0 commit 16ed520
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
19 changes: 13 additions & 6 deletions examples/bert_influence/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def single_checkpoint_influence(
_, eval_train_loader, test_loader = get_loaders(data_name=data_name)

# Set-up
analog = AnaLog(project="test", config="/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/config.yaml")
analog = AnaLog(
project="test",
config="/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/config.yaml",
)

# Hessian logging
analog.watch(model, type_filter=[torch.nn.Linear], lora=False)
Expand All @@ -52,9 +55,7 @@ def single_checkpoint_influence(

logits = outputs.view(-1, outputs.shape[-1])
labels = batch["labels"].view(-1).to(DEVICE)
loss = F.cross_entropy(
logits, labels, reduction="sum", ignore_index=-100
)
loss = F.cross_entropy(logits, labels, reduction="sum", ignore_index=-100)
loss.backward()
analog.finalize()

Expand All @@ -74,7 +75,10 @@ def single_checkpoint_influence(
logits = outputs.view(-1, outputs.shape[-1])
labels = batch["labels"].view(-1).to(DEVICE)
loss = F.cross_entropy(
logits, labels, reduction="sum", ignore_index=-100,
logits,
labels,
reduction="sum",
ignore_index=-100,
)
loss.backward()
analog.finalize()
Expand All @@ -97,7 +101,10 @@ def single_checkpoint_influence(
test_logits = test_outputs.view(-1, outputs.shape[-1])
test_labels = test_batch["labels"].view(-1).to(DEVICE)
test_loss = F.cross_entropy(
test_logits, test_labels, reduction="sum", ignore_index=-100,
test_logits,
test_labels,
reduction="sum",
ignore_index=-100,
)
test_loss.backward()

Expand Down
1 change: 1 addition & 0 deletions examples/bert_influence/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
train_loader, _, valid_loader = get_loaders(data_name=args.data_name)
model = construct_model(data_name=args.data_name).to(device)


def train(
model: nn.Module,
loader: torch.utils.data.DataLoader,
Expand Down
6 changes: 2 additions & 4 deletions examples/bert_influence/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def forward(
).logits


def construct_model(
data_name: str, ckpt_path: Union[None, str] = None
) -> nn.Module:
def construct_model(data_name: str, ckpt_path: Union[None, str] = None) -> nn.Module:
model = SequenceClassificationModel(data_name)
if ckpt_path is not None:
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
Expand Down Expand Up @@ -190,4 +188,4 @@ def preprocess_function(examples):
batch_size=batch_size,
shuffle=split == "train",
collate_fn=default_data_collator,
)
)

0 comments on commit 16ed520

Please sign in to comment.