Skip to content

Commit

Permalink
query optimization + logging, mundlak event study wip
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvalal committed Aug 19, 2024
1 parent 2667779 commit 35516bc
Show file tree
Hide file tree
Showing 3 changed files with 710 additions and 2 deletions.
26 changes: 24 additions & 2 deletions duckreg/duckreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
seed: int,
n_bootstraps: int = 100,
fitter="numpy",
keep_connection_open=False,
):
self.db_name = db_name
self.table_name = table_name
Expand All @@ -20,6 +21,7 @@ def __init__(
self.conn = duckdb.connect(db_name)
self.rng = np.random.default_rng(seed)
self.fitter = fitter
self.keep_connection_open = keep_connection_open

@abstractmethod
def prepare_data(self):
Expand Down Expand Up @@ -54,6 +56,7 @@ def fit(self):
self.point_estimate = self.estimate()
if self.n_bootstraps > 0:
self.vcov = self.bootstrap()
self.conn.close() if not self.keep_connection_open else None
return None
elif self.fitter == "feols":
fit = self.estimate_feols()
Expand All @@ -64,6 +67,7 @@ def fit(self):
fit.get_inference()
fit._vcov_type = "NP-Bootstrap"
fit._vcov_type_detail = "NP-Bootstrap"
self.conn.close() if not self.keep_connection_open else None
return fit

else:
Expand All @@ -73,17 +77,35 @@ def fit(self):
)
)

def summary(self):
def summary(self) -> dict:
"""Summary of regression
Returns:
dict
"""
if self.n_bootstraps > 0:
return {
"point_estimate": self.point_estimate,
"standard_error": np.sqrt(np.diag(self.vcov)),
}
return {"point_estimate": self.point_estimate}

def queries(self) -> dict:
"""Collect all query methods in the class
def wls(X: np.ndarray, y: np.ndarray, n: np.ndarray) -> np.ndarray:
Returns:
dict: Dictionary of query methods
"""
self._query_names = [x for x in dir(self) if "query" in x]
self.queries = {
k: getattr(self, self._query_names[c])
for c, k in enumerate(self._query_names)
}
return self.queries


def wls(X: np.ndarray, y: np.ndarray, n: np.ndarray) -> np.ndarray:
"""Weighted least squares with frequency weights"""
N = np.sqrt(n)
N = N.reshape(-1, 1) if N.ndim == 1 else N
Xn = X * N
Expand Down
8 changes: 8 additions & 0 deletions duckreg/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ def __init__(
cluster_col: str,
seed: int,
n_bootstraps: int = 100,
event_study: bool = False,
rowid_col: str = "rowid",
fitter: str = "numpy",
**kwargs,
):
super().__init__(
db_name=db_name,
table_name=table_name,
seed=seed,
n_bootstraps=n_bootstraps,
fitter=fitter,
**kwargs,
)
self.formula = formula
self.cluster_col = cluster_col
Expand Down Expand Up @@ -234,12 +237,14 @@ def __init__(
time_col: str = None,
n_bootstraps: int = 100,
cluster_col: str = None,
**kwargs,
):
super().__init__(
db_name=db_name,
table_name=table_name,
seed=seed,
n_bootstraps=n_bootstraps,
**kwargs,
)
self.outcome_var = outcome_var
self.covariates = covariates
Expand Down Expand Up @@ -299,6 +304,7 @@ def compress_data(self):
{', ' + ', '.join([f'avg_{cov}_time' for cov in self.covariates]) if self.time_col is not None else ''}
"""
self.df_compressed = self.conn.execute(self.compress_query).fetchdf()

self.df_compressed[f"mean_{self.outcome_var}"] = (
self.df_compressed[f"sum_{self.outcome_var}"] / self.df_compressed["count"]
)
Expand Down Expand Up @@ -416,12 +422,14 @@ def __init__(
seed: int,
n_bootstraps: int = 100,
cluster_col: str = None,
**kwargs,
):
super().__init__(
db_name=db_name,
table_name=table_name,
seed=seed,
n_bootstraps=n_bootstraps,
**kwargs
)
self.outcome_var = outcome_var
self.treatment_var = treatment_var
Expand Down
Loading

0 comments on commit 35516bc

Please sign in to comment.