Skip to content

Commit

Permalink
feat: allow multiple input files
Browse files Browse the repository at this point in the history
  • Loading branch information
IgnacioHeredia authored and alvarolopez committed Jul 19, 2024
1 parent f4e9ec4 commit ecff7f6
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions deepaas/cmd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
fields.URL: str,
fields.Url: str,
fields.UUID: str,
fields.Field: str,
}


Expand Down Expand Up @@ -159,6 +160,18 @@ def _get_model_name(model_name=None):
sys.exit(1)


def _get_file_args(fields_in):
"""Function to retrieve a list of file-type fields
:param fields_in: mashmallow fields
:return: list
"""
file_fields = []
for k, v in fields_in.items():
if type(v) is fields.Field:
file_fields.append(k)
return file_fields


# Get the model name
model_name = CONF.model_name

Expand All @@ -174,6 +187,10 @@ def _get_model_name(model_name=None):
predict_args = _fields_to_dict(model_obj.get_predict_args())
train_args = _fields_to_dict(model_obj.get_train_args())

# Find which of the arguments are going to be files
file_args = {}
file_args['predict'] = _get_file_args(model_obj.get_predict_args())
file_args['train'] = _get_file_args(model_obj.get_train_args())

# Function to add later these arguments to CONF via SubCommandOpt
def _add_methods(subparsers):
Expand Down Expand Up @@ -285,29 +302,31 @@ def main():
if CONF.deepaas_with_multiprocessing:
mp.set_start_method("spawn", force=True)

# TODO(multi-file): change to many files ('for' itteration)
if CONF.methods.__contains__("files"):
if CONF.methods.files:
# Create file wrapper for file args (if provided)
for farg in file_args[CONF.methods.name]:
if getattr(CONF.methods, farg, None):
fpath = conf_vars[farg]

# create tmp file as later it supposed
# to be deleted by the application
temp = tempfile.NamedTemporaryFile()
temp.close()
# copy original file into tmp file
with open(conf_vars["files"], "rb") as f:
with open(fpath, "rb") as f:
with open(temp.name, "wb") as f_tmp:
for line in f:
f_tmp.write(line)

# create file object
file_type = mimetypes.MimeTypes().guess_type(conf_vars["files"])[0]
file_type = mimetypes.MimeTypes().guess_type(fpath)[0]
file_obj = v2_wrapper.UploadedFile(
name="data",
filename=temp.name,
content_type=file_type,
original_filename=conf_vars["files"],
original_filename=fpath,
)
# re-write 'files' parameter in conf_vars
conf_vars["files"] = file_obj
# re-write parameter in conf_vars
conf_vars[farg] = file_obj

# debug of input parameters
LOG.debug("[DEBUG provided options, conf_vars]: {}".format(conf_vars))
Expand Down

0 comments on commit ecff7f6

Please sign in to comment.