Skip to content

Commit

Permalink
Merge pull request #5 from dagardner-nv/david-rest-source-sink-callbacks
Browse files Browse the repository at this point in the history
Fix race condition by moving task_done to a callback
  • Loading branch information
dagardner-nv authored Jun 13, 2023
2 parents 6a83898 + 13871d8 commit b67fe76
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 33 deletions.
15 changes: 10 additions & 5 deletions morpheus/_lib/include/morpheus/utilities/rest_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

#pragma once

#include <boost/asio/io_context.hpp> // for io_context
#include <pybind11/pytypes.h> // for pybind11::function
#include <boost/asio/io_context.hpp> // for io_context
#include <boost/system/error_code.hpp> // for error_code
#include <pybind11/pytypes.h> // for pybind11::function

#include <atomic> // for atomic
#include <chrono> // for seconds
Expand All @@ -44,18 +45,22 @@ namespace morpheus {
*/

#pragma GCC visibility push(default)
using on_complete_cb_fn_t = std::function<void(const boost::system::error_code& /* error message */)>;

/**
* @brief A tuple consisting of the HTTP status code, mime type to be used for the Content-Type header, and the body of
* the response.
*/
using parse_status_t = std::
tuple<unsigned /*http status code*/, std::string /* Content-Type of response */, std::string /* response body */>;
using parse_status_t = std::tuple<unsigned /*http status code*/,
std::string /* Content-Type of response */,
std::string /* response body */,
on_complete_cb_fn_t /* optional callback function, ignored if null */>;

/**
* @brief A function that receives the post body and returns an HTTP status code, Content-Type string and body.
*
* @details The function is expected to return a tuple conforming to `parse_status_t` consisting of the HTTP status
* code, mime type value for the Content-Type header and the body of the response.
* code, mime type value for the Content-Type header, body of the response and optionally a callback function.
*/
using payload_parse_fn_t = std::function<parse_status_t(const std::string& /* post body */)>;

Expand Down
8 changes: 4 additions & 4 deletions morpheus/_lib/src/stages/rest_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ RestSourceStage::RestSourceStage(std::string bind_address,
{
std::string error_msg = "Error occurred converting REST payload to Dataframe";
LOG(ERROR) << error_msg << ": " << e.what();
return std::make_tuple(400, "text/plain", error_msg);
return std::make_tuple(400, "text/plain", error_msg, nullptr);
}

try
Expand All @@ -68,7 +68,7 @@ RestSourceStage::RestSourceStage(std::string bind_address,

if (queue_status == boost::fibers::channel_op_status::success)
{
return std::make_tuple(201, "text/plain", std::string());
return std::make_tuple(201, "text/plain", std::string(), nullptr);
}

std::string error_msg = "REST payload queue is ";
Expand All @@ -90,12 +90,12 @@ RestSourceStage::RestSourceStage(std::string bind_address,
}
}

return std::make_tuple(503, "text/plain", std::move(error_msg));
return std::make_tuple(503, "text/plain", std::move(error_msg), nullptr);
} catch (const std::exception& e)
{
std::string error_msg = "Error occurred while pushing payload to queue";
LOG(ERROR) << error_msg << ": " << e.what();
return std::make_tuple(500, "text/plain", error_msg);
return std::make_tuple(500, "text/plain", error_msg, nullptr);
}
};
m_server = std::make_unique<RestServer>(std::move(parser),
Expand Down
67 changes: 58 additions & 9 deletions morpheus/_lib/src/utilities/rest_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@

#include "morpheus/utilities/rest_server.hpp"

#include "pymrc/utilities/function_wrappers.hpp" // for PyFuncWrapper

#include <boost/asio.hpp> // for dispatch
#include <boost/asio/ip/tcp.hpp> // for acceptor, endpoint, socket,
#include <boost/beast/core.hpp> // for bind_front_handler, error_code, flat_buffer, tcp_stream
#include <boost/beast/http.hpp> // for read_async, request, response, verb, write_async
#include <glog/logging.h> // for CHECK and LOG
#include <pybind11/gil.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>

#include <utility> // for move

Expand Down Expand Up @@ -56,7 +59,8 @@ class Session : public std::enable_shared_from_this<Session>
m_url_endpoint{url_endpoint},
m_method{method},
m_max_payload_size{max_payload_size},
m_timeout{timeout}
m_timeout{timeout},
m_on_complete_cb{nullptr}
{}

~Session() = default;
Expand Down Expand Up @@ -107,6 +111,7 @@ class Session : public std::enable_shared_from_this<Session>
m_response->result(std::get<0>(parse_status));
m_response->set(http::field::content_type, std::get<1>(parse_status));
m_response->body() = std::get<2>(parse_status);
m_on_complete_cb = std::get<3>(parse_status);
}
else
{
Expand Down Expand Up @@ -147,6 +152,22 @@ class Session : public std::enable_shared_from_this<Session>
m_parser.reset(nullptr);
m_response.reset(nullptr);

if (m_on_complete_cb)
{
try
{
m_on_complete_cb(ec);
} catch (const std::exception& e)
{
LOG(ERROR) << "Caught exception while calling on_complete callback: " << e.what();
} catch (...)
{
LOG(ERROR) << "Caught unknown exception while calling on_complete callback";
}

m_on_complete_cb = nullptr;
}

do_read();
}

