Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Questionnaire format and Parallelisation #83

Merged
merged 9 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions radarpipeline/io/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
from radarpipeline.common.utils import reparent
import posixpath

from contextlib import contextmanager
from stat import S_IMODE, S_ISDIR, S_ISREG
Expand All @@ -15,6 +16,81 @@
logger = logging.getLogger(__name__)


class WTCallbacks(object):
'''an object to house the callbacks, used internally'''
def __init__(self):
'''set instance vars'''
self._flist = []
self._dlist = []
self._ulist = []

def file_cb(self, pathname):
'''called for regular files, appends pathname to .flist

:param str pathname: file path
'''
self._flist.append(pathname)

def dir_cb(self, pathname):
'''called for directories, appends pathname to .dlist

:param str pathname: directory path
'''
self._dlist.append(pathname)

def unk_cb(self, pathname):
'''called for unknown file types, appends pathname to .ulist

:param str pathname: unknown entity path
'''
self._ulist.append(pathname)

@property
def flist(self):
'''return a sorted list of files currently traversed

:getter: returns the list
:setter: sets the list
:type: list
'''
return sorted(self._flist)

@flist.setter
def flist(self, val):
'''setter for _flist '''
self._flist = val

@property
def dlist(self):
'''return a sorted list of directories currently traversed

:getter: returns the list
:setter: sets the list
:type: list
'''
return sorted(self._dlist)

@dlist.setter
def dlist(self, val):
'''setter for _dlist '''
self._dlist = val

@property
def ulist(self):
'''return a sorted list of unknown entities currently traversed

:getter: returns the list
:setter: sets the list
:type: list
'''
return sorted(self._ulist)

@ulist.setter
def ulist(self, val):
'''setter for _ulist '''
self._ulist = val


class ConnectionException(Exception):
"""Exception raised for connection problems

Expand Down Expand Up @@ -238,6 +314,87 @@ def get_d(self, remotedir, localdir, preserve_mtime=False):
self.get(rname, reparent(localdir, rname),
preserve_mtime=preserve_mtime)

def get_r(self, remotedir, localdir, preserve_mtime=False):
"""recursively copy remotedir structure to localdir

:param str remotedir: the remote directory to copy from
:param str localdir: the local directory to copy to
:param bool preserve_mtime: *Default: False* -
preserve modification time on files

:returns: None

:raises:

"""
self._sftp_connect()
wtcb = WTCallbacks()
self.walktree(remotedir, wtcb.file_cb, wtcb.dir_cb, wtcb.unk_cb)
# handle directories we recursed through
for dname in wtcb.dlist:
for subdir in path_advance(dname):
try:
os.mkdir(reparent(localdir, subdir))
# force result to a list for setter,
wtcb.dlist = wtcb.dlist + [subdir, ]
except OSError: # dir exists
pass

for fname in wtcb.flist:
# they may have told us to start down farther, so we may not have
# recursed through some, ensure local dir structure matches
head, _ = os.path.split(fname)
if head not in wtcb.dlist:
for subdir in path_advance(head):
if subdir not in wtcb.dlist and subdir != '.':
os.mkdir(reparent(localdir, subdir))
wtcb.dlist = wtcb.dlist + [subdir, ]

self.get(fname,
reparent(localdir, fname),
preserve_mtime=preserve_mtime)

def walktree(self, remotepath, fcallback, dcallback, ucallback,
recurse=True):
'''recursively descend, depth first, the directory tree rooted at
remotepath, calling discreet callback functions for each regular file,
directory and unknown file type.

:param str remotepath:
root of remote directory to descend, use '.' to start at
:attr:`.pwd`
:param callable fcallback:
callback function to invoke for a regular file.
(form: ``func(str)``)
:param callable dcallback:
callback function to invoke for a directory. (form: ``func(str)``)
:param callable ucallback:
callback function to invoke for an unknown file type.
(form: ``func(str)``)
:param bool recurse: *Default: True* - should it recurse

:returns: None

:raises:

'''
self._sftp_connect()
for entry in self.listdir(remotepath):
pathname = posixpath.join(remotepath, entry)
mode = self._sftp.stat(pathname).st_mode
if S_ISDIR(mode):
# It's a directory, call the dcallback function
dcallback(pathname)
if recurse:
# now, recurse into it
self.walktree(pathname, fcallback, dcallback, ucallback)
elif S_ISREG(mode):
# It's a file, call the fcallback function
fcallback(pathname)
else:
# Unknown file type
ucallback(pathname)

def close(self):
"""Closes the connection and cleans up."""
# Close SFTP Connection.
Expand All @@ -248,3 +405,29 @@ def close(self):
if self._transport:
self._transport.close()
self._transport = None


def path_advance(thepath, sep=os.sep):
'''generator to iterate over a file path forwards

:param str thepath: the path to navigate forwards
:param str sep: *Default: os.sep* - the path separator to use

:returns: (iter)able of strings

'''
# handle a direct path
pre = ''
if thepath[0] == sep:
pre = sep
curpath = ''
parts = thepath.split(sep)
if pre:
if parts[0]:
parts[0] = pre + parts[0]
else:
parts[1] = pre + parts[1]
for part in parts:
curpath = os.path.join(curpath, part)
if curpath:
yield curpath
12 changes: 4 additions & 8 deletions radarpipeline/io/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,10 @@ def _fetch_data(self, root_path, sftp_source_path, included_var_cat, uid):
root_path, dir_path, src_file
)
):
os.makedirs(
os.path.join(
root_path, dir_path,
src_file),
exist_ok=True)
sftp.get_d(src_file,
sftp.get_r(src_file,
os.path.join(
root_path,
dir_path,
src_file),
dir_path),
preserve_mtime=True)
except FileNotFoundError:
logger.warning("Folder not found: " + dir_path
Expand All @@ -116,6 +110,8 @@ def _is_src_in_category(self, src, categories):
if categories == "all":
return True
for category in categories:
if "/" in category:
category = category.split("/")[0]
if src[:len(category)] == category:
return True
return False
Expand Down
42 changes: 29 additions & 13 deletions radarpipeline/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gzip
import re
from typing import Any, Dict, List, Optional, Union
import concurrent.futures

import pyspark.sql as ps
from pyspark.sql import SparkSession, DataFrame
Expand Down Expand Up @@ -127,6 +128,9 @@ def __init__(self, spark_session: ps.SparkSession,
# RADAR_NEW: uid/variable/yyyymm/yyyymmdd.csv.gz
"RADAR_NEW": re.compile(r"""^[\w-]+/([\w]+)/
[\d]+/([\d]+.csv.gz$|schema-\1.json$)""", re.X),
# RADAR_OLD: uid/questionnaire/QuestionaireName/yyyymm/yyyymmdd.csv.gz
"RADAR_QUES": re.compile(r"""^[\w-]+/([\w]+)/([\w]+)/
[\d]+/([\d]+.csv.gz$|schema-\1.json$)""", re.X)
}
self.required_data = required_data
self.df_type = df_type
Expand All @@ -141,16 +145,15 @@ def _get_source_type(self, source_path):
"""
Returns the source type of the data
"""
files = [y for x in os.walk(source_path) for y in
glob(os.path.join(x[0], '*.*'))]
if source_path[-1] != "/":
source_path = source_path + "/"
for key, value in self.source_formats.items():
file = files[0]
file_format = file.replace(source_path, "")
if re.match(value, file_format):
return key
raise ValueError("Source type not recognized")
for x in os.walk(source_path, topdown=False):
for file in glob(os.path.join(x[0], '*.*')):
for key, value in self.source_formats.items():
file_format = file.replace(source_path, "")
if re.match(value, file_format):
return key
raise ValueError("Source path not recognized")

def read_data(self) -> RadarData:
"""
Expand All @@ -172,7 +175,7 @@ def read_data(self) -> RadarData:
logger.info("Reading data from old RADAR format")
radar_data, user_data_dict = self._read_data_from_old_format(
source_path_item, user_data_dict)
elif source_type == "RADAR_NEW":
elif source_type == "RADAR_NEW" or source_type == "RADAR_QUES":
logger.info("Reading data from new RADAR format")
radar_data, user_data_dict = self._read_data_from_new_format(
source_path_item, user_data_dict)
Expand Down Expand Up @@ -269,7 +272,8 @@ def _read_data_from_old_format(self, source_path: str, user_data_dict: dict):
uids = self._remove_hidden_dirs(uids)
if self.user_sampler is not None:
uids = self.user_sampler.sample_uids(uids)
for uid in uids:

