Skip to content

Commit

Permalink
[RUNTIME]crt error handling (apache#5147)
Browse files Browse the repository at this point in the history
* crt error handling

* Review comments fixed
  • Loading branch information
siju-samuel authored and zhiics committed Apr 17, 2020
1 parent 66f6f2f commit 91d3fce
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 33 deletions.
163 changes: 136 additions & 27 deletions src/runtime/crt/graph_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,19 @@ int NodeEntry_Load(TVMGraphRuntimeNodeEntry * entry, JSONReader * reader) {
reader->BeginArray(reader);
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "invalid json format: failed to parse `node_id`\n");
status = -1;
}
reader->ReadUnsignedInteger(reader, &(entry->node_id));
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "invalid json format: failed to parse `index`\n");
status = -1;
}
reader->ReadUnsignedInteger(reader, &(entry->index));
if (reader->NextArrayItem(reader)) {
reader->ReadUnsignedInteger(reader, &(entry->version));
if (reader->NextArrayItem(reader)) {
fprintf(stderr, "invalid json format: failed to parse `version`\n");
status = -1;
}
} else {
entry->version = 0;
Expand Down Expand Up @@ -151,7 +154,10 @@ int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode * node, JSONReader *reader) {
}
if (status != 0) { break; }
}
if (bitmask != (1|2|4)) { fprintf(stderr, "invalid format\n"); }
if (bitmask != (1|2|4)) {
fprintf(stderr, "invalid format\n");
status = -1;
}
return status;
}

Expand All @@ -175,37 +181,81 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r
while (reader->NextObjectItem(reader, key)) {
if (!strcmp(key, "dltype")) {
reader->BeginArray(reader);
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
reader->ReadString(reader, type);
if (strcmp(type, "list_str")) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (strcmp(type, "list_str")) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
reader->BeginArray(reader);
while (reader->NextArrayItem(reader)) {
reader->ReadString(reader, attr->dltype[dltype_count]);
dltype_count++;
}
attr->dltype_count = dltype_count;;
if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); }
if (reader->NextArrayItem(reader)) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
bitmask |= 1;
} else if (!strcmp(key, "storage_id")) {
reader->BeginArray(reader);
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
reader->ReadString(reader, type);
if (strcmp(type, "list_int")) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (strcmp(type, "list_int")) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
reader->BeginArray(reader);
while (reader->NextArrayItem(reader)) {
reader->ReadUnsignedInteger(reader, &(attr->storage_id[storage_id_count]));
storage_id_count++;
}
if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); }
if (reader->NextArrayItem(reader)) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
bitmask |= 2;
} else if (!strcmp(key, "shape")) {
reader->BeginArray(reader);
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
reader->ReadString(reader, type);
if (strcmp(type, "list_shape")) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (strcmp(type, "list_shape")) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
reader->BeginArray(reader);
while (reader->NextArrayItem(reader)) {
reader->BeginArray(reader);
Expand Down Expand Up @@ -233,25 +283,53 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r
shape_count++;
}
attr->shape_count = shape_count;
if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); }
if (reader->NextArrayItem(reader)) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
bitmask |= 4;
} else if (!strcmp(key, "device_index")) {
reader->BeginArray(reader);
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
reader->ReadString(reader, type);
if (strcmp(type, "list_int")) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (strcmp(type, "list_int")) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
while (reader->NextArrayItem(reader)) {
reader->ReadUnsignedInteger(reader, &(attr->device_index[device_index_count]));
device_index_count++;
}
if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); }
if (reader->NextArrayItem(reader)) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
} else {
reader->BeginArray(reader);
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
reader->ReadString(reader, type);
if (!strcmp(type, "list_int")) {
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
uint32_t temp[GRAPH_RUNTIME_MAX_NODES];
uint32_t temp_count = 0;
reader->BeginArray(reader);
Expand All @@ -260,16 +338,29 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r
temp_count++;
}
} else if (!strcmp(type, "size_t")) {
if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); }
if (!(reader->NextArrayItem(reader))) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
uint32_t temp;
reader->ReadUnsignedInteger(reader, &temp);
} else {
fprintf(stderr, "cannot skip graph attr %s", key);
status = -1;
break;
}
if (reader->NextArrayItem(reader)) {
fprintf(stderr, "Invalid json format\n");
status = -1;
break;
}
if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); }
}
}
if (bitmask != (1|2|4)) { fprintf(stderr, "invalid format\n"); }
if (bitmask != (1|2|4)) {
fprintf(stderr, "invalid format\n");
status = -1;
}
return status;
}

Expand Down Expand Up @@ -339,7 +430,10 @@ int TVMGraphRuntime_Load(TVMGraphRuntime * runtime, JSONReader *reader) {
}
if (status != 0) { break; }
}
if (!(bitmask == (1|2|4|8|16))) { fprintf(stderr, "invalid format\n"); }
if (!(bitmask == (1|2|4|8|16))) {
fprintf(stderr, "invalid format\n");
status = -1;
}
return status;
}

Expand All @@ -350,6 +444,7 @@ uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime * runtime,

/*!
* \brief Get the input index given the name of input.
* \param runtime The graph runtime.
* \param name The name of the input.
* \return The index of input.
*/
Expand All @@ -370,8 +465,9 @@ int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime * runtime, const char * name)
}

