Skip to content

Commit

Permalink
Merge pull request #107 from MLanguage/var-unicity
Browse files Browse the repository at this point in the history
Unify SSA duplicated variables of Mir in Bir
  • Loading branch information
Raphaël Monat authored Feb 25, 2022
2 parents 603531c + 59e9f68 commit c4e6dd9
Show file tree
Hide file tree
Showing 27 changed files with 841 additions and 541 deletions.
124 changes: 62 additions & 62 deletions src/mlang/backend_compilers/bir_to_c.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
You should have received a copy of the GNU General Public License along with
this program. If not, see <https://www.gnu.org/licenses/>. *)

open Mir
open Bir

let none_value = "m_undefined"

Expand Down Expand Up @@ -45,37 +45,39 @@ type offset =
| PassPointer
| None

let generate_variable (var_indexes : int Mir.VariableMap.t) (offset : offset)
(fmt : Format.formatter) (var : Variable.t) : unit =
let generate_variable (var_indexes : int VariableMap.t) (offset : offset)
(fmt : Format.formatter) (var : variable) : unit =
let mvar = var_to_mir var in
let var_index =
match Mir.VariableMap.find_opt var var_indexes with
match VariableMap.find_opt var var_indexes with
| Some i -> i
| None ->
Errors.raise_error
(Format.asprintf "Variable %s not found in TGV"
(Pos.unmark var.Mir.Variable.name))
(Pos.unmark mvar.Mir.Variable.name))
in
match offset with
| PassPointer ->
Format.fprintf fmt "(TGV + %d/*%s*/)" var_index
(Pos.unmark var.Mir.Variable.name)
(Pos.unmark mvar.Mir.Variable.name)
| _ ->
Format.fprintf fmt "TGV[%d/*%s*/%s]" var_index
(Pos.unmark var.Mir.Variable.name)
(Pos.unmark mvar.Mir.Variable.name)
(match offset with
| None -> ""
| GetValueVar offset -> " + " ^ offset
| GetValueConst offset -> " + " ^ string_of_int offset
| PassPointer -> assert false)

let generate_raw_name (v : Variable.t) : string =
match v.alias with Some v -> v | None -> Pos.unmark v.Variable.name
let generate_raw_name (v : variable) : string =
let v = var_to_mir v in
match v.alias with Some v -> v | None -> Pos.unmark v.Mir.Variable.name

let generate_name (v : Variable.t) : string = "v_" ^ generate_raw_name v
let generate_name (v : variable) : string = "v_" ^ generate_raw_name v

let rec generate_c_expr (e : expression Pos.marked)
(var_indexes : int Mir.VariableMap.t) :
string * (LocalVariable.t * expression Pos.marked) list =
(var_indexes : int VariableMap.t) :
string * (Mir.LocalVariable.t * expression Pos.marked) list =
match Pos.unmark e with
| Comparison (op, e1, e2) ->
let se1, s1 = generate_c_expr e1 var_indexes in
Expand All @@ -92,7 +94,9 @@ let rec generate_c_expr (e : expression Pos.marked)
(Format.asprintf "%s(%s)" (generate_unop op) se, s)
| Index (var, e) ->
let se, s = generate_c_expr e var_indexes in
let size = Option.get (Pos.unmark var).Mir.Variable.is_table in
let size =
Option.get (var_to_mir (Pos.unmark var)).Mir.Variable.is_table
in
( Format.asprintf "m_array_index(%a, %s, %d)"
(generate_variable var_indexes PassPointer)
(Pos.unmark var) se size,
Expand Down Expand Up @@ -134,26 +138,25 @@ let rec generate_c_expr (e : expression Pos.marked)
| Literal Undefined -> (Format.asprintf "%s" none_value, [])
| Var var ->
(Format.asprintf "%a" (generate_variable var_indexes None) var, [])
| LocalVar lvar -> (Format.asprintf "LOCAL[%d]" lvar.LocalVariable.id, [])
| LocalVar lvar -> (Format.asprintf "LOCAL[%d]" lvar.Mir.LocalVariable.id, [])
| GenericTableIndex -> (Format.asprintf "m_literal(generic_index)", [])
| Error -> assert false (* should not happen *)
| LocalLet (lvar, e1, e2) ->
let _, s1 = generate_c_expr e1 var_indexes in
let se2, s2 = generate_c_expr e2 var_indexes in
(Format.asprintf "%s" se2, s1 @ ((lvar, e1) :: s2))

let format_local_vars_defs (var_indexes : int Mir.VariableMap.t)
let format_local_vars_defs (var_indexes : int VariableMap.t)
(fmt : Format.formatter)
(defs : (LocalVariable.t * expression Pos.marked) list) =
(defs : (Mir.LocalVariable.t * expression Pos.marked) list) =
List.iter
(fun (lvar, e) ->
let se, _ = generate_c_expr e var_indexes in
Format.fprintf fmt "LOCAL[%d] = %s;@\n" lvar.LocalVariable.id se)
Format.fprintf fmt "LOCAL[%d] = %s;@\n" lvar.Mir.LocalVariable.id se)
defs

