-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathear.py
43 lines (35 loc) · 1.47 KB
/
ear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
def compute_negative_entropy(
inputs: tuple, attention_mask: torch.torch, return_values: bool = False
):
"""Compute the negative entropy across layers of a network for given inputs.
Args:
- input: tuple. Tuple of length num_layers. Each item should be in the form: BHSS
- attention_mask. Tensor with dim: BS
"""
inputs = torch.stack(inputs) # LayersBatchHeadsSeqlenSeqlen
assert inputs.ndim == 5, "Here we expect 5 dimensions in the form LBHSS"
# average over attention heads
pool_heads = inputs.mean(2)
batch_size = pool_heads.shape[1]
samples_entropy = list()
neg_entropies = list()
for b in range(batch_size):
# get inputs from non-padded tokens of the current sample
mask = attention_mask[b]
sample = pool_heads[:, b, mask.bool(), :]
sample = sample[:, :, mask.bool()]
# get the negative entropy for each non-padded token
neg_entropy = (sample.softmax(-1) * sample.log_softmax(-1)).sum(-1)
if return_values:
neg_entropies.append(neg_entropy.detach())
# get the "average entropy" that traverses the layer
mean_entropy = neg_entropy.mean(-1)
# store the sum across all the layers
samples_entropy.append(mean_entropy.sum(0))
# average over the batch
final_entropy = torch.stack(samples_entropy).mean()
if return_values:
return final_entropy, neg_entropies
else:
return final_entropy