Expand All @@ -167,6 +188,7 @@ class Session : public std::enable_shared_from_this<Session>
// The response, and parser are all reset for each incoming request
std::unique_ptr<http::request_parser<http::string_body>> m_parser;
std::unique_ptr<http::response<http::string_body>> m_response;
morpheus::on_complete_cb_fn_t m_on_complete_cb;
};

class Listener : public std::enable_shared_from_this<Listener>
Expand Down Expand Up @@ -257,7 +279,8 @@ RestServer::RestServer(payload_parse_fn_t payload_parse_fn,
m_num_threads(num_threads),
m_request_timeout(request_timeout),
m_max_payload_size(max_payload_size),
m_io_context{nullptr}
m_io_context{nullptr},
m_is_running{false}
{
if (m_method == http::verb::unknown)
{
Expand Down Expand Up @@ -342,7 +365,10 @@ RestServer::~RestServer()
}

/****** RestServerInterfaceProxy *************************/
std::shared_ptr<RestServer> RestServerInterfaceProxy::init(pybind11::function py_parse_fn,
using mrc::pymrc::PyFuncWrapper;
namespace py = pybind11;

std::shared_ptr<RestServer> RestServerInterfaceProxy::init(py::function py_parse_fn,
std::string bind_address,
unsigned short port,
std::string endpoint,
Expand All @@ -351,12 +377,35 @@ std::shared_ptr<RestServer> RestServerInterfaceProxy::init(pybind11::function py
std::size_t max_payload_size,
int64_t request_timeout)
{
payload_parse_fn_t payload_parse_fn = [py_parse_fn = std::move(py_parse_fn)](const std::string& payload) {
pybind11::gil_scoped_acquire gil;
auto py_payload = pybind11::str(payload);
auto py_result = py_parse_fn(py_payload);
auto result = pybind11::cast<parse_status_t>(py_result);
return result;
auto wrapped_parse_fn = PyFuncWrapper(std::move(py_parse_fn));
payload_parse_fn_t payload_parse_fn = [wrapped_parse_fn = std::move(wrapped_parse_fn)](const std::string& payload) {
py::gil_scoped_acquire gil;
auto py_payload = py::str(payload);
auto py_result = wrapped_parse_fn.operator()<py::tuple, py::str>(py_payload);
on_complete_cb_fn_t cb_fn{nullptr};
if (!py_result[3].is_none())
{
auto py_cb_fn = py_result[3].cast<py::function>();
auto wrapped_cb_fn = PyFuncWrapper(std::move(py_cb_fn));

cb_fn = [wrapped_cb_fn = std::move(wrapped_cb_fn)](const beast::error_code& ec) {
py::gil_scoped_acquire gil;
py::bool_ has_error = false;
py::str error_msg;
if (ec)
{
has_error = true;
error_msg = ec.message();
}

wrapped_cb_fn.operator()<void, py::bool_, py::str>(has_error, error_msg);
};
}

return std::make_tuple(py::cast<unsigned>(py_result[0]),
py::cast<std::string>(py_result[1]),
py::cast<std::string>(py_result[2]),
std::move(cb_fn));
};

return std::make_shared<RestServer>(std::move(payload_parse_fn),
Expand Down
8 changes: 4 additions & 4 deletions morpheus/stages/input/rest_source_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,23 @@ def _parse_payload(self, payload: str) -> typing.Tuple[int, str]:
except Exception as e:
err_msg = "Error occurred converting REST payload to Dataframe"
logger.error(f"{err_msg}: {e}")
return (400, MimeTypes.TEXT.value, err_msg)
return (400, MimeTypes.TEXT.value, err_msg, None)

try:
self._queue.put(df, block=True, timeout=self._queue_timeout)
return (201, MimeTypes.TEXT.value, "")
return (201, MimeTypes.TEXT.value, "", None)
except (queue.Full, Closed) as e:
err_msg = "REST payload queue is "
if isinstance(e, queue.Full):
err_msg += "full"
else:
err_msg += "closed"
logger.error(err_msg)
return (503, MimeTypes.TEXT.value, err_msg)
return (503, MimeTypes.TEXT.value, err_msg, None)
except Exception as e:
err_msg = "Error occurred while pushing payload to queue"
logger.error(f"{err_msg}: {e}")
return (500, MimeTypes.TEXT.value, err_msg)
return (500, MimeTypes.TEXT.value, err_msg, None)

def _generate_frames(self) -> typing.Iterator[MessageMeta]:
from morpheus.common import FiberQueue
Expand Down
32 changes: 21 additions & 11 deletions morpheus/stages/output/rest_server_sink_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import queue
import time
import typing
from functools import partial
from io import StringIO

import mrc
Expand Down Expand Up @@ -147,9 +148,21 @@ def _default_df_serializer(self, df: DataFrameType) -> str:
str_buf.seek(0)
return str_buf.read()

def _request_callback(self, df: DataFrameType, num_tasks: int, has_error: bool, error_msg: str) -> None:
try:
if has_error:
logger.error(error_msg)

# If the client failed to read the response, then we need to put the dataframe back into the queue
self._queue.put(df)

# Even in the event of an error, we need to mark the tasks as done.
for _ in range(num_tasks):
self._queue.task_done()
except Exception as e:
logger.error("Unknown error in request callback: %s", e)

def _request_handler(self, _: str) -> typing.Tuple[int, str]:
# TODO: If this takes longer than `request_timeout_secs` then the request will be terminated, and the messages
# will be lost
num_rows = 0
data_frames = []
try:
Expand All @@ -162,22 +175,20 @@ def _request_handler(self, _: str) -> typing.Tuple[int, str]:
except Exception as e:
err_msg = "Unknown error processing request"
logger.error(f"{err_msg}: %s", e)
return (500, MimeTypes.TEXT.value, err_msg)
return (500, MimeTypes.TEXT.value, err_msg, None)

if (len(data_frames) > 0):
df = data_frames[0]
if len(data_frames) > 1:
cat_fn = pd.concat if isinstance(df, pd.DataFrame) else cudf.concat
df = cat_fn(data_frames)

# TODO: Move to a callback so that we only call task_done once the response has been sent, potentially
# allowing us to re-queue the message in the event of a network error
for _ in range(len(data_frames)):
self._queue.task_done()

return (200, self._content_type, self._df_serializer_fn(df))
return (200,
self._content_type,
self._df_serializer_fn(df),
partial(self._request_callback, df, len(data_frames)))
else:
return (204, MimeTypes.TEXT.value, "No messages available")
return (204, MimeTypes.TEXT.value, "", None)

def _partition_df(self, df: DataFrameType) -> typing.Iterable[DataFrameType]:
"""
Expand Down Expand Up @@ -209,7 +220,6 @@ def _process_message(self, msg: MessageMeta) -> MessageMeta:
def _block_until_empty(self):
logger.debug("Waiting for queue to empty")
self._queue.join()
time.sleep(1) # TODO: race condition, need some sort of on req callback, and only call task_done() there
logger.debug("stopping server")
self._server.stop()
logger.debug("stopped")
Expand Down

0 comments on commit b67fe76

Please sign in to comment.