diff --git a/src/branchrule.rs b/src/branchrule.rs index c64ab59..185c168 100644 --- a/src/branchrule.rs +++ b/src/branchrule.rs @@ -1,7 +1,5 @@ use crate::ffi; -use crate::variable::Variable; use scip_sys::SCIP_Result; -use std::rc::Rc; /// A trait for defining custom branching rules. pub trait BranchRule { @@ -45,8 +43,8 @@ impl From for SCIP_Result { /// A candidate for branching. #[derive(Debug, Clone, PartialEq)] pub struct BranchingCandidate { - /// The variable to branch on. - pub var: Rc, + /// The index of the variable to branch on in the current subproblem. + pub var_prob_id: usize, /// The LP solution value of the variable. pub lp_sol_val: f64, /// The fractional part of the LP solution value of the variable. @@ -176,4 +174,49 @@ mod tests { assert!(solved.n_nodes() > 1); } + + struct HighestBoundBranchRule { + model: Model, + } + + impl BranchRule for HighestBoundBranchRule { + fn execute(&mut self, candidates: Vec) -> BranchingResult { + let mut max_bound = f64::NEG_INFINITY; + let mut max_candidate = None; + for candidate in candidates { + let var = self.model.var_in_prob(candidate.var_prob_id).unwrap(); + let bound = var.ub(); + if bound > max_bound { + max_bound = bound; + max_candidate = Some(candidate); + } + } + + if let Some(candidate) = max_candidate { + BranchingResult::BranchOn(candidate) + } else { + BranchingResult::DidNotRun + } + } + } + + #[test] + fn highest_bound_branch_rule() { + let model = Model::new() + .hide_output() + .set_longint_param("limits/nodes", 2) + .unwrap() // only call brancher once + .include_default_plugins() + .read_prob("data/test/gen-ip054.mps") + .unwrap(); + + let br = HighestBoundBranchRule { + model: model.clone_for_plugins(), + }; + let solved = model + .include_branch_rule("", "", 100000, 1000, 1., Box::new(br)) + .solve(); + + assert!(solved.n_nodes() > 1); + } } diff --git a/src/model.rs b/src/model.rs index efc0186..0345a1e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -105,7 +105,21 @@ impl Model { pub fn read_prob(mut self, filename: &str) -> Result, Retcode> { let scip = self.scip.clone(); scip.read_prob(filename)?; - let vars = Rc::new(RefCell::new(self.scip.vars())); + let vars = Rc::new(RefCell::new( + self.scip + .vars() + .into_iter() + .map(|(i, v)| { + ( + i, + Rc::new(Variable { + raw: v, + scip: scip.clone(), + }), + ) + }) + .collect(), + )); let conss = Rc::new(RefCell::new( self.scip .conss() @@ -201,6 +215,10 @@ impl Model { .scip .create_var(lb, ub, obj, name, var_type) .expect("Failed to create variable in state ProblemCreated"); + let var = Variable { + raw: var, + scip: self.scip.clone(), + }; let var_id = var.index(); let var = Rc::new(var); self.state.vars.borrow_mut().insert(var_id, var.clone()); @@ -391,6 +409,10 @@ impl Model { .scip .create_var_solving(lb, ub, obj, name, var_type) .expect("Failed to create variable in state ProblemCreated"); + let var = Variable { + raw: var, + scip: self.scip.clone(), + }; let var_id = var.index(); let var = Rc::new(var); self.state.vars.borrow_mut().insert(var_id, var.clone()); @@ -444,11 +466,31 @@ impl Model { .scip .create_priced_var(lb, ub, obj, name, var_type) .expect("Failed to create variable in state ProblemCreated"); + let var = Variable { + raw: var, + scip: self.scip.clone(), + }; let var = Rc::new(var); let var_id = var.index(); self.state.vars.borrow_mut().insert(var_id, var.clone()); var } + + /// Gets the variable in current problem given its index (in the problem). + /// + /// # Arguments + /// * `var_prob_id` - The index of the variable in the problem. + /// + /// # Returns + /// A reference-counted pointer to the variable. + pub fn var_in_prob(&self, var_prob_id: usize) -> Option { + unsafe { + ScipPtr::var_from_id(self.scip.raw, var_prob_id).map(|v| Variable { + raw: v, + scip: self.scip.clone(), + }) + } + } } impl Model { diff --git a/src/scip.rs b/src/scip.rs index d13814a..5995e1d 100644 --- a/src/scip.rs +++ b/src/scip.rs @@ -6,7 +6,7 @@ use crate::{ }; use crate::{scip_call, HeurTiming, Heuristic}; use core::panic; -use scip_sys::{SCIP_Cons, SCIP_SOL}; +use scip_sys::{SCIP_Cons, SCIP_Var, Scip, SCIP_SOL}; use std::collections::BTreeMap; use std::ffi::{c_int, CStr, CString}; use std::mem::MaybeUninit; @@ -136,7 +136,7 @@ impl ScipPtr { Ok(()) } - pub(crate) fn vars(&self) -> BTreeMap> { + pub(crate) fn vars(&self) -> BTreeMap { // NOTE: this method should only be called once per SCIP instance let n_vars = self.n_vars(); let mut vars = BTreeMap::new(); @@ -146,8 +146,9 @@ impl ScipPtr { unsafe { ffi::SCIPcaptureVar(self.raw, scip_var); } - let var = Rc::new(Variable { raw: scip_var }); - vars.insert(var.index(), var); + let var = scip_var; + let var_id = unsafe { ffi::SCIPvarGetIndex(var) } as usize; + vars.insert(var_id, var); } vars } @@ -198,7 +199,7 @@ impl ScipPtr { obj: f64, name: &str, var_type: VarType, - ) -> Result { + ) -> Result<*mut SCIP_Var, Retcode> { let name = CString::new(name).unwrap(); let mut var_ptr = MaybeUninit::uninit(); scip_call! { ffi::SCIPcreateVarBasic( @@ -212,7 +213,7 @@ impl ScipPtr { ) }; let var_ptr = unsafe { var_ptr.assume_init() }; scip_call! { ffi::SCIPaddVar(self.raw, var_ptr) }; - Ok(Variable { raw: var_ptr }) + Ok(var_ptr) } pub(crate) fn create_var_solving( @@ -222,7 +223,7 @@ impl ScipPtr { obj: f64, name: &str, var_type: VarType, - ) -> Result { + ) -> Result<*mut SCIP_Var, Retcode> { let name = CString::new(name).unwrap(); let mut var_ptr = MaybeUninit::uninit(); scip_call! { ffi::SCIPcreateVarBasic( @@ -240,7 +241,7 @@ impl ScipPtr { scip_call! { ffi::SCIPgetTransformedVar(self.raw, var_ptr, transformed_var.as_mut_ptr()) }; let trans_var_ptr = unsafe { transformed_var.assume_init() }; scip_call! { ffi::SCIPreleaseVar(self.raw, &mut var_ptr) }; - Ok(Variable { raw: trans_var_ptr }) + Ok(trans_var_ptr) } pub(crate) fn create_priced_var( @@ -250,7 +251,7 @@ impl ScipPtr { obj: f64, name: &str, var_type: VarType, - ) -> Result { + ) -> Result<*mut SCIP_Var, Retcode> { let name = CString::new(name).unwrap(); let mut var_ptr = MaybeUninit::uninit(); scip_call! { ffi::SCIPcreateVarBasic( @@ -268,7 +269,7 @@ impl ScipPtr { scip_call! { ffi::SCIPgetTransformedVar(self.raw, var_ptr, transformed_var.as_mut_ptr()) }; let trans_var_ptr = unsafe { transformed_var.assume_init() }; scip_call! { ffi::SCIPreleaseVar(self.raw, &mut var_ptr) }; - Ok(Variable { raw: trans_var_ptr }) + Ok(trans_var_ptr) } pub(crate) fn create_cons( @@ -451,6 +452,15 @@ impl ScipPtr { Ok(scip_cons) } + pub(crate) unsafe fn var_from_id(scip: *mut Scip, var_prob_id: usize) -> Option<*mut SCIP_Var> { + let n_vars = ffi::SCIPgetNVars(scip) as usize; + let var = *ffi::SCIPgetVars(scip).add(var_prob_id); + if var_prob_id >= n_vars { + None + } else { + Some(var) + } + } pub(crate) fn create_cons_indicator( &self, bin_var: Rc, @@ -499,7 +509,9 @@ impl ScipPtr { Ok(()) } - pub(crate) fn lp_branching_cands(scip: *mut ffi::SCIP) -> Vec { + pub(crate) unsafe fn lp_branching_cands( + scip: *mut ffi::SCIP, + ) -> Vec<(*mut SCIP_Var, f64, f64)> { let mut lpcands = MaybeUninit::uninit(); let mut lpcandssol = MaybeUninit::uninit(); // let mut lpcandsfrac = MaybeUninit::uninit(); @@ -525,24 +537,25 @@ impl ScipPtr { let mut cands = Vec::with_capacity(nlpcands as usize); for i in 0..nlpcands { let var_ptr = unsafe { *lpcands.add(i as usize) }; - let var = Rc::new(Variable { raw: var_ptr }); + let var = var_ptr; let lp_sol_val = unsafe { *lpcandssol.add(i as usize) }; let frac = lp_sol_val.fract(); - cands.push(BranchingCandidate { - var, - lp_sol_val, - frac, - }); + cands.push((var, lp_sol_val, frac)); } cands } - pub(crate) fn branch_var_val( + pub(crate) unsafe fn branch_var_val( scip: *mut ffi::SCIP, - var: *mut ffi::SCIP_VAR, + var_prob_id: usize, val: f64, ) -> Result<(), Retcode> { - scip_call! { ffi::SCIPbranchVarVal(scip, var, val, std::ptr::null_mut(), std::ptr::null_mut(),std::ptr::null_mut()) }; + let var = ScipPtr::var_from_id(scip, var_prob_id); + if var.is_none() { + return Err(Retcode::Error); + } + let var = var.unwrap(); + scip_call! { ffi::SCIPbranchVarVal(scip, var, val, std::ptr::null_mut(), std::ptr::null_mut(),std::ptr::null_mut()) } Ok(()) } @@ -642,11 +655,20 @@ impl ScipPtr { let data_ptr = unsafe { ffi::SCIPbranchruleGetData(branchrule) }; assert!(!data_ptr.is_null()); let rule_ptr = data_ptr as *mut Box; - let cands = ScipPtr::lp_branching_cands(scip); + let cands = unsafe { ScipPtr::lp_branching_cands(scip) } + .into_iter() + .map(|(scip_var, lp_sol_val, frac)| BranchingCandidate { + var_prob_id: unsafe { ffi::SCIPvarGetProbindex(scip_var) } as usize, + lp_sol_val, + frac, + }) + .collect::>(); let branching_res = unsafe { (*rule_ptr).execute(cands) }; if let BranchingResult::BranchOn(cand) = branching_res.clone() { - ScipPtr::branch_var_val(scip, cand.var.raw, cand.lp_sol_val).unwrap(); + unsafe { + ScipPtr::branch_var_val(scip, cand.var_prob_id, cand.lp_sol_val).unwrap(); + } }; if branching_res == BranchingResult::CustomBranching { @@ -923,7 +945,8 @@ impl ScipPtr { var.raw }; - scip_call! { ffi::SCIPaddCoefLinear(self.raw, cons_ptr, var_ptr, coef) }; + scip_call! { ffi::SCIPaddCoefLinear(self.raw, cons_ptr, var_ptr, coef) } + Ok(()) } diff --git a/src/variable.rs b/src/variable.rs index d799e24..e5bf455 100644 --- a/src/variable.rs +++ b/src/variable.rs @@ -1,16 +1,28 @@ use crate::ffi; +use crate::scip::ScipPtr; use core::panic; use scip_sys::SCIP_Status; +use std::rc::Rc; /// A type alias for a variable ID. pub type VarId = usize; /// A wrapper for a mutable reference to a SCIP variable. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug)] +#[allow(dead_code)] pub struct Variable { pub(crate) raw: *mut ffi::SCIP_VAR, + pub(crate) scip: Rc, } +impl PartialEq for Variable { + fn eq(&self, other: &Self) -> bool { + self.index() == other.index() && self.raw == other.raw + } +} + +impl Eq for Variable {} + impl Variable { #[cfg(feature = "raw")] /// Returns a raw pointer to the underlying `ffi::SCIP_VAR` struct. @@ -133,7 +145,7 @@ impl From for VarStatus { #[cfg(test)] mod tests { use super::*; - use crate::Model; + use crate::{Model, ObjSense}; #[test] fn var_data() { @@ -151,4 +163,19 @@ mod tests { #[cfg(feature = "raw")] assert!(!var.inner().is_null()); } + + #[test] + fn var_memory_safety() { + let mut model = Model::new() + .hide_output() + .include_default_plugins() + .create_prob("test") + .set_obj_sense(ObjSense::Maximize); + + let x1 = model.add_var(0., f64::INFINITY, 3., "x1", VarType::Integer); + + drop(model); + + assert_eq!(x1.name(), "x1"); + } }