From dfdc63d42460c1d0a4e5cbe4036ad7aed7c6bf80 Mon Sep 17 00:00:00 2001 From: Bilal Elmoussaoui Date: Thu, 13 Feb 2025 12:35:40 +0100 Subject: [PATCH] slightly better error handling by implementing the proper traits for error types and propagating error where possible --- src/constants.rs | 29 ++++++++++++---------- src/example.rs | 2 +- src/lib.rs | 8 +++---- src/transaction.rs | 60 ++++++++++++++++++++++------------------------ 4 files changed, 50 insertions(+), 49 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index efdc6fb..e54d57d 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -24,19 +24,6 @@ pub enum ErrorResponse { } impl ErrorResponse { - pub fn as_string(self) -> String { - match self { - ErrorResponse::NoSpace => "No Space left on device", - ErrorResponse::CommandAborted => "Command aborted", - ErrorResponse::InvalidInstruction => "Invalid instruction", - ErrorResponse::AuthRequired => "Authentication required", - ErrorResponse::WrongSyntax => "Wrong syntax", - ErrorResponse::GenericError => "Generic Error", - ErrorResponse::NoSuchObject => "No such Object", - } - .to_string() - } - pub fn any_match(code: u16) -> Option { for resp in ErrorResponse::iter() { if code == resp as u16 { @@ -47,6 +34,22 @@ impl ErrorResponse { } } +impl std::fmt::Display for ErrorResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NoSpace => f.write_str("No Space left on device"), + Self::CommandAborted => f.write_str("Command aborted"), + Self::InvalidInstruction => f.write_str("Invalid instruction"), + Self::AuthRequired => f.write_str("Authentication required"), + Self::WrongSyntax => f.write_str("Wrong syntax"), + Self::GenericError => f.write_str("Generic Error"), + Self::NoSuchObject => f.write_str("No such Object"), + } + } +} + +impl std::error::Error for ErrorResponse {} + #[derive(Debug, EnumIter, Clone, Copy)] #[repr(u16)] pub enum SuccessResponse { diff --git a/src/example.rs b/src/example.rs index 951c16d..1775ceb 100644 --- a/src/example.rs +++ b/src/example.rs @@ -30,7 +30,7 @@ fn main() { for yubikey in yubikeys { let device_label: &str = yubikey; println!("Found device with label {}", device_label); - let session = OathSession::new(yubikey); + let session = OathSession::new(yubikey).unwrap(); println!("YubiKey version is {:?}", session.get_version()); for c in session.list_oath_codes().unwrap() { println!("{}", c); diff --git a/src/lib.rs b/src/lib.rs index aefd4d1..26a56bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -133,8 +133,8 @@ impl<'a> RefreshableOathCredential<'a> { } impl<'a> OathSession<'a> { - pub fn new(name: &str) -> Self { - let transaction_context = TransactionContext::from_name(name); + pub fn new(name: &str) -> Result { + let transaction_context = TransactionContext::from_name(name)?; let info_buffer = transaction_context .apdu_read_all(0, INS_SELECT, 0x04, 0, Some(&OATH_AID)) .unwrap(); @@ -145,7 +145,7 @@ impl<'a> OathSession<'a> { println!("{:?}: {:?}", tag, data); } - OathSession { + Ok(Self { version: clone_with_lifetime( info_map.get(&(Tag::Version as u8)).unwrap_or(&vec![0u8; 0]), ) @@ -160,7 +160,7 @@ impl<'a> OathSession<'a> { .leak(), name: name.to_string(), transaction_context, - } + }) } pub fn get_version(&self) -> &[u8] { diff --git a/src/transaction.rs b/src/transaction.rs index 41dffe4..5dc1165 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -18,23 +18,23 @@ pub enum FormattableErrorResponse { } impl FormattableErrorResponse { - pub fn from_apdu_response(sw1: u8, sw2: u8) -> FormattableErrorResponse { + pub fn from_apdu_response(sw1: u8, sw2: u8) -> Self { let code: u16 = (sw1 as u16 | sw2 as u16) << 8; if let Some(e) = ErrorResponse::any_match(code) { - return FormattableErrorResponse::Protocol(e); + return Self::Protocol(e); } if SuccessResponse::any_match(code) .or(SuccessResponse::any_match(sw1.into())) .is_some() { - return FormattableErrorResponse::NoError; + return Self::NoError; } - FormattableErrorResponse::Unknown(String::from("Unknown error")) + Self::Unknown(String::from("Unknown error")) } pub fn is_ok(&self) -> bool { - *self == FormattableErrorResponse::NoError + *self == Self::NoError } - pub fn as_opt(self) -> Option { + pub fn as_opt(self) -> Option { if self.is_ok() { None } else { @@ -42,25 +42,27 @@ impl FormattableErrorResponse { } } - fn from_transmit(err: pcsc::Error) -> FormattableErrorResponse { - FormattableErrorResponse::PcscError(err) + fn from_transmit(err: pcsc::Error) -> Self { + Self::PcscError(err) } +} - fn as_string(&self) -> String { - match self { - FormattableErrorResponse::NoError => "ok".to_string(), - FormattableErrorResponse::Unknown(msg) => msg.to_owned(), - FormattableErrorResponse::Protocol(error_response) => error_response.as_string(), - FormattableErrorResponse::PcscError(error) => format!("{}", error), - FormattableErrorResponse::ParsingError(msg) => msg.to_owned(), - FormattableErrorResponse::DeviceMismatchError => "Devices do not match".to_string(), - } +impl From for FormattableErrorResponse { + fn from(value: pcsc::Error) -> Self { + Self::PcscError(value) } } impl Display for FormattableErrorResponse { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(&self.as_string()) + match self { + Self::NoError => f.write_str("ok"), + Self::Unknown(msg) => f.write_str(msg), + Self::Protocol(error_response) => f.write_fmt(format_args!("{}", error_response)), + Self::PcscError(error) => f.write_fmt(format_args!("{}", error)), + Self::ParsingError(msg) => f.write_str(msg), + Self::DeviceMismatchError => f.write_str("Devices do not match"), + } } } @@ -139,26 +141,22 @@ pub struct TransactionContext { } impl TransactionContext { - pub fn from_name(name: &str) -> Self { - // FIXME: error handling here - + pub fn from_name(name: &str) -> Result { // Establish a PC/SC context - let ctx = pcsc::Context::establish(pcsc::Scope::User).unwrap(); + let ctx = pcsc::Context::establish(pcsc::Scope::User)?; // Connect to the card - let card = ctx - .connect( - &CString::new(name).unwrap(), - pcsc::ShareMode::Shared, - pcsc::Protocols::ANY, - ) - .unwrap(); + let card = ctx.connect( + &CString::new(name).unwrap(), + pcsc::ShareMode::Shared, + pcsc::Protocols::ANY, + )?; - TransactionContextBuilder { + Ok(TransactionContextBuilder { card, transaction_builder: |c| c.transaction().unwrap(), } - .build() + .build()) } pub fn apdu(