Skip to content

Commit

Permalink
[mbedtls] Support mbedtls 3.x and fix bugs and compiler warnings. (Ha…
Browse files Browse the repository at this point in the history
…xeFoundation#11646)

* [mbedtls] Store bio functions in a GC root.

* [mbedtls] Fix incorrect alt name check.

* [mbedtls] Replace String_val with Bytes_val to prevent compiler warnings.

* [mbedtls] use SecTrustCopyAnchorCertificates to get root certs on macOS.

SecKeychainOpen is deprecated.

* [mbedtls] Remove unused includes and use angled brackets.

* [mbedtls] Fix more warnings.

* [mbedtls] Support mbedtls 3.x.
  • Loading branch information
Apprentice-Alchemist authored Apr 26, 2024
1 parent 547b510 commit c325889
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 62 deletions.
4 changes: 2 additions & 2 deletions libs/mbedtls/mbedtls.ml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ external mbedtls_ssl_setup : mbedtls_ssl_context -> mbedtls_ssl_config -> mbedtl
external mbedtls_ssl_write : mbedtls_ssl_context -> bytes -> int -> int -> mbedtls_result = "ml_mbedtls_ssl_write"

external mbedtls_pk_init : unit -> mbedtls_pk_context = "ml_mbedtls_pk_init"
external mbedtls_pk_parse_key : mbedtls_pk_context -> bytes -> string option -> mbedtls_result = "ml_mbedtls_pk_parse_key"
external mbedtls_pk_parse_keyfile : mbedtls_pk_context -> string -> string option -> mbedtls_result = "ml_mbedtls_pk_parse_keyfile"
external mbedtls_pk_parse_key : mbedtls_pk_context -> bytes -> string option -> mbedtls_ctr_drbg_context -> mbedtls_result = "ml_mbedtls_pk_parse_key"
external mbedtls_pk_parse_keyfile : mbedtls_pk_context -> string -> string option -> mbedtls_ctr_drbg_context -> mbedtls_result = "ml_mbedtls_pk_parse_keyfile"
external mbedtls_pk_parse_public_keyfile : mbedtls_pk_context -> string -> mbedtls_result = "ml_mbedtls_pk_parse_public_keyfile"
external mbedtls_pk_parse_public_key : mbedtls_pk_context -> bytes -> mbedtls_result = "ml_mbedtls_pk_parse_public_key"

Expand Down
103 changes: 51 additions & 52 deletions libs/mbedtls/mbedtls_stubs.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <ctype.h>
#include <string.h>
#include <stdio.h>

Expand All @@ -18,13 +17,10 @@
#include <caml/callback.h>
#include <caml/custom.h>

#include "mbedtls/debug.h"
#include "mbedtls/error.h"
#include "mbedtls/config.h"
#include "mbedtls/ssl.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/certs.h"
#include "mbedtls/oid.h"

#define PVoid_val(v) (*((void**) Data_custom_val(v)))
Expand Down Expand Up @@ -84,7 +80,7 @@ CAMLprim value ml_mbedtls_ctr_drbg_init(void) {

CAMLprim value ml_mbedtls_ctr_drbg_random(value p_rng, value output, value output_len) {
CAMLparam3(p_rng, output, output_len);
CAMLreturn(Val_int(mbedtls_ctr_drbg_random(CtrDrbg_val(p_rng), String_val(output), Int_val(output_len))));
CAMLreturn(Val_int(mbedtls_ctr_drbg_random(CtrDrbg_val(p_rng), Bytes_val(output), Int_val(output_len))));
}

CAMLprim value ml_mbedtls_ctr_drbg_seed(value ctx, value p_entropy, value custom) {
Expand Down Expand Up @@ -124,7 +120,7 @@ CAMLprim value ml_mbedtls_entropy_init(void) {

CAMLprim value ml_mbedtls_entropy_func(value data, value output, value len) {
CAMLparam3(data, output, len);
CAMLreturn(Val_int(mbedtls_entropy_func(PVoid_val(data), String_val(output), Int_val(len))));
CAMLreturn(Val_int(mbedtls_entropy_func(PVoid_val(data), Bytes_val(output), Int_val(len))));
}

// Certificate
Expand Down Expand Up @@ -171,7 +167,7 @@ CAMLprim value ml_mbedtls_x509_next(value chain) {

CAMLprim value ml_mbedtls_x509_crt_parse(value chain, value bytes) {
CAMLparam2(chain, bytes);
const char* buf = String_val(bytes);
const unsigned char* buf = Bytes_val(bytes);
int len = caml_string_length(bytes);
CAMLreturn(Val_int(mbedtls_x509_crt_parse(X509Crt_val(chain), buf, len + 1)));
}
Expand All @@ -191,16 +187,19 @@ CAMLprim value ml_mbedtls_x509_crt_parse_path(value chain, value path) {
value caml_string_of_asn1_buf(mbedtls_asn1_buf* dat) {
CAMLparam0();
CAMLlocal1(s);
s = caml_alloc_string(dat->len);
memcpy(String_val(s), dat->p, dat->len);
s = caml_alloc_initialized_string(dat->len, (const char *)dat->p);
CAMLreturn(s);
}

CAMLprim value hx_cert_get_alt_names(value chain) {
CAMLparam1(chain);
CAMLlocal1(obj);
mbedtls_x509_crt* cert = X509Crt_val(chain);
if (cert->ext_types & MBEDTLS_X509_EXT_SUBJECT_ALT_NAME == 0 || &cert->subject_alt_names == NULL) {
#if MBEDTLS_VERSION_MAJOR >= 3
if (!mbedtls_x509_crt_has_ext_type(cert, MBEDTLS_X509_EXT_SUBJECT_ALT_NAME)) {
#else
if ((cert->ext_types & MBEDTLS_X509_EXT_SUBJECT_ALT_NAME) == 0) {
#endif
obj = Atom(0);
} else {
mbedtls_asn1_sequence* cur = &cert->subject_alt_names;
Expand Down Expand Up @@ -366,29 +365,39 @@ CAMLprim value ml_mbedtls_pk_init(void) {
CAMLreturn(obj);
}

CAMLprim value ml_mbedtls_pk_parse_key(value ctx, value key, value password) {
CAMLparam3(ctx, key, password);
const char* pwd = NULL;
CAMLprim value ml_mbedtls_pk_parse_key(value ctx, value key, value password, value rng) {
CAMLparam4(ctx, key, password, rng);
const unsigned char* pwd = NULL;
size_t pwdlen = 0;
if (password != Val_none) {
pwd = String_val(Field(password, 0));
pwd = Bytes_val(Field(password, 0));
pwdlen = caml_string_length(Field(password, 0));
}
CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), String_val(key), caml_string_length(key) + 1, pwd, pwdlen));
#if MBEDTLS_VERSION_MAJOR >= 3
mbedtls_ctr_drbg_context *ctr_drbg = CtrDrbg_val(rng);
CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1, pwd, pwdlen, mbedtls_ctr_drbg_random, NULL));
#else
CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1, pwd, pwdlen));
#endif
}

CAMLprim value ml_mbedtls_pk_parse_keyfile(value ctx, value path, value password) {
CAMLparam3(ctx, path, password);
CAMLprim value ml_mbedtls_pk_parse_keyfile(value ctx, value path, value password, value rng) {
CAMLparam4(ctx, path, password, rng);
const char* pwd = NULL;
if (password != Val_none) {
pwd = String_val(Field(password, 0));
}
#if MBEDTLS_VERSION_MAJOR >= 3
mbedtls_ctr_drbg_context *ctr_drbg = CtrDrbg_val(rng);
CAMLreturn(mbedtls_pk_parse_keyfile(PkContext_val(ctx), String_val(path), pwd, mbedtls_ctr_drbg_random, ctr_drbg));
#else
CAMLreturn(mbedtls_pk_parse_keyfile(PkContext_val(ctx), String_val(path), pwd));
#endif
}

CAMLprim value ml_mbedtls_pk_parse_public_key(value ctx, value key) {
CAMLparam2(ctx, key);
CAMLreturn(mbedtls_pk_parse_public_key(PkContext_val(ctx), String_val(key), caml_string_length(key) + 1));
CAMLreturn(mbedtls_pk_parse_public_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1));
}

CAMLprim value ml_mbedtls_pk_parse_public_keyfile(value ctx, value path) {
Expand Down Expand Up @@ -446,23 +455,22 @@ CAMLprim value ml_mbedtls_ssl_handshake(value ssl) {

CAMLprim value ml_mbedtls_ssl_read(value ssl, value buf, value pos, value len) {
CAMLparam4(ssl, buf, pos, len);
CAMLreturn(Val_int(mbedtls_ssl_read(SslContext_val(ssl), String_val(buf) + Int_val(pos), Int_val(len))));
CAMLreturn(Val_int(mbedtls_ssl_read(SslContext_val(ssl), Bytes_val(buf) + Int_val(pos), Int_val(len))));
}

static int bio_write_cb(void* ctx, const unsigned char* buf, size_t len) {
CAMLparam0();
CAMLlocal3(r, s, vctx);
vctx = (value)ctx;
s = caml_alloc_string(len);
memcpy(String_val(s), buf, len);
vctx = *(value*)ctx;
s = caml_alloc_initialized_string(len, (const char*)buf);
r = caml_callback2(Field(vctx, 1), Field(vctx, 0), s);
CAMLreturn(Int_val(r));
}

static int bio_read_cb(void* ctx, unsigned char* buf, size_t len) {
CAMLparam0();
CAMLlocal3(r, s, vctx);
vctx = (value)ctx;
vctx = *(value*)ctx;
s = caml_alloc_string(len);
r = caml_callback2(Field(vctx, 2), Field(vctx, 0), s);
memcpy(buf, String_val(s), len);
Expand All @@ -476,7 +484,11 @@ CAMLprim value ml_mbedtls_ssl_set_bio(value ssl, value p_bio, value f_send, valu
Store_field(ctx, 0, p_bio);
Store_field(ctx, 1, f_send);
Store_field(ctx, 2, f_recv);
mbedtls_ssl_set_bio(SslContext_val(ssl), (void*)ctx, bio_write_cb, bio_read_cb, NULL);
// TODO: this allocation is leaked
value *location = malloc(sizeof(value));
*location = ctx;
caml_register_generational_global_root(location);
mbedtls_ssl_set_bio(SslContext_val(ssl), (void*)location, bio_write_cb, bio_read_cb, NULL);
CAMLreturn(Val_unit);
}

Expand All @@ -492,7 +504,7 @@ CAMLprim value ml_mbedtls_ssl_setup(value ssl, value conf) {

CAMLprim value ml_mbedtls_ssl_write(value ssl, value buf, value pos, value len) {
CAMLparam4(ssl, buf, pos, len);
CAMLreturn(Val_int(mbedtls_ssl_write(SslContext_val(ssl), String_val(buf) + Int_val(pos), Int_val(len))));
CAMLreturn(Val_int(mbedtls_ssl_write(SslContext_val(ssl), Bytes_val(buf) + Int_val(pos), Int_val(len))));
}

// glue
Expand Down Expand Up @@ -520,36 +532,23 @@ CAMLprim value hx_cert_load_defaults(value certificate) {
#endif

#ifdef __APPLE__
CFMutableDictionaryRef search;
CFArrayRef result;
SecKeychainRef keychain;
SecCertificateRef item;
CFDataRef dat;
// Load keychain
if (SecKeychainOpen("/System/Library/Keychains/SystemRootCertificates.keychain", &keychain) == errSecSuccess) {
// Search for certificates
search = CFDictionaryCreateMutable(NULL, 0, NULL, NULL);
CFDictionarySetValue(search, kSecClass, kSecClassCertificate);
CFDictionarySetValue(search, kSecMatchLimit, kSecMatchLimitAll);
CFDictionarySetValue(search, kSecReturnRef, kCFBooleanTrue);
CFDictionarySetValue(search, kSecMatchSearchList, CFArrayCreate(NULL, (const void **)&keychain, 1, NULL));
if (SecItemCopyMatching(search, (CFTypeRef *)&result) == errSecSuccess) {
CFIndex n = CFArrayGetCount(result);
for (CFIndex i = 0; i < n; i++) {
item = (SecCertificateRef)CFArrayGetValueAtIndex(result, i);

// Get certificate in DER format
dat = SecCertificateCopyData(item);
if (dat) {
r = mbedtls_x509_crt_parse_der(chain, (unsigned char *)CFDataGetBytePtr(dat), CFDataGetLength(dat));
CFRelease(dat);
if (r != 0) {
CAMLreturn(Val_int(r));
}
CFArrayRef certs;
if (SecTrustCopyAnchorCertificates(&certs) == errSecSuccess) {
CFIndex count = CFArrayGetCount(certs);
for(CFIndex i = 0; i < count; i++) {
SecCertificateRef item = (SecCertificateRef)CFArrayGetValueAtIndex(certs, i);

// Get certificate in DER format
CFDataRef data = SecCertificateCopyData(item);
if(data) {
r = mbedtls_x509_crt_parse_der(chain, (unsigned char *)CFDataGetBytePtr(data), CFDataGetLength(data));
CFRelease(data);
if (r != 0) {
CAMLreturn(Val_int(r));
}
}
}
CFRelease(keychain);
CFRelease(certs);
}
#endif

Expand Down
8 changes: 4 additions & 4 deletions src/macro/eval/evalSsl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ let init_fields init_fields builtins =
"strerror",vfun1 (fun code -> encode_string (mbedtls_strerror (decode_int code)));
] [];
init_fields builtins (["mbedtls"],"PkContext") [] [
"parse_key",vifun2 (fun this key password ->
vint (mbedtls_pk_parse_key (as_pk_context this) (decode_bytes key) (match password with VNull -> None | _ -> Some (decode_string password)));
"parse_key",vifun3 (fun this key password rng ->
vint (mbedtls_pk_parse_key (as_pk_context this) (decode_bytes key) (match password with VNull -> None | _ -> Some (decode_string password)) (as_ctr_drbg rng));
);
"parse_keyfile",vifun2 (fun this path password ->
vint (mbedtls_pk_parse_keyfile (as_pk_context this) (decode_string path) (match password with VNull -> None | _ -> Some (decode_string password)));
"parse_keyfile",vifun3 (fun this path password rng ->
vint (mbedtls_pk_parse_keyfile (as_pk_context this) (decode_string path) (match password with VNull -> None | _ -> Some (decode_string password)) (as_ctr_drbg rng));
);
"parse_public_key",vifun1 (fun this key ->
vint (mbedtls_pk_parse_public_key (as_pk_context this) (decode_bytes key));
Expand Down
4 changes: 2 additions & 2 deletions std/eval/_std/mbedtls/PkContext.hx
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import haxe.io.Bytes;
extern class PkContext {
function new():Void;

function parse_key(key:Bytes, ?pwd:String):Int;
function parse_keyfile(path:String, ?password:String):Int;
function parse_key(key:Bytes, ?pwd:String, ctr_dbg: CtrDrbg):Int;
function parse_keyfile(path:String, ?password:String, ctr_dbg: CtrDrbg):Int;
function parse_public_key(key:Bytes):Int;
function parse_public_keyfile(path:String):Int;
}
4 changes: 2 additions & 2 deletions std/eval/_std/sys/ssl/Key.hx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Key {
var code = if (isPublic) {
key.native.parse_public_keyfile(file);
} else {
key.native.parse_keyfile(file, pass);
key.native.parse_keyfile(file, pass, Mbedtls.getDefaultCtrDrbg());
}
if (code != 0) {
throw(mbedtls.Error.strerror(code));
Expand All @@ -51,7 +51,7 @@ class Key {
var code = if (isPublic) {
key.native.parse_public_key(data);
} else {
key.native.parse_key(data);
key.native.parse_key(data, null, Mbedtls.getDefaultCtrDrbg());
}
if (code != 0) {
throw(mbedtls.Error.strerror(code));
Expand Down

0 comments on commit c325889

Please sign in to comment.