1
1
use std:: env:: { consts:: EXE_SUFFIX , split_paths} ;
2
2
use std:: ffi:: { OsStr , OsString } ;
3
3
use std:: fmt;
4
- use std:: io:: Write ;
4
+ use std:: io:: { self , Write } ;
5
5
use std:: os:: windows:: ffi:: { OsStrExt , OsStringExt } ;
6
6
use std:: path:: Path ;
7
7
use std:: process:: Command ;
8
- use std:: sync:: { Arc , Mutex } ;
8
+ use std:: sync:: { Arc , LockResult , Mutex , MutexGuard } ;
9
9
10
10
use anyhow:: { anyhow, Context , Result } ;
11
11
use tracing:: { info, warn} ;
@@ -20,6 +20,7 @@ use crate::utils::utils;
20
20
use crate :: utils:: Notification ;
21
21
22
22
use winreg:: enums:: { RegType , HKEY_CURRENT_USER , KEY_READ , KEY_WRITE } ;
23
+ use winreg:: types:: { FromRegValue , ToRegValue } ;
23
24
use winreg:: { RegKey , RegValue } ;
24
25
25
26
pub ( crate ) fn ensure_prompt ( process : & Process ) -> Result < ( ) > {
@@ -807,16 +808,85 @@ pub(crate) fn delete_rustup_and_cargo_home(process: &Process) -> Result<()> {
807
808
Ok ( ( ) )
808
809
}
809
810
810
- #[ cfg( test) ]
811
- mod tests {
812
- use std:: ffi:: OsString ;
813
- use std:: os:: windows:: ffi:: OsStrExt ;
811
+ #[ cfg( any( test, feature = "test" ) ) ]
812
+ pub fn get_path ( ) -> io:: Result < Option < RegValue > > {
813
+ USER_PATH . get ( )
814
+ }
815
+
816
+ #[ cfg( any( test, feature = "test" ) ) ]
817
+ pub struct RegistryGuard < ' a > {
818
+ _locked : LockResult < MutexGuard < ' a , ( ) > > ,
819
+ id : & ' static RegistryValueId ,
820
+ prev : Option < RegValue > ,
821
+ }
822
+
823
+ #[ cfg( any( test, feature = "test" ) ) ]
824
+ impl < ' a > RegistryGuard < ' a > {
825
+ pub fn new ( id : & ' static RegistryValueId ) -> io:: Result < Self > {
826
+ Ok ( Self {
827
+ _locked : REGISTRY_LOCK . lock ( ) ,
828
+ id,
829
+ prev : id. get ( ) ?,
830
+ } )
831
+ }
832
+ }
833
+
834
+ #[ cfg( any( test, feature = "test" ) ) ]
835
+ impl < ' a > Drop for RegistryGuard < ' a > {
836
+ fn drop ( & mut self ) {
837
+ self . id . set ( self . prev . as_ref ( ) ) . unwrap ( ) ;
838
+ }
839
+ }
840
+
841
+ #[ cfg( any( test, feature = "test" ) ) ]
842
+ static REGISTRY_LOCK : Mutex < ( ) > = Mutex :: new ( ( ) ) ;
843
+
844
+ #[ cfg( any( test, feature = "test" ) ) ]
845
+ pub const USER_PATH : RegistryValueId = RegistryValueId {
846
+ sub_key : "Environment" ,
847
+ value_name : "PATH" ,
848
+ } ;
814
849
815
- use winreg:: enums:: { RegType , HKEY_CURRENT_USER , KEY_READ , KEY_WRITE } ;
816
- use winreg:: { RegKey , RegValue } ;
850
+ #[ cfg( any( test, feature = "test" ) ) ]
851
+ pub struct RegistryValueId {
852
+ pub sub_key : & ' static str ,
853
+ pub value_name : & ' static str ,
854
+ }
855
+
856
+ #[ cfg( any( test, feature = "test" ) ) ]
857
+ impl RegistryValueId {
858
+ pub fn get_value < T : FromRegValue > ( & self ) -> io:: Result < Option < T > > {
859
+ self . get ( ) ?. map ( |v| T :: from_reg_value ( & v) ) . transpose ( )
860
+ }
861
+
862
+ pub fn get ( & self ) -> io:: Result < Option < RegValue > > {
863
+ let sub_key = RegKey :: predef ( HKEY_CURRENT_USER )
864
+ . open_subkey_with_flags ( self . sub_key , KEY_READ | KEY_WRITE ) ?;
865
+ match sub_key. get_raw_value ( self . value_name ) {
866
+ Ok ( val) => Ok ( Some ( val) ) ,
867
+ Err ( ref e) if e. kind ( ) == io:: ErrorKind :: NotFound => Ok ( None ) ,
868
+ Err ( e) => Err ( e) ,
869
+ }
870
+ }
871
+
872
+ pub fn set_value ( & self , new : Option < impl ToRegValue > ) -> io:: Result < ( ) > {
873
+ self . set ( new. map ( |s| s. to_reg_value ( ) ) . as_ref ( ) )
874
+ }
817
875
876
+ pub fn set ( & self , new : Option < & RegValue > ) -> io:: Result < ( ) > {
877
+ let sub_key = RegKey :: predef ( HKEY_CURRENT_USER )
878
+ . open_subkey_with_flags ( self . sub_key , KEY_READ | KEY_WRITE ) ?;
879
+ match new {
880
+ Some ( new) => sub_key. set_raw_value ( self . value_name , new) ,
881
+ None => sub_key. delete_value ( self . value_name ) ,
882
+ }
883
+ }
884
+ }
885
+
886
+ #[ cfg( test) ]
887
+ mod tests {
888
+ use super :: * ;
818
889
use crate :: currentprocess:: TestProcess ;
819
- use crate :: test:: { RegistryGuard , USER_PATH } ;
820
890
821
891
fn wide ( str : & str ) -> Vec < u16 > {
822
892
OsString :: from ( str) . encode_wide ( ) . collect ( )
0 commit comments