Skip to content

Commit

Permalink
Fix rss calculation (#7)
Browse files Browse the repository at this point in the history
* Fix model selection

* Run example clean

* Normalize rss using mean weights in gamlss update

* Update mean_of_weights in update

* Fix init of mean_of_weights_inner object

* Fix rss calculation

* Bump version number

---------

Co-authored-by: simon-hirsch <simon.hirsch@stud.uni-due.de>
  • Loading branch information
BerriJ and simon-hirsch authored Jul 18, 2024
1 parent 10cd702 commit a992d43
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "rolch"
version = "0.1.3"
version = "0.1.4"
authors = [
{name="Simon Hirsch", email="simon.hirsch@stud.uni-due.de"},
{name="Jonathan Berrisch", email="jonathan.berrisch@uni-due.de"},
Expand Down
13 changes: 10 additions & 3 deletions src/rolch/online_gamlss.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def update_beta(

rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.sum_of_weights[param])
+ (1 - self.forget) * (self.rss[param] * self.mean_of_weights[param])
) / (self.mean_of_weights[param] * (1 - self.forget) + w)

elif (self.method == "lasso") & self.intercept_only[param]:
Expand All @@ -263,7 +263,8 @@ def update_beta(

rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.sum_of_weights[param])

+ (1 - self.forget) * (self.rss[param] * self.mean_of_weights[param])
) / (self.mean_of_weights[param] * (1 - self.forget) + w)

elif self.method == "lasso":
Expand Down Expand Up @@ -292,7 +293,8 @@ def update_beta(

rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.sum_of_weights[param])

+ (1 - self.forget) * (self.rss[param] * self.mean_of_weights[param])
) / (self.mean_of_weights[param] * (1 - self.forget) + w)

model_params_n = np.sum(np.isclose(beta_path, 0), axis=1)
Expand Down Expand Up @@ -510,6 +512,7 @@ def update(
self.y_gram_inner = copy.copy(self.y_gram)
self.rss_inner = copy.copy(self.rss)
self.sum_of_weights_inner = copy.copy(self.sum_of_weights)
self.mean_of_weights_inner = copy.copy(self.mean_of_weights)

self.lambda_max_inner = copy.copy(self.lambda_max)
self.lambda_path_inner = copy.copy(self.lambda_path)
Expand All @@ -536,6 +539,7 @@ def update(
self.x_gram = copy.copy(self.x_gram_inner)
self.y_gram = copy.copy(self.y_gram_inner)
self.sum_of_weights = copy.copy(self.sum_of_weights_inner)
self.mean_of_weights = copy.copy(self.mean_of_weights_inner)
self.rss = copy.copy(self.rss_inner)

self.lambda_max = copy.copy(self.lambda_max_inner)
Expand Down Expand Up @@ -880,6 +884,9 @@ def _inner_update(
self.sum_of_weights_inner[param] = (
np.sum(w * wt) + (1 - self.forget) * self.sum_of_weights[param]
)
self.mean_of_weights_inner[param] = (
self.sum_of_weights_inner[param] / self.n_training
)

olddv = dv

Expand Down

0 comments on commit a992d43

Please sign in to comment.