Skip to content

Commit

Permalink
add feature
Browse files Browse the repository at this point in the history
  • Loading branch information
miaohancheng committed Nov 24, 2024
1 parent 3d84f0d commit c6cfb44
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
8 changes: 4 additions & 4 deletions Example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
# for reproducibility
np.random.seed(20240919)

# m.fit_scores(balance=True, nmodels=10,n_jobs=3,model_type='knn')
m.fit_scores(balance=True, nmodels=10, n_jobs=3, model_type='tree')
# m.fit_scores(balance=True, nmodels=10,n_jobs=3,model_type='linear')
m.fit_scores(balance=True, nmodels=10,n_jobs=3,model_type='knn')
# m.fit_scores(balance=True, nmodels=10, n_jobs=3, model_type='tree', max_iter=100)
# m.fit_scores(balance=True, nmodels=10,n_jobs=3,model_type='linear', max_iter=200)


m.predict_scores()

m.plot_scores()
m.tune_threshold(method='min')
m.tune_threshold(method='random')
m.match(method="min", nmatches=1, threshold=0.0001, replacement=False)

m.record_frequency()
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ With pysmatch, you can choose between linear models (logistic regression), tree-
np.random.seed(42)

# Fit propensity score models
m.fit_scores(balance=True, nmodels=100, n_jobs=5, model_type='linear')# model_type='linear', model_type='tree'
m.fit_scores(balance=True, nmodels=10,n_jobs=3,model_type='knn')
# m.fit_scores(balance=True, nmodels=10, n_jobs=3, model_type='tree', max_iter=100)
# m.fit_scores(balance=True, nmodels=10,n_jobs=3,model_type='linear', max_iter=200)
```

Output:
Expand Down
4 changes: 3 additions & 1 deletion README_CHINESE.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ n minority: 1219
np.random.seed(42)

# Fit propensity score models
m.fit_scores(balance=True, nmodels=100, n_jobs=5, model_type='linear') # model_type='knn',model_type='tree'
m.fit_scores(balance=True, nmodels=10,n_jobs=3,model_type='knn')
# m.fit_scores(balance=True, nmodels=10, n_jobs=3, model_type='tree', max_iter=100)
# m.fit_scores(balance=True, nmodels=10,n_jobs=3,model_type='linear', max_iter=200)
```

输出:
Expand Down
29 changes: 18 additions & 11 deletions pysmatch/Matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')


class Matcher:
"""
Matcher Class -- Match data for an observational study.
Expand Down Expand Up @@ -78,6 +80,7 @@ def __init__(self, test, control, yvar, formula=None, exclude=None):
logging.info(f'Formula:{yvar} ~ {"+".join(self.xvars)}')
logging.info(f'n majority:{len(self.data[self.data[yvar] == self.majority])}')
logging.info(f'n minority:{len(self.data[self.data[yvar] == self.minority])}')

def preprocess_data(self, X, fit_scaler=False, index=None):
X_encoded = pd.get_dummies(X)

Expand All @@ -97,7 +100,8 @@ def preprocess_data(self, X, fit_scaler=False, index=None):
X_scaled = scaler.transform(X_encoded)

return X_scaled
def fit_model(self, index, X, y, model_type, balance):

def fit_model(self, index, X, y, model_type, balance, max_iter=100):
X_train, _, y_train, _ = train_test_split(X, y, train_size=0.7, random_state=index)

if balance:
Expand All @@ -112,12 +116,12 @@ def fit_model(self, index, X, y, model_type, balance):
X_processed = X_resampled

if model_type == 'linear':
model = LogisticRegression(max_iter=100)
model = LogisticRegression(max_iter=max_iter)
model.fit(X_processed, y_resampled.iloc[:, 0])
accuracy = model.score(X_processed, y_resampled)
elif model_type == 'tree':
cat_features_indices = np.where(X_resampled.dtypes == 'object')[0]
model = CatBoostClassifier(iterations=100, depth=6,
model = CatBoostClassifier(iterations=max_iter, depth=6,
eval_metric='AUC', l2_leaf_reg=3,
cat_features=cat_features_indices,
learning_rate=0.02, loss_function='Logloss',
Expand All @@ -133,7 +137,7 @@ def fit_model(self, index, X, y, model_type, balance):
logging.info(f"Model {index + 1}/{self.nmodels} trained. Accuracy: {accuracy:.2%}")
return {'model': model, 'accuracy': accuracy}

def fit_scores(self, balance=True, nmodels=None, n_jobs=1, model_type='linear'):
def fit_scores(self, balance=True, nmodels=None, n_jobs=1, model_type='linear', max_iter=100):
self.models, self.model_accuracy = [], []
self.model_type = model_type
num_cores = mp.cpu_count()
Expand All @@ -148,16 +152,17 @@ def fit_scores(self, balance=True, nmodels=None, n_jobs=1, model_type='linear'):
if balance:
with Pool(min(num_cores, n_jobs)) as pool:
results = pool.starmap(self.fit_model,
[(i, self.X, self.y, self.model_type, balance) for i in range(nmodels)])
[(i, self.X, self.y, self.model_type, balance, max_iter) for i in
range(nmodels)])
for res in results:
self.models.append(res['model'])
self.model_accuracy.append(res['accuracy'])
logging.info(f"Average Accuracy:{np.mean(self.model_accuracy):.2%} ")
else:
result = self.fit_model(0, self.X, self.y, self.model_type, balance)
result = self.fit_model(0, self.X, self.y, self.model_type, balance, max_iter)
self.models.append(result['model'])
self.model_accuracy.append(result['accuracy'])
logging.info(f"Accuracy:{round(self.model_accuracy[0] * 100,2)}%")
logging.info(f"Accuracy:{round(self.model_accuracy[0] * 100, 2)}%")

def predict_scores(self):
"""
Expand Down Expand Up @@ -252,8 +257,10 @@ def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10, replacem
if len(matches) == 0: # Check again if there are matches after filtering
continue

select = nmatches if method != 'random' else np.random.choice(range(1, max_rand + 1), 1) # Select number of matches
chosen = np.random.choice(matches.index, min(select, nmatches), replace=False) # Choose the indices for matching
select = nmatches if method != 'random' else np.random.choice(range(1, max_rand + 1),
1) # Select number of matches
chosen = np.random.choice(matches.index, min(select, nmatches),
replace=False) # Choose the indices for matching

if not replacement: # If no replacement, update the used indices
used_indices.update(chosen)
Expand Down Expand Up @@ -309,7 +316,7 @@ def prop_test(self, col):
else:
logging.info(f"{col} is a continuous variable")

def compare_continuous(self, save=False, return_table=False,plot_result = True):
def compare_continuous(self, save=False, return_table=False, plot_result=True):
"""
Plots the ECDFs for continuous features before and
after matching. Each chart title contains test results
Expand Down Expand Up @@ -415,7 +422,7 @@ def compare_continuous(self, save=False, return_table=False,plot_result = True):

return pd.DataFrame(test_results)[var_order] if return_table else None

def compare_categorical(self, return_table=False,plot_result=True):
def compare_categorical(self, return_table=False, plot_result=True):
"""
Plots the proportional differences of each enumerated
discete column for test and control.
Expand Down

0 comments on commit c6cfb44

Please sign in to comment.