let generate_var_def (var_indexes : int Mir.VariableMap.t)
(var : Mir.Variable.t) (data : Mir.variable_data) (oc : Format.formatter) :
unit =
let generate_var_def (var_indexes : int VariableMap.t) (var : variable)
(data : variable_data) (oc : Format.formatter) : unit =
match data.var_definition with
| SimpleVar e ->
let se, defs = generate_c_expr e var_indexes in
Expand All @@ -165,7 +168,7 @@ let generate_var_def (var_indexes : int Mir.VariableMap.t)
| TableVar (_, IndexTable es) ->
Format.fprintf oc "%a"
(fun fmt ->
IndexMap.iter (fun i v ->
Mir.IndexMap.iter (fun i v ->
let sv, defs = generate_c_expr v var_indexes in
Format.fprintf fmt "%a%a = %s;@\n"
(format_local_vars_defs var_indexes)
Expand All @@ -188,8 +191,8 @@ let generate_var_def (var_indexes : int Mir.VariableMap.t)
* (Pos.get_position e) *)
| InputVar -> assert false

let generate_var_cond (var_indexes : int Mir.VariableMap.t)
(cond : condition_data) (oc : Format.formatter) =
let generate_var_cond (var_indexes : int VariableMap.t) (cond : condition_data)
(oc : Format.formatter) =
if (fst cond.cond_error).typ = Mast.Anomaly then
let scond, defs = generate_c_expr cond.cond_expr var_indexes in
let percent = Re.Pcre.regexp "%" in
Expand All @@ -211,16 +214,15 @@ let generate_var_cond (var_indexes : int Mir.VariableMap.t)
let error_descr =
Re.Pcre.substitute ~rex:percent ~subst:(fun _ -> "%%") error_descr
in
Format.fprintf fmt "%s: %s" (Pos.unmark err.Error.name) error_descr)
Format.fprintf fmt "%s: %s" (Pos.unmark err.Mir.Error.name) error_descr)
(fst cond.cond_error)

let fresh_cond_counter = ref 0

let rec generate_stmt (program : Bir.program)
(var_indexes : int Mir.VariableMap.t) (oc : Format.formatter)
(stmt : Bir.stmt) =
let rec generate_stmt (program : program) (var_indexes : int VariableMap.t)
(oc : Format.formatter) (stmt : stmt) =
match Pos.unmark stmt with
| Bir.SAssign (var, vdata) -> generate_var_def var_indexes var vdata oc
| SAssign (var, vdata) -> generate_var_def var_indexes var vdata oc
| SConditional (cond, tt, ff) ->
let pos = Pos.get_position stmt in
let fname =
Expand Down Expand Up @@ -253,44 +255,42 @@ let rec generate_stmt (program : Bir.program)
ff
| SVerif v -> generate_var_cond var_indexes v oc
| SRuleCall r ->
let rule = Bir.RuleMap.find r program.rules in
let rule = RuleMap.find r program.rules in
generate_rule_function_header ~definition:false oc rule
| SFunctionCall (f, _) ->
Format.fprintf oc "if(%s(output, TGV, LOCAL)) {return -1;};\n" f

