diff --git a/ssh-key/src/algorithm.rs b/ssh-key/src/algorithm.rs index 100434b..bffb019 100644 --- a/ssh-key/src/algorithm.rs +++ b/ssh-key/src/algorithm.rs @@ -9,7 +9,7 @@ use encoding::{Label, LabelError}; #[cfg(feature = "alloc")] use { - alloc::vec::Vec, + alloc::{borrow::ToOwned, string::String, vec::Vec}, sha2::{Digest, Sha256, Sha512}, }; @@ -179,7 +179,7 @@ impl Algorithm { CERT_SK_ECDSA_SHA2_P256 => Ok(Algorithm::SkEcdsaSha2NistP256), CERT_SK_SSH_ED25519 => Ok(Algorithm::SkEd25519), #[cfg(feature = "alloc")] - _ => Ok(Algorithm::Other(AlgorithmName::from_certificate_str(id)?)), + _ => Ok(Algorithm::Other(AlgorithmName::from_certificate_type(id)?)), #[cfg(not(feature = "alloc"))] _ => Err(Error::AlgorithmUnknown), } @@ -214,7 +214,8 @@ impl Algorithm { /// See [PROTOCOL.certkeys] for more information. /// /// [PROTOCOL.certkeys]: https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.certkeys?annotate=HEAD - pub fn as_certificate_str(&self) -> &str { + #[cfg(feature = "alloc")] + pub fn to_certificate_type(&self) -> String { match self { Algorithm::Dsa => CERT_DSA, Algorithm::Ecdsa { curve } => match curve { @@ -226,9 +227,9 @@ impl Algorithm { Algorithm::Rsa { .. } => CERT_RSA, Algorithm::SkEcdsaSha2NistP256 => CERT_SK_ECDSA_SHA2_P256, Algorithm::SkEd25519 => CERT_SK_SSH_ED25519, - #[cfg(feature = "alloc")] - Algorithm::Other(algorithm) => algorithm.certificate_str(), + Algorithm::Other(algorithm) => return algorithm.certificate_type(), } + .to_owned() } /// Is the algorithm DSA? diff --git a/ssh-key/src/algorithm/name.rs b/ssh-key/src/algorithm/name.rs index d85bd3b..a7ba8ab 100644 --- a/ssh-key/src/algorithm/name.rs +++ b/ssh-key/src/algorithm/name.rs @@ -33,26 +33,30 @@ const MAX_CERT_STR_LEN: usize = MAX_ALGORITHM_NAME_LEN + CERT_STR_SUFFIX.len(); pub struct AlgorithmName { /// The string identifier which corresponds to this algorithm. id: String, - /// The string identifier which corresponds to the OpenSSH certificate format. - /// - /// This is derived from the algorithm name by inserting `"-cert-v01"` immediately after the - /// name preceding the at-symbol (`@`). - certificate_str: String, } impl AlgorithmName { + /// Create a new algorithm identifier. + pub fn new(id: impl Into) -> Result { + let id = id.into(); + validate_algorithm_id(&id, MAX_ALGORITHM_NAME_LEN)?; + split_algorithm_id(&id)?; + Ok(Self { id }) + } + /// Get the string identifier which corresponds to this algorithm name. pub fn as_str(&self) -> &str { &self.id } /// Get the string identifier which corresponds to the OpenSSH certificate format. - pub fn certificate_str(&self) -> &str { - &self.certificate_str + pub fn certificate_type(&self) -> String { + let (name, domain) = split_algorithm_id(&self.id).expect("format checked in constructor"); + format!("{name}{CERT_STR_SUFFIX}@{domain}") } /// Create a new [`AlgorithmName`] from an OpenSSH certificate format string identifier. - pub fn from_certificate_str(id: &str) -> Result { + pub fn from_certificate_type(id: &str) -> Result { validate_algorithm_id(id, MAX_CERT_STR_LEN)?; // Derive the algorithm name from the certificate format string identifier: @@ -63,10 +67,7 @@ impl AlgorithmName { let algorithm_name = format!("{name}@{domain}"); - Ok(Self { - id: algorithm_name, - certificate_str: id.into(), - }) + Ok(Self { id: algorithm_name }) } } @@ -74,16 +75,7 @@ impl FromStr for AlgorithmName { type Err = LabelError; fn from_str(id: &str) -> Result { - validate_algorithm_id(id, MAX_ALGORITHM_NAME_LEN)?; - - // Derive the certificate format string identifier from the algorithm name: - let (name, domain) = split_algorithm_id(id)?; - let certificate_str = format!("{name}{CERT_STR_SUFFIX}@{domain}"); - - Ok(Self { - id: id.into(), - certificate_str, - }) + Self::new(id) } } diff --git a/ssh-key/src/certificate.rs b/ssh-key/src/certificate.rs index 936697d..5e39bef 100644 --- a/ssh-key/src/certificate.rs +++ b/ssh-key/src/certificate.rs @@ -176,7 +176,7 @@ impl Certificate { let mut cert = Certificate::decode(&mut reader)?; // Verify that the algorithm in the Base64-encoded data matches the text - if encapsulation.algorithm_id != cert.algorithm().as_certificate_str() { + if encapsulation.algorithm_id != cert.algorithm().to_certificate_type() { return Err(Error::AlgorithmUnknown); } @@ -193,7 +193,7 @@ impl Certificate { /// Encode OpenSSH certificate to a [`String`]. pub fn to_openssh(&self) -> Result { - SshFormat::encode_string(self.algorithm().as_certificate_str(), self, self.comment()) + SshFormat::encode_string(&self.algorithm().to_certificate_type(), self, self.comment()) } /// Serialize OpenSSH certificate as raw bytes. @@ -429,7 +429,7 @@ impl Certificate { /// Encode the portion of the certificate "to be signed" by the CA /// (or to be verified against an existing CA signature) fn encode_tbs(&self, writer: &mut impl Writer) -> encoding::Result<()> { - self.algorithm().as_certificate_str().encode(writer)?; + self.algorithm().to_certificate_type().encode(writer)?; self.nonce.encode(writer)?; self.public_key.encode_key_data(writer)?; self.serial.encode(writer)?; @@ -473,7 +473,7 @@ impl Decode for Certificate { impl Encode for Certificate { fn encoded_len(&self) -> encoding::Result { [ - self.algorithm().as_certificate_str().encoded_len()?, + self.algorithm().to_certificate_type().encoded_len()?, self.nonce.encoded_len()?, self.public_key.encoded_key_data_len()?, self.serial.encoded_len()?, diff --git a/ssh-key/tests/algorithm_name.rs b/ssh-key/tests/algorithm_name.rs index 3ff5ccd..98a942b 100644 --- a/ssh-key/tests/algorithm_name.rs +++ b/ssh-key/tests/algorithm_name.rs @@ -12,11 +12,11 @@ fn additional_algorithm_name() { let name = AlgorithmName::from_str(NAME).unwrap(); assert_eq!(name.as_str(), NAME); - assert_eq!(name.certificate_str(), CERT_STR); + assert_eq!(name.certificate_type(), CERT_STR); - let name = AlgorithmName::from_certificate_str(CERT_STR).unwrap(); + let name = AlgorithmName::from_certificate_type(CERT_STR).unwrap(); assert_eq!(name.as_str(), NAME); - assert_eq!(name.certificate_str(), CERT_STR); + assert_eq!(name.certificate_type(), CERT_STR); } #[test] @@ -48,7 +48,7 @@ fn invalid_algorithm_name() { for name in INVALID_CERT_STRS { assert!( - AlgorithmName::from_certificate_str(&name).is_err(), + AlgorithmName::from_certificate_type(&name).is_err(), "{:?} should be an invalid certificate str", name );