Skip to content

Commit

Permalink
fixed bug in NDCG and reranking
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewklayk committed Sep 17, 2023
1 parent 42c1d01 commit 6fef28d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
14 changes: 11 additions & 3 deletions aif360/metrics/regression_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def infeasible_index(self, target_prop: dict, r: int = None):
pr_attr_values = np.ravel(
self.dataset.unprivileged_protected_attributes + self.dataset.privileged_protected_attributes)
if set(list(target_prop.keys())) != set(pr_attr_values):
raise ValueError()
raise ValueError('Desired proportions must be specified for all values of the protected attributes!')

ranking = np.column_stack((self.dataset.scores, self.dataset.protected_attributes))
if r is None:
Expand All @@ -67,7 +67,7 @@ def infeasible_index(self, target_prop: dict, r: int = None):
k_viol.add(k-1)
return ii, list(k_viol)

def discounted_cum_gain(self, r: int = None, normalized=False):
def discounted_cum_gain(self, r: int = None, full_dataset: RegressionDataset=None, normalized=False):
"""
Discounted Cumulative Gain metric.
Expand All @@ -78,10 +78,18 @@ def discounted_cum_gain(self, r: int = None, normalized=False):
Returns:
The calculated DCG.
"""
if r is None:
r = np.ravel(self.dataset.scores).shape[0]
if r < 0:
raise ValueError(f'r must be >= 0, got {r}')
if normalized == True and full_dataset is None:
raise ValueError('`normalized` is set to True, but `full_dataset` is not specified')
if not isinstance(full_dataset, RegressionDataset) and not (full_dataset is None):
raise TypeError(f'`full_datset`: expected `RegressionDataset`, got {type(full_dataset)}')
scores = np.ravel(self.dataset.scores)[:r]
z = self._dcg(scores)
if normalized:
z /= self._dcg(np.sort(scores)[::-1][:r])
z /= self._dcg(np.sort(np.ravel(full_dataset.scores))[::-1][:r])
return z

def _dcg(self, scores):
Expand Down
48 changes: 24 additions & 24 deletions examples/demo_deterministic_reranking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -94,7 +94,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -167,7 +167,7 @@
"5 b 60"
]
},
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -196,7 +196,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -274,7 +274,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -295,7 +295,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -318,7 +318,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -391,7 +391,7 @@
"6 0.0 50.0"
]
},
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -429,7 +429,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -454,7 +454,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -480,7 +480,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -500,7 +500,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -720,7 +720,7 @@
"10563 1.0 1.0 1.000000 0.900 0.909091"
]
},
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -740,7 +740,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -960,7 +960,7 @@
"5951 0.0 0.0 0.482759 0.210526 0.848485"
]
},
"execution_count": 13,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -995,7 +995,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1004,7 +1004,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand All @@ -1013,7 +1013,7 @@
"(18, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])"
]
},
"execution_count": 15,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -1025,7 +1025,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand All @@ -1034,7 +1034,7 @@
"(9, [1, 3, 5, 7, 9, 11, 13, 15, 17])"
]
},
"execution_count": 16,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -1056,24 +1056,24 @@
"\n",
"where $score(j)$ denotes the score of the item at position $j$.\n",
"\n",
"This metric can also be normalized against the value of the \"perfect\" strictly score-based ordering, giving it a range from 0 to 1."
"Setting `normalized` to `True` normalizes the metric against the DCG of top `r` elements of the full dataset by score, allowing us to compare the fair ranking to a purely score-based one."
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Normalized DCG of fair ranking: 0.9473492630864591\n"
"Normalized DCG of fair ranking: 0.9773523486131529\n"
]
}
],
"source": [
"print(f'Normalized DCG of fair ranking: {m_fair.discounted_cum_gain(normalized=True, r=20)}')"
"print(f'Normalized DCG of fair ranking: {m_fair.discounted_cum_gain(normalized=True, full_dataset=dataset, r=20)}')"
]
},
{
Expand Down

0 comments on commit 6fef28d

Please sign in to comment.