Skip to content

Commit

Permalink
event studies with two-way-mundlak specification (#11)
Browse files Browse the repository at this point in the history
* query optimization + logging, mundlak event study wip

* pandas compression

* add multi-cohort

* revise multi-cohort specification

* add mundlak event study
  • Loading branch information
apoorvalal authored Aug 21, 2024
1 parent 2667779 commit 8c31f92
Show file tree
Hide file tree
Showing 3 changed files with 1,112 additions and 9 deletions.
26 changes: 24 additions & 2 deletions duckreg/duckreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
seed: int,
n_bootstraps: int = 100,
fitter="numpy",
keep_connection_open=False,
):
self.db_name = db_name
self.table_name = table_name
Expand All @@ -20,6 +21,7 @@ def __init__(
self.conn = duckdb.connect(db_name)
self.rng = np.random.default_rng(seed)
self.fitter = fitter
self.keep_connection_open = keep_connection_open

@abstractmethod
def prepare_data(self):
Expand Down Expand Up @@ -54,6 +56,7 @@ def fit(self):
self.point_estimate = self.estimate()
if self.n_bootstraps > 0:
self.vcov = self.bootstrap()
self.conn.close() if not self.keep_connection_open else None
return None
elif self.fitter == "feols":
fit = self.estimate_feols()
Expand All @@ -64,6 +67,7 @@ def fit(self):
fit.get_inference()
fit._vcov_type = "NP-Bootstrap"
fit._vcov_type_detail = "NP-Bootstrap"
self.conn.close() if not self.keep_connection_open else None
return fit

else:
Expand All @@ -73,17 +77,35 @@ def fit(self):
)
)

def summary(self):
def summary(self) -> dict:
"""Summary of regression
Returns:
dict
"""
if self.n_bootstraps > 0:
return {
"point_estimate": self.point_estimate,
"standard_error": np.sqrt(np.diag(self.vcov)),
}
return {"point_estimate": self.point_estimate}

def queries(self) -> dict:
"""Collect all query methods in the class
def wls(X: np.ndarray, y: np.ndarray, n: np.ndarray) -> np.ndarray:
Returns:
dict: Dictionary of query methods
"""
self._query_names = [x for x in dir(self) if "query" in x]
self.queries = {
k: getattr(self, self._query_names[c])
for c, k in enumerate(self._query_names)
}
return self.queries


def wls(X: np.ndarray, y: np.ndarray, n: np.ndarray) -> np.ndarray:
"""Weighted least squares with frequency weights"""
N = np.sqrt(n)
N = N.reshape(-1, 1) if N.ndim == 1 else N
Xn = X * N
Expand Down
Loading

0 comments on commit 8c31f92

Please sign in to comment.