diff --git a/src/algorithm.rs b/src/algorithm.rs index e009a38..10d2164 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,6 +1,8 @@ use std::fmt; use std::str::FromStr; +use error::{ConnectionError, ConnectionResult}; + /// Slice of implemented key exchange algorithms, ordered by preference pub static KEY_EXCHANGE: &[KeyExchangeAlgorithm] = &[ @@ -26,13 +28,16 @@ pub static COMPRESSION: &[CompressionAlgorithm] = &[CompressionAlgorithm::None, CompressionAlgorithm::Zlib]; /// Find the best matching algorithm -pub fn negotiate(server: &[A], client: &[A]) -> Option { +pub fn negotiate( + server: &[A], + client: &[A], +) -> ConnectionResult { for algorithm in client.iter() { if server.iter().any(|a| a == algorithm) { - return Some(*algorithm); + return Ok(*algorithm); } } - None + Err(ConnectionError::NegotiationError) } #[derive(Clone, Copy, PartialEq, Debug)] diff --git a/src/connection.rs b/src/connection.rs index 0a171f8..07e28f7 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,9 +1,9 @@ use std::io::{self, BufRead, BufReader, Read, Write}; +use error::{ConnectionError, ConnectionResult}; use key_exchange::{self, KeyExchange, KeyExchangeResult}; use message::MessageType; use packet::{Packet, ReadPacketExt, WritePacketExt}; -use public_key::KeyPair; #[derive(PartialEq)] enum ConnectionState { @@ -42,7 +42,7 @@ impl Connection { } } - pub fn run(&mut self, mut stream: &mut Read) -> io::Result<()> { + pub fn run(&mut self, mut stream: &mut Read) -> ConnectionResult<()> { self.stream.write(self.my_id.as_bytes())?; self.stream.flush()?; @@ -53,7 +53,7 @@ impl Connection { } loop { - let packet = Packet::read_from(&mut stream).unwrap(); + let packet = Packet::read_from(&mut stream)?; println!("packet: {:?}", packet); self.process(&packet); } @@ -73,62 +73,64 @@ impl Connection { Ok(id.trim_right().to_owned()) } - pub fn process(&mut self, packet: &Packet) { + pub fn process(&mut self, packet: &Packet) -> ConnectionResult<()> { match packet.msg_type() { MessageType::KexInit => { println!("Starting Key Exchange!"); - self.kex_init(packet); + self.kex_init(packet) } MessageType::KeyExchange(_) => { - if let Some(ref mut kex) = self.key_exchange { - match kex.process(packet) - { - KeyExchangeResult::Ok(Some(packet)) => { - packet.write_to(&mut self.stream); + let ref mut kex = self.key_exchange.as_mut().ok_or( + ConnectionError::KeyExchangeError, + )?; + + match kex.process(packet) + { + KeyExchangeResult::Ok(Some(packet)) => { + packet.write_to(&mut self.stream)?; + } + KeyExchangeResult::Error(Some(packet)) => { + packet.write_to(&mut self.stream)?; + } + KeyExchangeResult::Done(packet) => { + if let Some(packet) = packet { + packet.write_to(&mut self.stream)?; } - KeyExchangeResult::Error(Some(packet)) => { - packet.write_to(&mut self.stream); - } - KeyExchangeResult::Done(Some(packet)) => { - packet.write_to(&mut self.stream); - } - KeyExchangeResult::Ok(None) | - KeyExchangeResult::Error(None) | - KeyExchangeResult::Done(None) => {} - }; - } - else { - warn!("Received KeyExchange packet without KexInit"); - } + self.state = ConnectionState::Established; + } + KeyExchangeResult::Ok(None) | + KeyExchangeResult::Error(None) => {} + }; + Ok(()) } _ => { println!("Unhandled packet: {:?}", packet); + Err(ConnectionError::KeyExchangeError) } } } - pub fn kex_init(&mut self, packet: &Packet) { + pub fn kex_init(&mut self, packet: &Packet) -> ConnectionResult<()> { use algorithm::*; let mut reader = packet.reader(); - let cookie = reader.read_bytes(16); - let kex_algos = reader.read_enum_list::(); - let srv_host_key_algos = reader.read_enum_list::(); - let enc_algos_c2s = reader.read_enum_list::(); - let enc_algos_s2c = reader.read_enum_list::(); - let mac_algos_c2s = reader.read_enum_list::(); - let mac_algos_s2c = reader.read_enum_list::(); - let comp_algos_c2s = reader.read_enum_list::(); - let comp_algos_s2c = reader.read_enum_list::(); + let _ = reader.read_bytes(16)?; // Cookie. Throw it away. + let kex_algos = reader.read_enum_list::()?; + let srv_host_key_algos = reader.read_enum_list::()?; + let enc_algos_c2s = reader.read_enum_list::()?; + let enc_algos_s2c = reader.read_enum_list::()?; + let mac_algos_c2s = reader.read_enum_list::()?; + let mac_algos_s2c = reader.read_enum_list::()?; + let comp_algos_c2s = reader.read_enum_list::()?; + let comp_algos_s2c = reader.read_enum_list::()?; - let kex_algo = negotiate(KEY_EXCHANGE, kex_algos.unwrap().as_slice()); + let kex_algo = negotiate(KEY_EXCHANGE, kex_algos.as_slice())?; let srv_host_key_algo = - negotiate(HOST_KEY, srv_host_key_algos.unwrap().as_slice()); - let enc_algo = negotiate(ENCRYPTION, enc_algos_s2c.unwrap().as_slice()); - let mac_algo = negotiate(MAC, mac_algos_s2c.unwrap().as_slice()); - let comp_algo = - negotiate(COMPRESSION, comp_algos_s2c.unwrap().as_slice()); + negotiate(HOST_KEY, srv_host_key_algos.as_slice())?; + let enc_algo = negotiate(ENCRYPTION, enc_algos_s2c.as_slice())?; + let mac_algo = negotiate(MAC, mac_algos_s2c.as_slice())?; + let comp_algo = negotiate(COMPRESSION, comp_algos_s2c.as_slice())?; println!("Negotiated Kex Algorithm: {:?}", kex_algo); println!("Negotiated Host Key Algorithm: {:?}", srv_host_key_algo); @@ -137,7 +139,7 @@ impl Connection { println!("Negotiated Comp Algorithm: {:?}", comp_algo); use rand::{OsRng, Rng}; - let mut rng = OsRng::new().unwrap(); + let mut rng = OsRng::new()?; let cookie: Vec = rng.gen_iter::().take(16).collect(); let mut packet = Packet::new(MessageType::KexInit); @@ -156,10 +158,11 @@ impl Connection { w.write_bool(false)?; w.write_uint32(0)?; Ok(()) - }); + })?; self.state = ConnectionState::KeyExchange; self.key_exchange = Some(Box::new(key_exchange::Curve25519::new())); - packet.write_to(&mut self.stream); + packet.write_to(&mut self.stream)?; + Ok(()) } } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..845a13b --- /dev/null +++ b/src/error.rs @@ -0,0 +1,39 @@ +use std::convert::From; +use std::error::Error; +use std::fmt; +use std::io; + +pub type ConnectionResult = Result; + +#[derive(Debug)] +pub enum ConnectionError { + IoError(io::Error), + ProtocolError, + NegotiationError, + KeyExchangeError, +} + +impl fmt::Display for ConnectionError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "connection error: {}", (self as &Error).description()) + } +} + +impl Error for ConnectionError { + fn description(&self) -> &str { + use self::ConnectionError::*; + match self + { + &IoError(_) => "io error", + &ProtocolError => "protocol error", + &NegotiationError => "negotiation error", + &KeyExchangeError => "key exchange error", + } + } +} + +impl From for ConnectionError { + fn from(err: io::Error) -> ConnectionError { + ConnectionError::IoError(err) + } +} diff --git a/src/lib.rs b/src/lib.rs index 7573ca8..937b89e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ extern crate num_bigint; #[macro_use] extern crate log; +mod error; mod algorithm; mod packet; mod message; diff --git a/src/message.rs b/src/message.rs index 97e1692..21b4e97 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,5 +1,3 @@ -use std::fmt::Debug; - #[derive(PartialEq, Clone, Copy, Debug)] pub enum MessageType { Disconnect, diff --git a/src/packet.rs b/src/packet.rs index 5ed4cca..d46e85e 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -109,11 +109,11 @@ pub trait ReadPacketExt: ReadBytesExt { } fn read_utf8(&mut self) -> Result { - Ok( - str::from_utf8(self.read_string()?.as_slice()) - .unwrap_or("") - .to_owned(), - ) + str::from_utf8(self.read_string()?.as_slice()) + .map(|s| s.to_owned()) + .map_err(|_| { + io::Error::new(io::ErrorKind::InvalidData, "invalid utf-8") + }) } fn read_bool(&mut self) -> Result { diff --git a/src/server.rs b/src/server.rs index 04dfe29..67c9b82 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,7 @@ -use std::io::{self, Write}; +use std::io; use std::net::TcpListener; use connection::{Connection, ConnectionType}; -use packet::Packet; use public_key::KeyPair; pub struct ServerConfig { @@ -41,7 +40,10 @@ impl Server { stream.try_clone().unwrap(), ); - connection.run(&mut stream); + let result = connection.run(&mut stream); + if let Some(error) = result.err() { + println!("sshd: {}", error) + } } Ok(())