def process_uid(uid):
logger.info(f"Reading data for user: {uid}")
variable_data_dict = {}
for dirname in self.required_data:
Expand Down Expand Up @@ -297,6 +301,9 @@ def _read_data_from_old_format(self, source_path: str, user_data_dict: dict):
if variable_data.get_data_size() > 0:
variable_data_dict[dirname] = variable_data
user_data_dict[uid] = RadarUserData(variable_data_dict, self.df_type)

with concurrent.futures.ThreadPoolExecutor() as executor:
executor.map(process_uid, uids)
radar_data = RadarData(user_data_dict, self.df_type)
return radar_data, user_data_dict

Expand All @@ -306,14 +313,15 @@ def _read_data_from_new_format(self, source_path: str, user_data_dict: dict):
uids = self._remove_hidden_dirs(uids)
if self.user_sampler is not None:
uids = self.user_sampler.sample_uids(uids)
for uid in uids:

def process_uid(uid):
# Skip hidden files
if uid[0] == ".":
continue
return
logger.info(f"Reading data for user: {uid}")
variable_data_dict = {}
for dirname in self.required_data:
if dirname not in os.listdir(os.path.join(source_path, uid)):
if not os.path.exists(os.path.join(source_path, uid, dirname)):
continue
logger.info(f"Reading data for variable: {dirname}")
data_files = []
Expand Down Expand Up @@ -341,6 +349,9 @@ def _read_data_from_new_format(self, source_path: str, user_data_dict: dict):
if variable_data.get_data_size() > 0:
variable_data_dict[dirname] = variable_data
user_data_dict[uid] = RadarUserData(variable_data_dict, self.df_type)

with concurrent.futures.ThreadPoolExecutor() as executor:
executor.map(process_uid, uids)
radar_data = RadarData(user_data_dict, self.df_type)
return radar_data, user_data_dict

Expand Down Expand Up @@ -369,6 +380,8 @@ def is_schema_present(self, schema_dir, schema_dir_base) -> bool:
bool
True if schema is present, False otherwise
"""
if "/" in schema_dir_base:
schema_dir_base = schema_dir_base.split("/")[0]
schema_file = os.path.join(
schema_dir, f"schema-{schema_dir_base}.json"
)
Expand All @@ -378,6 +391,9 @@ def is_schema_present(self, schema_dir, schema_dir_base) -> bool:
return False

def get_schema(self, schema_dir, schema_dir_base) -> StructType:

if "/" in schema_dir_base:
schema_dir_base = schema_dir_base.split("/")[0]
if schema_dir_base in self.schema_dict:
return self.schema_dict[schema_dir_base]
else:
Expand Down
Loading