Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bassmang committed Feb 8, 2024
1 parent 5162279 commit 150e9f6
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 32 deletions.
16 changes: 1 addition & 15 deletions bindings/cs/rl.net.native/binding_static_model.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
#include "binding_static_model.h"

#include "err_constants.h"
#include "model_mgmt.h"

using namespace rl_net_native;
using namespace reinforcement_learning;

Expand All @@ -11,16 +8,5 @@ binding_static_model::binding_static_model(const char* vw_model, const size_t le
int binding_static_model::get_data(
model_transport::model_data& data, reinforcement_learning::api_status* status)
{
if (this->vw_model == nullptr || this->len == 0)
{
return reinforcement_learning::error_code::static_model_load_error;
}

char* buffer = data.alloc(this->len);
if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; }

memcpy(buffer, this->vw_model, this->len);
data.data_sz(this->len);

return reinforcement_learning::error_code::success;
return data.set_data(vw_model, len);
}
1 change: 1 addition & 0 deletions bindings/cs/rl.net.native/binding_static_model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "err_constants.h"
#include "model_mgmt.h"

#include <cstring>
Expand Down
15 changes: 6 additions & 9 deletions bindings/cs/rl.net.native/rl.net.factory_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,16 @@ void cleanup_data_transport_factory(data_transport_factory_t* data_transport_fac
}
}

API factory_context_t* CreateFactoryContextWithStaticModel(const char* weights, const size_t len)
API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len)
{
using namespace reinforcement_learning::model_management;
auto context = CreateFactoryContext();

auto data_transport_factory_fn = [weights, len](std::unique_ptr<model_management::i_data_transport>& retval,
const utility::configuration& configuration, i_trace* trace_logger,
api_status* status) -> int
{
char* weightsCopy = new char[len];
std::memcpy(weightsCopy, weights, len);
retval.reset(new rl_net_native::binding_static_model(weightsCopy, len));
return error_code::success;
auto data_transport_factory_fn = [vw_model, len](std::unique_ptr<model_management::i_data_transport>& retval, const utility::configuration& configuration, i_trace* trace_logger, api_status* status) -> int {
char* vw_model_copy = new char[len];
std::memcpy(vw_model_copy, vw_model, len);
retval.reset(new rl_net_native::binding_static_model(vw_model_copy, len));
return error_code::success;
};

data_transport_factory_t* data_transport_factory =
Expand Down
2 changes: 1 addition & 1 deletion bindings/cs/rl.net.native/rl.net.factory_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ typedef struct factory_context
extern "C"
{
API factory_context_t* CreateFactoryContext();
API factory_context_t* CreateFactoryContextWithStaticModel(const char* weights, const size_t len);
API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len);
API void DeleteFactoryContext(factory_context_t* context);
API void SetFactoryContextBindingSenderFactory(
factory_context_t* context, rl_net_native::sender_create_fn create_fn, rl_net_native::sender_vtable_t vtable);
Expand Down
12 changes: 5 additions & 7 deletions bindings/cs/rl.net/FactoryContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,25 @@ public sealed class FactoryContext : NativeObject<FactoryContext>
private static extern IntPtr CreateFactoryContext();

[DllImport("rlnetnative")]
private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr weights, int length);
private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr vw_model, int len);

[DllImport("rlnetnative")]
private static extern void DeleteFactoryContext(IntPtr context);

[DllImport("rlnetnative")]
private static extern IntPtr SetFactoryContextBindingSenderFactory(IntPtr context, sender_create_fn create_Fn, sender_vtable vtable);

public List<byte> Weights { get; private set; }

public FactoryContext() : base(new New<FactoryContext>(CreateFactoryContext), new Delete<FactoryContext>(DeleteFactoryContext))
{
}

public FactoryContext(IEnumerable<byte> weightsEnumerable) : base(
public FactoryContext(IEnumerable<byte> vwModelEnumerable) : base(
new New<FactoryContext>(() => {
var weightsArray = weightsEnumerable.ToArray();
GCHandle handle = GCHandle.Alloc(weightsArray, GCHandleType.Pinned);
var vwModelArray = vwModelEnumerable.ToArray();
GCHandle handle = GCHandle.Alloc(vwModelArray, GCHandleType.Pinned);
try {
IntPtr ptr = handle.AddrOfPinnedObject();
return CreateFactoryContextWithStaticModel(ptr, weightsArray.Length);
return CreateFactoryContextWithStaticModel(ptr, vwModelArray.Length);
}
finally {
if (handle.IsAllocated)
Expand Down
4 changes: 4 additions & 0 deletions include/model_mgmt.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include "err_constants.h"
#include "multistep.h"

#include <cstddef>
Expand Down Expand Up @@ -34,6 +35,9 @@ class model_data
char* alloc(size_t desired);
void free();

// Set data
int set_data(const char* vw_model, size_t len);

model_data() = default;
~model_data() = default;

Expand Down
14 changes: 14 additions & 0 deletions rlclientlib/model_mgmt/model_mgmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,19 @@ char* model_data::alloc(const size_t desired)

void model_data::free() { _data.clear(); }

int model_data::set_data(const char* vw_model, size_t len)
{
if (vw_model == nullptr || len == 0)
{
return reinforcement_learning::error_code::static_model_load_error;
}

char* buffer = this->alloc(len);
if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; }

memcpy(buffer, vw_model, len);
this->data_sz(len);
}

} // namespace model_management
} // namespace reinforcement_learning

0 comments on commit 150e9f6

Please sign in to comment.