Skip to content

Commit

Permalink
Inversion counting for misranked pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
ArtemSokolov committed Mar 30, 2023
1 parent 43091c7 commit c1f6e03
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion paired-eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,43 @@ def paired_eval(scores, labels, min_dist=0.5):
# a rankable pair relative to r
nrp = nrp + (l + 1) # +1 to account for 0-based indexing
r = r + 1

# Compute the number of misranked pairs through inversion counting
arr = np.array(labels)[np.argsort(scores)] # Array for merge sort
tmp = np.zeros_like(arr) # Temporary work space

# Computes an inversion count (cnt) on the [left, right] region of arr
def _merge_sort(l, r):
cnt = 0
if l >= r: return cnt

# Recurse on [left, mid) and [mid, right]
m = (l + r)//2
cnt += _merge_sort(l, m)
cnt += _merge_sort(m+1, r)

return nrp, 0
# Count inversions from the merge
i = l; j = m+1; k = l
while i <= m and j <= r:

# No inversions if (i, j) is ranked correctly
if arr[i] <= arr[j]:
tmp[k] = arr[i]
i += 1; k += 1

# Otherwise, everything up to mid is an inversion relative to j
else:
cnt += m - i + 1
tmp[k] = arr[j]
j += 1; k += 1

# Copy the remaining bits of left and right
while i <= m: tmp[k] = arr[i]; k += 1; i += 1
while j <= r: tmp[k] = arr[j]; k += 1; j += 1

# Update the array chunk with its sorted version
arr[l:(r+1)] = tmp[l:(r+1)]
return cnt

# Ranked correctly = total - misranked
return nrp, nrp - _merge_sort(0, len(arr)-1)

0 comments on commit c1f6e03

Please sign in to comment.