From 9848701bc35ef06964d14dcd7f0c7988df878343 Mon Sep 17 00:00:00 2001 From: Mohammad Hossein Forouhesh Tehrani <17898264+MohammadForouhesh@users.noreply.github.com> Date: Mon, 14 Mar 2022 13:40:19 +0330 Subject: [PATCH] #1 --- README.md | 2 +- tracking_policy_agendas/api.py | 18 ++++++++---------- .../classifiers/meta_clf.py | 4 ++-- tracking_policy_agendas/word2vec/w2v_emb.py | 3 ++- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 05e5b8c..3715160 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ | | | Lasso | 0.700 | 0.764 | 0.646 | 0.738 | 0.742 ## Implementation details: -[![](https://mermaid.ink/img/pako:eNp10s1uwjAMAOBXiXLaJHqBWw-TBi1_kzgMxAZ0B9MYiJYmXZJOQpR3X9qmG2ysp7r5bNluTjRVDGlI9xryA1lEiSTuedzsINxBkAqebxVoFghuLBkonRfmjQTBQ_miNOsuMS1J_87jA6bvgfkoQCOpjkmcbZExLvf3Tdl-lUnW3TZB7Uksqwb0Nej9C9a9WqxO38CQgQBj-I6nYLmS5waObk8QgQWDth6BlPNccFuS8W9LFhq4JHMHfbUrP_nr0dX-4WM_RiKbeFXF0ebGml5HfaVMm1e7-JaL3YyWpzO8osNbdAb8E_twRONpVHcz9SszCDo9BLkoDFmC4Oxya3FDm2B4GUz8r2sWMZE71GV7Nq0_PvlmfBvPaAphqx5oh2aoM-DM3bNTlZJQe8AMExq6V4mF1SASmsizo0XuOsKYcas0dfWEwQ6Fwqr5UaY0tLrAFkUc3LXNvDp_AYJt5H0)](https://mermaid-js.github.io/mermaid-live-editor/edit/#pako:eNp10s1uwjAMAOBXiXLaJHqBWw-TBi1_kzgMxAZ0B9MYiJYmXZJOQpR3X9qmG2ysp7r5bNluTjRVDGlI9xryA1lEiSTuedzsINxBkAqebxVoFghuLBkonRfmjQTBQ_miNOsuMS1J_87jA6bvgfkoQCOpjkmcbZExLvf3Tdl-lUnW3TZB7Uksqwb0Nej9C9a9WqxO38CQgQBj-I6nYLmS5waObk8QgQWDth6BlPNccFuS8W9LFhq4JHMHfbUrP_nr0dX-4WM_RiKbeFXF0ebGml5HfaVMm1e7-JaL3YyWpzO8osNbdAb8E_twRONpVHcz9SszCDo9BLkoDFmC4Oxya3FDm2B4GUz8r2sWMZE71GV7Nq0_PvlmfBvPaAphqx5oh2aoM-DM3bNTlZJQe8AMExq6V4mF1SASmsizo0XuOsKYcas0dfWEwQ6Fwqr5UaY0tLrAFkUc3LXNvDp_AYJt5H0) +![mermaid_kroki)](https://kroki.io/mermaid/svg/eNp1kcFOwzAMhu97Ch_hkMt244BE127rkHZg02CbOHiJ20aEpiQpEqK8O10apg5KTnb-z85vJzdYFbCJR9Ceu0OGNxkyrmR11GgEU9I6mGpT1fYZGLttHrUR4y3xBqKrABfEX5h9q9EQnGRIXo8khCzza981OhXCfvzD6xySkmtB5kKf_KfvJx7YfZ51C1OF1spMcnRSl1-emw-7j9GhJeftQ7OulHQNLH6zsDEoS1i3YNfsAk__4tS2PtOLMMLIZ7tTFh8G1vM0j7S2ochjyRCWtNM5yVfUJ2dD5ArlO0X4QbYjY29kGVZlCQ0vWKVqC1tUUvS2lXSkj2e9OA2_1c2flhmZJkhLf3cffAQHD2Rr5drnvwEH6rEj) ## Reproducing Results for XGB diff --git a/tracking_policy_agendas/api.py b/tracking_policy_agendas/api.py index 1d55bc5..96e0130 100644 --- a/tracking_policy_agendas/api.py +++ b/tracking_policy_agendas/api.py @@ -27,10 +27,10 @@ 'pa_vaccine': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/bin/pa_vaccine.zip', 'lasso_vaccine': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/bin/lasso_vaccine.zip', 'gnb_vaccine': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/bin/gnb_vaccine.zip', - 'xgb_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/xgb_jcpoa', - 'pa_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/pa_jcpoa', - 'lasso_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/lasso_jcpoa', - 'gnb_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/gnb_jcpoa'} + 'xgb_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/xgb_jcpoa.zip', + 'pa_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/pa_jcpoa.zip', + 'lasso_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/lasso_jcpoa.zip', + 'gnb_jcpoa': 'https://github.com/MohammadForouhesh/tracking-policy-agendas/releases/download/v1.0.0/gnb_jcpoa.zip'} def downloader(path: str, save_path: str) -> Union[int, None]: @@ -40,12 +40,10 @@ def downloader(path: str, save_path: str) -> Union[int, None]: :param save_path: The intended storage path. :return: If the file exists, it returns 0 (int), otherwise nothing would be returned. """ - try: - model_bin = requests.get(path, allow_redirects=True) - with zipfile.ZipFile(BytesIO(model_bin.content)) as resource: - resource.extractall(save_path) - except Exception: - raise Exception('not a proper webpage') + assert path in http_dict.values(), f'{path[-path[::-1].find("/"):]} is not a supported models, use: \n{http_dict}' + model_bin = requests.get(path, allow_redirects=True) + with zipfile.ZipFile(BytesIO(model_bin.content)) as resource: + resource.extractall(save_path) return 0 diff --git a/tracking_policy_agendas/classifiers/meta_clf.py b/tracking_policy_agendas/classifiers/meta_clf.py index dd2e80d..af60306 100644 --- a/tracking_policy_agendas/classifiers/meta_clf.py +++ b/tracking_policy_agendas/classifiers/meta_clf.py @@ -26,7 +26,7 @@ class MetaClf: - def __init__(self, classifier_instance, text_array: list = None, embedding_doc: list = None, labels: list = None, load_path: str = None): + def __init__(self, classifier_instance, text_array: List[str] = None, embedding_doc: list = None, labels: list = None, load_path: str = None): if not isinstance(text_array, pd.Series): text_array = pd.Series(text_array) self.clf = classifier_instance @@ -49,7 +49,7 @@ def __init__(self, classifier_instance, text_array: list = None, embedding_doc: self.scaler = self.prep_scaler(encoded) self.encoded_input = self.scaler.transform(encoded) - def prep_scaler(self, encoded: List[int]) -> MinMaxScaler: + def prep_scaler(self, encoded: List[np.ndarray]) -> MinMaxScaler: """ Fitting a Min-Max Scaler to use in the pipeline :param encoded: An array of numbers. diff --git a/tracking_policy_agendas/word2vec/w2v_emb.py b/tracking_policy_agendas/word2vec/w2v_emb.py index f6649dd..2c72600 100644 --- a/tracking_policy_agendas/word2vec/w2v_emb.py +++ b/tracking_policy_agendas/word2vec/w2v_emb.py @@ -13,8 +13,9 @@ import gensim import numpy as np import pandas as pd -from gensim import utils from typing import List, Generator + +from gensim import utils from sklearn.pipeline import Pipeline from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from .w2v_corpus import W2VCorpus