and generate_stmts (program : Bir.program) (var_indexes : int Mir.VariableMap.t)
(oc : Format.formatter) (stmts : Bir.stmt list) =
and generate_stmts (program : program) (var_indexes : int VariableMap.t)
(oc : Format.formatter) (stmts : stmt list) =
Format.pp_print_list (generate_stmt program var_indexes) oc stmts

and generate_rule_function_header ~(definition : bool) (oc : Format.formatter)
(rule : Bir.rule) =
(rule : rule) =
let arg_type = if definition then "m_value *" else "" in
let ret_type = if definition then "void " else "" in
Format.fprintf oc "%sm_rule_%s(%sTGV, %sLOCAL)%s@\n" ret_type rule.rule_name
arg_type arg_type
(if definition then "" else ";")

let generate_rule_function (program : Bir.program)
(var_indexes : int Mir.VariableMap.t) (oc : Format.formatter)
(rule : Bir.rule) =
let generate_rule_function (program : program) (var_indexes : int VariableMap.t)
(oc : Format.formatter) (rule : rule) =
Format.fprintf oc "%a@[<v 2>{@ %a@]@;}@\n"
(generate_rule_function_header ~definition:true)
rule
(generate_stmts program var_indexes)
rule.rule_stmts

let generate_rule_functions (program : Bir.program)
(var_indexes : int Mir.VariableMap.t) (oc : Format.formatter)
(rules : Bir.rule Bir.RuleMap.t) =
let generate_rule_functions (program : program)
(var_indexes : int VariableMap.t) (oc : Format.formatter)
(rules : rule RuleMap.t) =
Format.pp_print_list ~pp_sep:Format.pp_print_cut
(generate_rule_function program var_indexes)
oc
(Bir.RuleMap.bindings rules |> List.map snd)
(RuleMap.bindings rules |> List.map snd)

let generate_mpp_function (program : Bir.program)
(var_indexes : int Mir.VariableMap.t) (oc : Format.formatter)
(f : Bir.function_name) =
let stmts = Bir.FunctionMap.find f program.mpp_functions in
let generate_mpp_function (program : program) (var_indexes : int VariableMap.t)
(oc : Format.formatter) (f : function_name) =
let stmts = FunctionMap.find f program.mpp_functions in
Format.fprintf oc
"@[<hv 4>int %s(m_output*output, m_value* TGV, m_value* LOCAL) {@,\
m_value cond;@,\
Expand All @@ -301,7 +301,7 @@ let generate_mpp_function (program : Bir.program)
stmts

let generate_mpp_functions (program : Bir.program) (oc : Format.formatter)
(var_indexes : int Mir.VariableMap.t) =
(var_indexes : int VariableMap.t) =
Bir.FunctionMap.iter
(fun fname _ -> generate_mpp_function program var_indexes oc fname)
(Bir_interface.context_agnostic_mpp_functions program)
Expand All @@ -311,8 +311,8 @@ let generate_main_function_signature (oc : Format.formatter)
Format.fprintf oc "int m_extracted(m_output *output, const m_input *input)%s"
(if add_semicolon then ";" else "")

let generate_main_function_signature_and_var_decls (p : Bir.program)
(var_indexes : int Mir.VariableMap.t) (var_table_size : int)
let generate_main_function_signature_and_var_decls (p : program)
(var_indexes : int VariableMap.t) (var_table_size : int)
(oc : Format.formatter) (function_spec : Bir_interface.bir_function) =
let input_vars =
List.map fst (VariableMap.bindings function_spec.func_variable_inputs)
Expand All @@ -324,7 +324,7 @@ let generate_main_function_signature_and_var_decls (p : Bir.program)
(* here, we need to generate a table that can host all the local vars. the
index inside the table will be the id of the local var so we generate a
table big enough so that the highest id is always in bounds *)
let size_locals = Bir.get_locals_size p + 1 in
let size_locals = get_locals_size p + 1 in
Format.fprintf oc "m_value *LOCAL = malloc(%d * sizeof(m_value));@\n@\n"
size_locals;
Format.fprintf oc "m_value *TGV = malloc(%d * sizeof(m_value));@\n@\n"
Expand All @@ -341,8 +341,8 @@ let generate_main_function_signature_and_var_decls (p : Bir.program)