/*!
* \brief set index-th input to the graph.
* \param index The input index.
* \brief set input to the graph based on name.
* \param runtime The graph runtime.
* \param name The name of the input.
* \param data_in The input data.
*/
void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in) {
Expand All @@ -383,6 +479,13 @@ void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTe
runtime->data_entry[eid].dl_tensor = *data_in;
}

/*!
* \brief Load parameters from parameter blob.
* \param runtime The graph runtime.
* \param param_blob A binary blob of parameter.
* \param param_size The parameter size.
* \return The result of this function execution.
*/
int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob,
const uint32_t param_size) {
int status = 0;
Expand All @@ -392,6 +495,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo
bptr += sizeof(header);
if (header != kTVMNDArrayListMagic) {
fprintf(stderr, "Invalid parameters file format");
status = -1;
}
reserved = ((uint64_t*)bptr)[0]; // NOLINT(*)
bptr += sizeof(reserved);
Expand All @@ -409,6 +513,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo
bptr += sizeof(name_length);
if (name_length >= 80) {
fprintf(stderr, "Error: function name longer than expected.\n");
status = -1;
}
memcpy(names[idx], bptr, name_length);
bptr += name_length;
Expand Down Expand Up @@ -451,6 +556,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo

/*!
* \brief Run all the operations one by one.
* \param runtime The graph runtime.
*/
void TVMGraphRuntime_Run(TVMGraphRuntime * runtime) {
// setup the array and requirements.
Expand Down Expand Up @@ -565,7 +671,8 @@ int TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime * runtime) {
args_count++;
}
if (strcmp(inode->op_type, "tvm_op")) {
fprintf(stderr, "Can only take tvm_op as op\n"); status = -1;
fprintf(stderr, "Can only take tvm_op as op\n");
status = -1;
break;
}
if (args_count >= TVM_CRT_MAX_ARGS) {
Expand Down Expand Up @@ -599,6 +706,7 @@ typedef struct TVMOpArgs {
int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime, const TVMOpParam * param,
DLTensorPtr * args, const uint32_t args_count,
uint32_t num_inputs, TVMPackedFunc * pf) {
int status = 0;
uint32_t idx;
TVMOpArgs arg_ptr;
memset(&arg_ptr, 0, sizeof(TVMOpArgs));
Expand All @@ -624,13 +732,14 @@ int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime, const TVMOpParam
}
if (!strcmp(param->func_name, "__nop") || !strcmp(param->func_name, "__copy")) {
fprintf(stderr, "%s function is not yet supported.", param->func_name);
status = -1;
}

runtime->module.GetFunction(param->func_name, pf);
TVMArgs targs = TVMArgs_Create(arg_ptr.arg_values, arg_ptr.arg_tcodes, arg_ptr.arg_values_count);
pf->SetArgs(pf, &targs);

return 0;
return status;
}

/*!
Expand Down
22 changes: 16 additions & 6 deletions src/runtime/crt/graph_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ typedef struct TVMGraphRuntime {

/*!
* \brief Initialize the graph executor with graph and context.
* \param runtime The graph runtime.
* \param graph_json The execution graph.
* \param module The module containing the compiled functions for the host
* processor.
Expand All @@ -112,27 +113,34 @@ typedef struct TVMGraphRuntime {

/*!
* \brief Get the input index given the name of input.
* \param runtime The graph runtime.
* \param name The name of the input.
* \return The index of input.
*/
int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name);

/*!
* \brief set index-th input to the graph.
* \param index The input index.
* \brief set input to the graph based on name.
* \param runtime The graph runtime.
* \param name The name of the input.
* \param data_in The input data.
*/
void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in);

/*!
* \brief Return NDArray for given output index.
* \param runtime The graph runtime.
* \param index The output index.
*
* \return NDArray corresponding to given output node index.
* \param out The DLTensor corresponding to given output node index.
* \return The result of this function execution.
*/
int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out);
/*!
* \brief Load parameters from parameter blob.
* \param runtime The graph runtime.
* \param param_blob A binary blob of parameter.
* \param param_size The parameter size.
* \return The result of this function execution.
*/
int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob,
const uint32_t param_size);
Expand All @@ -146,10 +154,13 @@ typedef struct TVMGraphRuntime {

/*!
* \brief Create an execution function given input.
* \param runtime The graph runtime.
* \param attrs The node attributes.
* \param args The arguments to the functor, including inputs and outputs.
* \param args_count The total number of arguments.
* \param num_inputs Number of inputs.
* \return The created executor.
* \param pf The created executor.
* \return The result of this function execution.
*/
int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs,
DLTensorPtr * args, const uint32_t args_count,
Expand All @@ -159,7 +170,6 @@ typedef struct TVMGraphRuntime {
uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index);

// /*! \brief The graph nodes. */
/* GraphRuntimeNode nodes_[GRAPH_RUNTIME_MAX_NODES]; */
TVMGraphRuntimeNode nodes[GRAPH_RUNTIME_MAX_NODES];
uint32_t nodes_count;
/*! \brief The argument nodes. */
Expand Down

0 comments on commit 91d3fce

Please sign in to comment.