-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: static model loader #603
Changes from 12 commits
f053fbb
88981d3
6ec44cc
34f4005
2d923dd
f50d2e6
d2acecd
a40c9ed
f8c7d06
f7f212c
5162279
dc8851a
150e9f6
f4a7ff5
003d1f1
2261efd
fe46f5d
990a552
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include "binding_static_model.h" | ||
|
||
#include "err_constants.h" | ||
#include "model_mgmt.h" | ||
|
||
using namespace rl_net_native; | ||
using namespace reinforcement_learning; | ||
|
||
binding_static_model::binding_static_model(const char* vw_model, const size_t len) : vw_model(vw_model), len(len) {} | ||
|
||
int binding_static_model::get_data(model_transport::model_data& data, reinforcement_learning::api_status* status) | ||
{ | ||
if (this->vw_model == nullptr || this->len == 0) | ||
{ | ||
return reinforcement_learning::error_code::static_model_load_error; | ||
} | ||
|
||
char* buffer = data.alloc(this->len); | ||
if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; } | ||
|
||
memcpy(buffer, this->vw_model, this->len); | ||
data.data_sz(this->len); | ||
|
||
return reinforcement_learning::error_code::success; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#pragma once | ||
|
||
#include "model_mgmt.h" | ||
|
||
bassmang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include <cstring> | ||
|
||
namespace model_transport = reinforcement_learning::model_management; | ||
|
||
namespace rl_net_native | ||
{ | ||
|
||
namespace constants | ||
{ | ||
const char* const BINDING_DATA_TRANSPORT = "BINDING_DATA_TRANSPORT"; | ||
} | ||
class binding_static_model : public model_transport::i_data_transport | ||
{ | ||
public: | ||
binding_static_model(const char* vw_model, const size_t len); | ||
int get_data(model_transport::model_data& data, reinforcement_learning::api_status* status = nullptr) override; | ||
|
||
private: | ||
const char* vw_model; | ||
const size_t len; | ||
}; | ||
} // namespace rl_net_native |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
using System; | ||
using System.Threading; | ||
using System.Runtime.InteropServices; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
|
||
using Rl.Net.Native; | ||
|
||
|
@@ -10,16 +12,38 @@ public sealed class FactoryContext : NativeObject<FactoryContext> | |
[DllImport("rlnetnative")] | ||
private static extern IntPtr CreateFactoryContext(); | ||
|
||
[DllImport("rlnetnative")] | ||
private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr weights, int length); | ||
|
||
[DllImport("rlnetnative")] | ||
private static extern void DeleteFactoryContext(IntPtr context); | ||
|
||
[DllImport("rlnetnative")] | ||
private static extern IntPtr SetFactoryContextBindingSenderFactory(IntPtr context, sender_create_fn create_Fn, sender_vtable vtable); | ||
|
||
public List<byte> Weights { get; private set; } | ||
|
||
public FactoryContext() : base(new New<FactoryContext>(CreateFactoryContext), new Delete<FactoryContext>(DeleteFactoryContext)) | ||
{ | ||
} | ||
|
||
public FactoryContext(IEnumerable<byte> weightsEnumerable) : base( | ||
new New<FactoryContext>(() => { | ||
var weightsArray = weightsEnumerable.ToArray(); | ||
GCHandle handle = GCHandle.Alloc(weightsArray, GCHandleType.Pinned); | ||
try { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to check if this makes a copy here? Unclear to me... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a copy happening here but I believe it is necessary since they pass in a List which doesn't have a way to obtain a pointer to it's internal array There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a way to avoid the copy on .NET 5 or above with List<> types, but not generally, no. |
||
IntPtr ptr = handle.AddrOfPinnedObject(); | ||
return CreateFactoryContextWithStaticModel(ptr, weightsArray.Length); | ||
} | ||
finally { | ||
if (handle.IsAllocated) | ||
handle.Free(); | ||
} | ||
}), | ||
new Delete<FactoryContext>(DeleteFactoryContext)) | ||
{ | ||
} | ||
|
||
private GCHandleLifetime registeredSenderCreateHandle; | ||
|
||
public void SetSenderFactory<TSender>(Func<IReadOnlyConfiguration, ErrorCallback, TSender> createSender) where TSender : ISender | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
using System; | ||
using System.Runtime.InteropServices; | ||
using System.Collections.Generic; | ||
|
||
using Rl.Net.Native; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if model_data class can use a method to encapsulate these operations. i.e. alloc + copy + set size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put a set_data function in data_data to handle this