From 91d3fce7654a88e5dbaae3caa61d5e24a61ea09d Mon Sep 17 00:00:00 2001 From: Samuel Date: Fri, 27 Mar 2020 05:29:59 +0530 Subject: [PATCH] [RUNTIME]crt error handling (#5147) * crt error handling * Review comments fixed --- src/runtime/crt/graph_runtime.c | 163 ++++++++++++++++++++++++++------ src/runtime/crt/graph_runtime.h | 22 +++-- 2 files changed, 152 insertions(+), 33 deletions(-) diff --git a/src/runtime/crt/graph_runtime.c b/src/runtime/crt/graph_runtime.c index 1957d0bead4b..89c325acb216 100644 --- a/src/runtime/crt/graph_runtime.c +++ b/src/runtime/crt/graph_runtime.c @@ -42,16 +42,19 @@ int NodeEntry_Load(TVMGraphRuntimeNodeEntry * entry, JSONReader * reader) { reader->BeginArray(reader); if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "invalid json format: failed to parse `node_id`\n"); + status = -1; } reader->ReadUnsignedInteger(reader, &(entry->node_id)); if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "invalid json format: failed to parse `index`\n"); + status = -1; } reader->ReadUnsignedInteger(reader, &(entry->index)); if (reader->NextArrayItem(reader)) { reader->ReadUnsignedInteger(reader, &(entry->version)); if (reader->NextArrayItem(reader)) { fprintf(stderr, "invalid json format: failed to parse `version`\n"); + status = -1; } } else { entry->version = 0; @@ -151,7 +154,10 @@ int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode * node, JSONReader *reader) { } if (status != 0) { break; } } - if (bitmask != (1|2|4)) { fprintf(stderr, "invalid format\n"); } + if (bitmask != (1|2|4)) { + fprintf(stderr, "invalid format\n"); + status = -1; + } return status; } @@ -175,37 +181,81 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r while (reader->NextObjectItem(reader, key)) { if (!strcmp(key, "dltype")) { reader->BeginArray(reader); - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } reader->ReadString(reader, type); - if (strcmp(type, "list_str")) { fprintf(stderr, "Invalid json format\n"); } - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (strcmp(type, "list_str")) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { reader->ReadString(reader, attr->dltype[dltype_count]); dltype_count++; } attr->dltype_count = dltype_count;; - if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } bitmask |= 1; } else if (!strcmp(key, "storage_id")) { reader->BeginArray(reader); - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } reader->ReadString(reader, type); - if (strcmp(type, "list_int")) { fprintf(stderr, "Invalid json format\n"); } - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (strcmp(type, "list_int")) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { reader->ReadUnsignedInteger(reader, &(attr->storage_id[storage_id_count])); storage_id_count++; } - if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } bitmask |= 2; } else if (!strcmp(key, "shape")) { reader->BeginArray(reader); - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } reader->ReadString(reader, type); - if (strcmp(type, "list_shape")) { fprintf(stderr, "Invalid json format\n"); } - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (strcmp(type, "list_shape")) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { reader->BeginArray(reader); @@ -233,25 +283,53 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r shape_count++; } attr->shape_count = shape_count; - if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } bitmask |= 4; } else if (!strcmp(key, "device_index")) { reader->BeginArray(reader); - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } reader->ReadString(reader, type); - if (strcmp(type, "list_int")) { fprintf(stderr, "Invalid json format\n"); } - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (strcmp(type, "list_int")) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } while (reader->NextArrayItem(reader)) { reader->ReadUnsignedInteger(reader, &(attr->device_index[device_index_count])); device_index_count++; } - if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } } else { reader->BeginArray(reader); - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } reader->ReadString(reader, type); if (!strcmp(type, "list_int")) { - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } uint32_t temp[GRAPH_RUNTIME_MAX_NODES]; uint32_t temp_count = 0; reader->BeginArray(reader); @@ -260,16 +338,29 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r temp_count++; } } else if (!strcmp(type, "size_t")) { - if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; + } uint32_t temp; reader->ReadUnsignedInteger(reader, &temp); } else { fprintf(stderr, "cannot skip graph attr %s", key); + status = -1; + break; + } + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "Invalid json format\n"); + status = -1; + break; } - if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } } } - if (bitmask != (1|2|4)) { fprintf(stderr, "invalid format\n"); } + if (bitmask != (1|2|4)) { + fprintf(stderr, "invalid format\n"); + status = -1; + } return status; } @@ -339,7 +430,10 @@ int TVMGraphRuntime_Load(TVMGraphRuntime * runtime, JSONReader *reader) { } if (status != 0) { break; } } - if (!(bitmask == (1|2|4|8|16))) { fprintf(stderr, "invalid format\n"); } + if (!(bitmask == (1|2|4|8|16))) { + fprintf(stderr, "invalid format\n"); + status = -1; + } return status; } @@ -350,6 +444,7 @@ uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime * runtime, /*! * \brief Get the input index given the name of input. + * \param runtime The graph runtime. * \param name The name of the input. * \return The index of input. */ @@ -370,8 +465,9 @@ int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime * runtime, const char * name) } /*! - * \brief set index-th input to the graph. - * \param index The input index. + * \brief set input to the graph based on name. + * \param runtime The graph runtime. + * \param name The name of the input. * \param data_in The input data. */ void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in) { @@ -383,6 +479,13 @@ void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTe runtime->data_entry[eid].dl_tensor = *data_in; } +/*! + * \brief Load parameters from parameter blob. + * \param runtime The graph runtime. + * \param param_blob A binary blob of parameter. + * \param param_size The parameter size. + * \return The result of this function execution. + */ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob, const uint32_t param_size) { int status = 0; @@ -392,6 +495,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo bptr += sizeof(header); if (header != kTVMNDArrayListMagic) { fprintf(stderr, "Invalid parameters file format"); + status = -1; } reserved = ((uint64_t*)bptr)[0]; // NOLINT(*) bptr += sizeof(reserved); @@ -409,6 +513,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo bptr += sizeof(name_length); if (name_length >= 80) { fprintf(stderr, "Error: function name longer than expected.\n"); + status = -1; } memcpy(names[idx], bptr, name_length); bptr += name_length; @@ -451,6 +556,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo /*! * \brief Run all the operations one by one. + * \param runtime The graph runtime. */ void TVMGraphRuntime_Run(TVMGraphRuntime * runtime) { // setup the array and requirements. @@ -565,7 +671,8 @@ int TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime * runtime) { args_count++; } if (strcmp(inode->op_type, "tvm_op")) { - fprintf(stderr, "Can only take tvm_op as op\n"); status = -1; + fprintf(stderr, "Can only take tvm_op as op\n"); + status = -1; break; } if (args_count >= TVM_CRT_MAX_ARGS) { @@ -599,6 +706,7 @@ typedef struct TVMOpArgs { int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime, const TVMOpParam * param, DLTensorPtr * args, const uint32_t args_count, uint32_t num_inputs, TVMPackedFunc * pf) { + int status = 0; uint32_t idx; TVMOpArgs arg_ptr; memset(&arg_ptr, 0, sizeof(TVMOpArgs)); @@ -624,13 +732,14 @@ int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime, const TVMOpParam } if (!strcmp(param->func_name, "__nop") || !strcmp(param->func_name, "__copy")) { fprintf(stderr, "%s function is not yet supported.", param->func_name); + status = -1; } runtime->module.GetFunction(param->func_name, pf); TVMArgs targs = TVMArgs_Create(arg_ptr.arg_values, arg_ptr.arg_tcodes, arg_ptr.arg_values_count); pf->SetArgs(pf, &targs); - return 0; + return status; } /*! diff --git a/src/runtime/crt/graph_runtime.h b/src/runtime/crt/graph_runtime.h index 7fe395c5b09c..5b6e9058840d 100644 --- a/src/runtime/crt/graph_runtime.h +++ b/src/runtime/crt/graph_runtime.h @@ -99,6 +99,7 @@ typedef struct TVMGraphRuntime { /*! * \brief Initialize the graph executor with graph and context. + * \param runtime The graph runtime. * \param graph_json The execution graph. * \param module The module containing the compiled functions for the host * processor. @@ -112,27 +113,34 @@ typedef struct TVMGraphRuntime { /*! * \brief Get the input index given the name of input. + * \param runtime The graph runtime. * \param name The name of the input. * \return The index of input. */ int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name); /*! - * \brief set index-th input to the graph. - * \param index The input index. + * \brief set input to the graph based on name. + * \param runtime The graph runtime. + * \param name The name of the input. * \param data_in The input data. */ void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); + /*! * \brief Return NDArray for given output index. + * \param runtime The graph runtime. * \param index The output index. - * - * \return NDArray corresponding to given output node index. + * \param out The DLTensor corresponding to given output node index. + * \return The result of this function execution. */ int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out); /*! * \brief Load parameters from parameter blob. + * \param runtime The graph runtime. * \param param_blob A binary blob of parameter. + * \param param_size The parameter size. + * \return The result of this function execution. */ int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob, const uint32_t param_size); @@ -146,10 +154,13 @@ typedef struct TVMGraphRuntime { /*! * \brief Create an execution function given input. + * \param runtime The graph runtime. * \param attrs The node attributes. * \param args The arguments to the functor, including inputs and outputs. + * \param args_count The total number of arguments. * \param num_inputs Number of inputs. - * \return The created executor. + * \param pf The created executor. + * \return The result of this function execution. */ int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs, DLTensorPtr * args, const uint32_t args_count, @@ -159,7 +170,6 @@ typedef struct TVMGraphRuntime { uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index); // /*! \brief The graph nodes. */ - /* GraphRuntimeNode nodes_[GRAPH_RUNTIME_MAX_NODES]; */ TVMGraphRuntimeNode nodes[GRAPH_RUNTIME_MAX_NODES]; uint32_t nodes_count; /*! \brief The argument nodes. */