-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW] Faster Treelite serialization #2263
[REVIEW] Faster Treelite serialization #2263
Conversation
Very promising! I suspect that the RF -> treelite can be accelerated a lot in a future rev too... I don't know why it should have to be much longer than treelite-> fil in general if we move to convert the representation efficiently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a fairly quick pass for the first rev... will review again in more detail as things get wrapped up.
Overall, I think it's great and looks pretty clear. The changes to the core cython code are smaller than I expected - it fits in well with the existing pattern, so I think it's pretty understandable.
I don't fully get the serialization to frames - looks like it simply passing through the binary format used within tl? Seems reasonable to me, but I haven't used this style of conversion before with Py_buffers etc.
@@ -515,18 +509,21 @@ class RandomForestClassifier(Base): | |||
to a shared file. Cuml issue #1854 has been created to track this. | |||
""" | |||
def _tl_model_handles(self, model_bytes): | |||
cdef ModelHandle cuml_model_ptr = NULL | |||
cdef uintptr_t tl_handle_int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In another pr, I proposed renaming this something like _alloc_and_convert_model
to make it clear that the caller needs to free the result.
""" | ||
Returns the self.model_pbuf_bytes. | ||
Returns the self.model_bytes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe clarify the type and update the rest of docstring: "Returns the treelite binary format representation of this model."
Treelite objects now exposes the Python buffer protocol interface, so that we can transparently convert Treelite objects to memory views with zero overhead. In @jakirkham gave me valuable advice for implementing the Python buffer protocol. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, its great to see the time required for predict drop so much! I have a few suggestions and questions
117d313
to
a935bee
Compare
This is quite a strange error:
|
I managed to fix the failing benchmark test. Marking this as ready for review. |
List of changes made to Treelite:
Once all tests pass, I will go ahead and release 0.92 version of Treelite. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exciting to see all of the progress here, @hcho3! 😄
Added a couple of comments about how we might simplify things a bit here. Please let me know if you have any questions 🙂
I submitted Treelite 0.92 to conda-forge: conda-forge/staged-recipes#11926. Fingers crossed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I have only small questions/suggestions.
Also, are there additional unit tests that would be helpful here? Serializing+deserializing model variants (e.g. classification/regression/multiclass) and ensuring we get the properties right? Not sure...
if (task_category > 2) { | ||
// Multi-class classification | ||
TREELITE_CHECK(TreeliteModelBuilderSetModelParam( | ||
model_builder, "pred_transform", "max_index")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't multiclass currently disabled until #2248 goes in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. I included the line because it was already part of the current codebase.
@JohnZed The Treelite repo contains several round-trip tests for the new serializer: https://github.com/dmlc/treelite/blob/master/tests/cpp/test_serializer.cc |
@jakirkham I addressed all your comments, except the one about casting buffer frames. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grouping together suggested format_str
encoding changes for clearer discussion.
model: uintptr_t | ||
) -> Dict[str, Union[List[str], List[np.ndarray]]]: | ||
frames = get_frames(model) | ||
header = {'format_str': [x.format.encode('utf-8') for x in frames], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
header = {'format_str': [x.format.encode('utf-8') for x in frames], | |
header = {'format_str': [x.format for x in frames], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went ahead and applied your suggestion.
Is it preferable to pass str
to pickle, rather than bytes
? I'd like to understand your reasoning behind this suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was thinking about this in the context of moving from pickle to Dask serialization down the road (assuming that is still the plan). Typically the header consists of things like int
, str
, dict
, list
, and tuple
. Generally things that are MsgPack serializable. Typically bytes
and memoryviews
are reserved for frames instead.
Was unsure at first whether bytes
would work in the header. However after playing with things a bit bytes
may work. MsgPack is at least able to handle them with the flags that Dask is using.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Thanks for your explanation. I agree that built-in types like str
would be well supported by Dask serializer.
Co-authored-by: jakirkham <jakirkham@gmail.com>
Thanks for all of the work here @hcho3! 😄 Looks good Sounds like we are going to handle switching to the |
Yes, I’ll work on it after this PR is merged. |
The failing test in the CI should be fixed by #2432 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Summary
Speed up serialization of Treelite model objects and reduce overhead in multi-GPU RF prediction.
Features
Benchmark setup
n_gpus=2, n_gb=2, n_features=20, depth=25, n_estimators=10
. This leads to a forest consisting of 10 depth-25 trees, and we run through 25 million data rows.Benchmark Results
Aggregate
Breakdown by components
Notice that some of the overhead is still yet to be explained. Also note that the actual prediction time is a small portion of the total run time. Due to the nature of distributed algorithm, timing measures are approximate.
Current cuML
This PR