Format.fprintf oc "m_value cond;@\n@\n"

let generate_return (var_indexes : int Mir.VariableMap.t)
(oc : Format.formatter) (function_spec : Bir_interface.bir_function) =
let generate_return (var_indexes : int VariableMap.t) (oc : Format.formatter)
(function_spec : Bir_interface.bir_function) =
let returned_variables =
List.map fst (VariableMap.bindings function_spec.func_outputs)
in
Expand Down Expand Up @@ -481,7 +481,7 @@ let generate_input_type (oc : Format.formatter)
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt var ->
Format.fprintf fmt "m_value %s; // %s" (generate_name var)
(Pos.unmark var.Variable.descr)))
(Pos.unmark (var_to_mir var).Mir.Variable.descr)))
input_vars

let generate_empty_output_prototype (oc : Format.formatter)
Expand Down Expand Up @@ -543,7 +543,7 @@ let generate_get_output_index_func (oc : Format.formatter)
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (var, i) ->
Format.fprintf fmt "if (strcmp(\"%s\", name) == 0) { return %d; }"
(Pos.unmark var.Mir.Variable.name)
(Pos.unmark (var_to_mir var).Mir.Variable.name)
i))
(List.mapi (fun i x -> (x, i)) output_vars)

Expand All @@ -569,7 +569,7 @@ let generate_get_output_name_from_index_func (oc : Format.formatter)
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (var, i) ->
Format.fprintf fmt "if (index == %d) { return \"%s\"; }" i
(Pos.unmark var.Mir.Variable.name)))
(Pos.unmark (var_to_mir var).Mir.Variable.name)))
(List.mapi (fun i x -> (x, i)) output_vars)

let generate_get_output_num_prototype (oc : Format.formatter)
Expand Down Expand Up @@ -600,15 +600,15 @@ let generate_output_type (oc : Format.formatter)
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt var ->
Format.fprintf fmt "m_value %s; // %s" (generate_name var)
(Pos.unmark var.Variable.descr)))
(Pos.unmark (var_to_mir var).Mir.Variable.descr)))
output_vars

let generate_implem_header oc header_filename =
Format.fprintf oc "// File generated by the Mlang compiler\n\n";
Format.fprintf oc "#include <string.h>\n";
Format.fprintf oc "#include \"%s\"\n\n" header_filename

let generate_c_program (program : Bir.program)
let generate_c_program (program : program)
(function_spec : Bir_interface.bir_function) (filename : string) : unit =
if Filename.extension filename <> ".c" then
Errors.raise_error
Expand All @@ -633,25 +633,25 @@ let generate_c_program (program : Bir.program)
close_out _oc;
let _oc = open_out filename in
let oc = Format.formatter_of_out_channel _oc in
Format.fprintf oc "%a%a%a%a%a%a%a%a%a%a%a%a%a%a%a%a"
generate_implem_header header_filename
Format.fprintf oc "%a%a%a%a%a%a%a%a%a%a%a%a%a%a%a%a"
generate_implem_header header_filename
generate_empty_input_func function_spec
generate_input_from_array_func function_spec
generate_get_input_index_func function_spec
generate_input_from_array_func function_spec
generate_get_input_index_func function_spec
generate_get_input_name_from_index_func function_spec
generate_get_input_num_func function_spec
generate_output_to_array_func function_spec
generate_get_input_num_func function_spec
generate_output_to_array_func function_spec
generate_get_output_index_func function_spec
generate_get_output_name_from_index_func function_spec
generate_get_output_num_func function_spec
generate_get_output_num_func function_spec
generate_empty_output_func function_spec
(generate_rule_functions program var_indexes)
program.rules
(generate_mpp_functions program)
var_indexes
(generate_main_function_signature_and_var_decls program var_indexes
var_table_size) function_spec
(generate_stmts program var_indexes)
(generate_stmts program var_indexes)
(Bir.main_statements program)
(generate_return var_indexes) function_spec;
close_out _oc[@@ocamlformat "disable"]
Loading

0 comments on commit c4e6dd9

Please sign in to comment.