From 9848701bc35ef06964d14dcd7f0c7988df878343 Mon Sep 17 00:00:00 2001
From: Mohammad Hossein Forouhesh Tehrani
<17898264+MohammadForouhesh@users.noreply.github.com>
Date: Mon, 14 Mar 2022 13:40:19 +0330
Subject: [PATCH] #1
---
README.md | 2 +-
tracking_policy_agendas/api.py | 18 ++++++++----------
.../classifiers/meta_clf.py | 4 ++--
tracking_policy_agendas/word2vec/w2v_emb.py | 3 ++-
4 files changed, 13 insertions(+), 14 deletions(-)
diff --git a/README.md b/README.md
index 05e5b8c..3715160 100644
--- a/README.md
+++ b/README.md
@@ -114,7 +114,7 @@
| | | Lasso | 0.700 | 0.764 | 0.646 | 0.738 | 0.742
## Implementation details:
-[![](https://mermaid.ink/img/pako:eNp10s1uwjAMAOBXiXLaJHqBWw-TBi1_kzgMxAZ0B9MYiJYmXZJOQpR3X9qmG2ysp7r5bNluTjRVDGlI9xryA1lEiSTuedzsINxBkAqebxVoFghuLBkonRfmjQTBQ_miNOsuMS1J_87jA6bvgfkoQCOpjkmcbZExLvf3Tdl-lUnW3TZB7Uksqwb0Nej9C9a9WqxO38CQgQBj-I6nYLmS5waObk8QgQWDth6BlPNccFuS8W9LFhq4JHMHfbUrP_nr0dX-4WM_RiKbeFXF0ebGml5HfaVMm1e7-JaL3YyWpzO8osNbdAb8E_twRONpVHcz9SszCDo9BLkoDFmC4Oxya3FDm2B4GUz8r2sWMZE71GV7Nq0_PvlmfBvPaAphqx5oh2aoM-DM3bNTlZJQe8AMExq6V4mF1SASmsizo0XuOsKYcas0dfWEwQ6Fwqr5UaY0tLrAFkUc3LXNvDp_AYJt5H0)](https://mermaid-js.github.io/mermaid-live-editor/edit/#pako:eNp10s1uwjAMAOBXiXLaJHqBWw-TBi1_kzgMxAZ0B9MYiJYmXZJOQpR3X9qmG2ysp7r5bNluTjRVDGlI9xryA1lEiSTuedzsINxBkAqebxVoFghuLBkonRfmjQTBQ_miNOsuMS1J_87jA6bvgfkoQCOpjkmcbZExLvf3Tdl-lUnW3TZB7Uksqwb0Nej9C9a9WqxO38CQgQBj-I6nYLmS5waObk8QgQWDth6BlPNccFuS8W9LFhq4JHMHfbUrP_nr0dX-4WM_RiKbeFXF0ebGml5HfaVMm1e7-JaL3YyWpzO8osNbdAb8E_twRONpVHcz9SszCDo9BLkoDFmC4Oxya3FDm2B4GUz8r2sWMZE71GV7Nq0_PvlmfBvPaAphqx5oh2aoM-DM3bNTlZJQe8AMExq6V4mF1SASmsizo0XuOsKYcas0dfWEwQ6Fwqr5UaY0tLrAFkUc3LXNvDp_AYJt5H0)
+![mermaid_kroki)](https://kroki.io/mermaid/svg/eNp1kcFOwzAMhu97Ch_hkMt244BE127rkHZg02CbOHiJ20aEpiQpEqK8O10apg5KTnb-z85vJzdYFbCJR9Ceu0OGNxkyrmR11GgEU9I6mGpT1fYZGLttHrUR4y3xBqKrABfEX5h9q9EQnGRIXo8khCzza981OhXCfvzD6xySkmtB5kKf_KfvJx7YfZ51C1OF1spMcnRSl1-emw-7j9GhJeftQ7OulHQNLH6zsDEoS1i3YNfsAk__4tS2PtOLMMLIZ7tTFh8G1vM0j7S2ochjyRCWtNM5yVfUJ2dD5ArlO0X4QbYjY29kGVZlCQ0vWKVqC1tUUvS2lXSkj2e9OA2_1c2flhmZJkhLf3cffAQHD2Rr5drnvwEH6rEj)
## Reproducing Results for XGB
diff --git a/tracking_policy_agendas/api.py b/tracking_policy_agendas/api.py
index 1d55bc5..96e0130 100644
--- a/tracking_policy_agendas/api.py
+++ b/tracking_policy_agendas/api.py
@@ -27,10 +27,10 @@
'pa_vaccine': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/bin/pa_vaccine.zip',
'lasso_vaccine': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/bin/lasso_vaccine.zip',
'gnb_vaccine': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/bin/gnb_vaccine.zip',
- 'xgb_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/xgb_jcpoa',
- 'pa_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/pa_jcpoa',
- 'lasso_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/lasso_jcpoa',
- 'gnb_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/gnb_jcpoa'}
+ 'xgb_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/xgb_jcpoa.zip',
+ 'pa_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/pa_jcpoa.zip',
+ 'lasso_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/lasso_jcpoa.zip',
+ 'gnb_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/gnb_jcpoa.zip'}
def downloader(path: str, save_path: str) -> Union[int, None]:
@@ -40,12 +40,10 @@ def downloader(path: str, save_path: str) -> Union[int, None]:
:param save_path: The intended storage path.
:return: If the file exists, it returns 0 (int), otherwise nothing would be returned.
"""
- try:
- model_bin = requests.get(path, allow_redirects=True)
- with zipfile.ZipFile(BytesIO(model_bin.content)) as resource:
- resource.extractall(save_path)
- except Exception:
- raise Exception('not a proper webpage')
+ assert path in http_dict.values(), f'{path[-path[::-1].find("/"):]} is not a supported models, use: \n{http_dict}'
+ model_bin = requests.get(path, allow_redirects=True)
+ with zipfile.ZipFile(BytesIO(model_bin.content)) as resource:
+ resource.extractall(save_path)
return 0
diff --git a/tracking_policy_agendas/classifiers/meta_clf.py b/tracking_policy_agendas/classifiers/meta_clf.py
index dd2e80d..af60306 100644
--- a/tracking_policy_agendas/classifiers/meta_clf.py
+++ b/tracking_policy_agendas/classifiers/meta_clf.py
@@ -26,7 +26,7 @@
class MetaClf:
- def __init__(self, classifier_instance, text_array: list = None, embedding_doc: list = None, labels: list = None, load_path: str = None):
+ def __init__(self, classifier_instance, text_array: List[str] = None, embedding_doc: list = None, labels: list = None, load_path: str = None):
if not isinstance(text_array, pd.Series): text_array = pd.Series(text_array)
self.clf = classifier_instance
@@ -49,7 +49,7 @@ def __init__(self, classifier_instance, text_array: list = None, embedding_doc:
self.scaler = self.prep_scaler(encoded)
self.encoded_input = self.scaler.transform(encoded)
- def prep_scaler(self, encoded: List[int]) -> MinMaxScaler:
+ def prep_scaler(self, encoded: List[np.ndarray]) -> MinMaxScaler:
"""
Fitting a Min-Max Scaler to use in the pipeline
:param encoded: An array of numbers.
diff --git a/tracking_policy_agendas/word2vec/w2v_emb.py b/tracking_policy_agendas/word2vec/w2v_emb.py
index f6649dd..2c72600 100644
--- a/tracking_policy_agendas/word2vec/w2v_emb.py
+++ b/tracking_policy_agendas/word2vec/w2v_emb.py
@@ -13,8 +13,9 @@
import gensim
import numpy as np
import pandas as pd
-from gensim import utils
from typing import List, Generator
+
+from gensim import utils
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from .w2v_corpus import W2VCorpus