From 8c911fec23e4906a7c1bdc8d387af2ccd40ccf35 Mon Sep 17 00:00:00 2001 From: Thomas Gatzweiler Date: Thu, 20 Jul 2017 14:14:31 +0200 Subject: [PATCH] Don't use TcpStream::try_clone for now --- src/connection.rs | 548 ++++++++++++++++++++++++---------------------- src/server.rs | 10 +- 2 files changed, 289 insertions(+), 269 deletions(-) diff --git a/src/connection.rs b/src/connection.rs index c003d91..eb1ec61 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,9 +1,11 @@ -use std::io::{self, BufRead, BufReader, Read, Write}; +use std::collections::VecDeque; +use std::io::{self, BufReader, Read, Write}; +use std::net::TcpStream; use std::sync::Arc; use encryption::{AesCtr, Decryptor, Encryption}; -use error::{ConnectionError, ConnectionResult}; -use key_exchange::{self, KexResult, KeyExchange}; +use error::{ConnectionError, ConnectionResult as Result}; +use key_exchange::{KexResult, KeyExchange}; use mac::{Hmac, MacAlgorithm}; use message::MessageType; use packet::{Packet, ReadPacketExt, WritePacketExt}; @@ -34,122 +36,150 @@ pub struct Connection { pub hash_data: HashData, state: ConnectionState, key_exchange: Option>, - stream: Box, session_id: Option>, encryption: Option<(Box, Box)>, mac: Option<(Box, Box)>, seq: (u32, u32), + tx_queue: VecDeque, } impl<'a> Connection { - pub fn new(conn_type: ConnectionType, stream: Box) -> Connection { + pub fn new(conn_type: ConnectionType) -> Connection { Connection { conn_type: conn_type, hash_data: HashData::default(), state: ConnectionState::Initial, key_exchange: None, - stream: Box::new(stream), session_id: None, encryption: None, mac: None, seq: (0, 0), + tx_queue: VecDeque::new(), } } - pub fn run(&mut self, stream: &mut Read) -> ConnectionResult<()> { - let mut reader = BufReader::new(stream); + pub fn run(&mut self, mut stream: TcpStream) -> Result<()> { + self.send_id(&mut stream)?; + self.read_id(&stream)?; - self.send_id()?; - self.read_id(&mut reader)?; + let mut reader = BufReader::new(&stream); loop { - let packet = if let Some((ref mut c2s, _)) = self.encryption { - let mut decryptor = Decryptor::new(&mut **c2s, &mut reader); - Packet::read_from(&mut decryptor)? - } - else { - Packet::read_from(&mut reader)? - }; + let packet = self.recv(&mut reader)?; + let response = self.process(packet)?; - if let Some((ref mut mac, _)) = self.mac { - let mut sig = vec![0; mac.size()]; - reader.read_exact(&mut sig)?; + let mut stream = reader.get_mut(); - let mut sig_cmp = vec![0; mac.size()]; - mac.sign(packet.data(), self.seq.0, sig_cmp.as_mut_slice()); - - if sig != sig_cmp { - return Err(ConnectionError::IntegrityError); - } + if let Some(packet) = response { + self.send(&mut stream, packet)?; } - trace!("Packet {} received: {:?}", self.seq.0, packet); - self.process(packet)?; - - self.seq.0 = self.seq.0.wrapping_add(1); + // Send additional packets from the queue + let mut packets: Vec = self.tx_queue.drain(..).collect(); + for packet in packets.drain(..) { + self.send(&mut stream, packet)?; + } } } - pub fn send(&mut self, packet: Packet) -> io::Result<()> { + fn recv(&mut self, mut stream: &mut Read) -> Result { + let packet = if let Some((ref mut c2s, _)) = self.encryption { + let mut decryptor = Decryptor::new(&mut **c2s, &mut stream); + Packet::read_from(&mut decryptor)? + } + else { + Packet::read_from(&mut stream)? + }; + + if let Some((ref mut mac, _)) = self.mac { + let mut sig = vec![0; mac.size()]; + stream.read_exact(&mut sig)?; + + let mut sig_cmp = vec![0; mac.size()]; + mac.sign(packet.data(), self.seq.0, sig_cmp.as_mut_slice()); + + if sig != sig_cmp { + return Err(ConnectionError::IntegrityError); + } + } + + trace!("Packet {} received: {:?}", self.seq.0, packet); + + // Count up the received packet sequence number + self.seq.0 = self.seq.0.wrapping_add(1); + + Ok(packet) + } + + fn send(&mut self, mut stream: &mut Write, packet: Packet) + -> io::Result<()> { trace!("Sending packet {}: {:?}", self.seq.1, packet); let packet = packet.to_raw()?; if let Some((_, ref mut s2c)) = self.encryption { - let mut encrypted = vec![0; packet.data().len()]; s2c.encrypt(packet.data(), encrypted.as_mut_slice()); // Sending encrypted packet - self.stream.write_all(encrypted.as_slice())?; + stream.write_all(encrypted.as_slice())?; } else { - packet.write_to(&mut self.stream)?; + packet.write_to(&mut stream)?; } - self.seq.1 = self.seq.1.wrapping_add(1); - if let Some((_, ref mut mac)) = self.mac { let mut sig = vec![0; mac.size()]; mac.sign(packet.data(), self.seq.1, sig.as_mut_slice()); - self.stream.write_all(sig.as_slice())?; + stream.write_all(sig.as_slice())?; } + self.seq.1 = self.seq.1.wrapping_add(1); + Ok(()) } - fn send_id(&mut self) -> io::Result<()> { + fn send_id(&mut self, stream: &mut TcpStream) -> io::Result<()> { let id = format!("SSH-2.0-RedoxSSH_{}", env!("CARGO_PKG_VERSION")); info!("Identifying as {:?}", id); - self.stream.write(id.as_bytes())?; - self.stream.write(b"\r\n")?; - self.stream.flush()?; + stream.write(id.as_bytes())?; + stream.write(b"\r\n")?; + stream.flush()?; self.hash_data.server_id = Some(id); Ok(()) } - fn read_id(&mut self, mut reader: &mut BufRead) -> io::Result<()> { - // The identification string has a maximum length of 255 bytes - // TODO: Make sure to stop reading if the client sends too much + fn read_id(&mut self, stream: &TcpStream) -> io::Result<()> { + use std::ascii::AsciiExt; let mut id = String::new(); - while !id.starts_with("SSH-") { - reader.read_line(&mut id)?; + for byte in stream.take(255).bytes() { + match byte + { + Ok(b'\n') => break, + Ok(b) if b.is_ascii() => id.push(b as char), + Ok(_) => {} + Err(_) => break, + } } - let peer_id = id.trim_right().to_owned(); - info!("Peer identifies as {:?}", peer_id); - self.hash_data.client_id = Some(peer_id); + let id = id.trim().to_owned(); - Ok(()) + if id.starts_with("SSH-") { + info!("Peer identifies as {:?}", id); + self.hash_data.client_id = Some(id); + Ok(()) + } + else { + Err(io::Error::new(io::ErrorKind::InvalidData, "invalid id")) + } } - fn generate_key(&mut self, id: &[u8], len: usize) - -> ConnectionResult> { + fn generate_key(&mut self, id: &[u8], len: usize) -> Result> { use self::ConnectionError::KeyGenerationError; let kex = self.key_exchange.take().ok_or(KeyGenerationError)?; @@ -171,213 +201,17 @@ impl<'a> Connection { Ok(key) } - pub fn process(&mut self, packet: Packet) -> ConnectionResult<()> { + pub fn process(&mut self, packet: Packet) -> Result> { match packet.msg_type() { - MessageType::KexInit => { - debug!("Starting key exchange"); - self.kex_init(packet) - } - MessageType::NewKeys => { - debug!("Switching to new keys"); - - let iv_c2s = self.generate_key(b"A", 256)?; - let iv_s2c = self.generate_key(b"B", 256)?; - let enc_c2s = self.generate_key(b"C", 256)?; - let enc_s2c = self.generate_key(b"D", 256)?; - let mac_c2s = self.generate_key(b"E", 256)?; - let mac_s2c = self.generate_key(b"F", 256)?; - - self.encryption = - Some(( - Box::new( - AesCtr::new(enc_c2s.as_slice(), iv_c2s.as_slice()), - ), - Box::new( - AesCtr::new(enc_s2c.as_slice(), iv_s2c.as_slice()), - ), - )); - - self.mac = Some(( - Box::new(Hmac::new(mac_c2s.as_slice())), - Box::new(Hmac::new(mac_s2c.as_slice())), - )); - - Ok(()) - } - MessageType::ServiceRequest => { - let mut reader = packet.reader(); - let name = reader.read_string()?; - - trace!( - "{:?}", - ::std::str::from_utf8(&name.as_slice()).unwrap() - ); - - let mut res = Packet::new(MessageType::ServiceAccept); - res.with_writer(&|w| { - w.write_bytes(name.as_slice())?; - Ok(()) - })?; - - self.send(res)?; - Ok(()) - } - MessageType::UserAuthRequest => { - let mut reader = packet.reader(); - let name = reader.read_utf8()?; - let service = reader.read_utf8()?; - let method = reader.read_utf8()?; - - let success = if method == "password" { - assert!(reader.read_bool()? == false); - let pass = reader.read_utf8()?; - pass == "hunter2" - } - else { - false - }; - - if success { - self.send(Packet::new(MessageType::UserAuthSuccess))?; - } - else { - let mut res = Packet::new(MessageType::UserAuthFailure); - res.with_writer(&|w| { - w.write_string("password")?; - w.write_bool(false)?; - Ok(()) - })?; - - self.send(res)?; - } - - debug!("User Auth {:?}, {:?}, {:?}", name, service, method); - Ok(()) - } - MessageType::ChannelOpen => { - let mut reader = packet.reader(); - let channel_type = reader.read_utf8()?; - let sender_channel = reader.read_uint32()?; - let window_size = reader.read_uint32()?; - let max_packet_size = reader.read_uint32()?; - - let mut res = Packet::new(MessageType::ChannelOpenConfirmation); - res.with_writer(&|w| { - w.write_uint32(sender_channel)?; - w.write_uint32(0)?; - w.write_uint32(window_size)?; - w.write_uint32(max_packet_size)?; - Ok(()) - })?; - - self.send(res)?; - debug!( - "Channel Open {:?}, {:?}, {:?}, {:?}", - channel_type, - sender_channel, - window_size, - max_packet_size - ); - - Ok(()) - } - MessageType::ChannelRequest => { - let mut reader = packet.reader(); - let channel = reader.read_uint32()?; - let request = reader.read_utf8()?; - let want_reply = reader.read_bool()?; - - debug!( - "Channel Request {:?}, {:?}, {:?}", - channel, - request, - want_reply - ); - - if request == "pty-req" { - let term = reader.read_utf8()?; - let char_width = reader.read_uint32()?; - let row_height = reader.read_uint32()?; - let pixel_width = reader.read_uint32()?; - let pixel_height = reader.read_uint32()?; - let modes = reader.read_string()?; - - debug!( - "PTY request: {:?} {:?} {:?} {:?} {:?} {:?}", - term, - char_width, - row_height, - pixel_width, - pixel_height, - modes - ); - } - - if request == "shell" { - debug!("Shell request"); - } - - if want_reply { - let mut res = Packet::new(MessageType::ChannelSuccess); - res.with_writer(&|w| w.write_uint32(0))?; - self.send(res)?; - } - - Ok(()) - } - MessageType::ChannelData => { - let mut reader = packet.reader(); - let channel = reader.read_uint32()?; - let data = reader.read_string()?; - - let mut res = Packet::new(MessageType::ChannelData); - res.with_writer(&|w| { - w.write_uint32(0)?; - w.write_bytes(data.as_slice())?; - Ok(()) - })?; - - self.send(res)?; - - debug!( - "Channel {} Data ({} bytes): {:?}", - channel, - data.len(), - data - ); - Ok(()) - } - MessageType::KeyExchange(_) => { - let mut kex = self.key_exchange.take().ok_or( - ConnectionError::KeyExchangeError, - )?; - - match kex.process(self, packet) - { - KexResult::Done(packet) => { - self.state = ConnectionState::Established; - self.send(packet)?; - - if self.session_id.is_none() { - self.session_id = - kex.exchange_hash().map(|h| h.to_vec()); - } - - let packet = Packet::new(MessageType::NewKeys); - self.send(packet)?; - Ok(()) - } - KexResult::Ok(packet) => { - self.send(packet)?; - Ok(()) - } - KexResult::Error => Err(ConnectionError::KeyExchangeError), - }?; - - self.key_exchange = Some(kex); - Ok(()) - } + MessageType::KexInit => self.kex_init(packet), + MessageType::NewKeys => self.new_keys(packet), + MessageType::ServiceRequest => self.service_request(packet), + MessageType::UserAuthRequest => self.user_auth_request(packet), + MessageType::ChannelOpen => self.channel_open(packet), + MessageType::ChannelRequest => self.channel_request(packet), + MessageType::ChannelData => self.channel_data(packet), + MessageType::KeyExchange(_) => self.key_exchange(packet), _ => { error!("Unhandled packet: {:?}", packet); Err(ConnectionError::ProtocolError) @@ -385,7 +219,172 @@ impl<'a> Connection { } } - pub fn kex_init(&mut self, packet: Packet) -> ConnectionResult<()> { + fn new_keys(&mut self, packet: Packet) -> Result> { + debug!("Switching to new keys"); + + let iv_c2s = self.generate_key(b"A", 256)?; + let iv_s2c = self.generate_key(b"B", 256)?; + let enc_c2s = self.generate_key(b"C", 256)?; + let enc_s2c = self.generate_key(b"D", 256)?; + let mac_c2s = self.generate_key(b"E", 256)?; + let mac_s2c = self.generate_key(b"F", 256)?; + + self.encryption = + Some(( + Box::new(AesCtr::new(enc_c2s.as_slice(), iv_c2s.as_slice())), + Box::new(AesCtr::new(enc_s2c.as_slice(), iv_s2c.as_slice())), + )); + + self.mac = Some(( + Box::new(Hmac::new(mac_c2s.as_slice())), + Box::new(Hmac::new(mac_s2c.as_slice())), + )); + + Ok(None) + } + + fn service_request(&mut self, packet: Packet) -> Result> { + let mut reader = packet.reader(); + let name = reader.read_string()?; + + trace!("{:?}", ::std::str::from_utf8(&name.as_slice()).unwrap()); + + let mut res = Packet::new(MessageType::ServiceAccept); + res.with_writer(&|w| { + w.write_bytes(name.as_slice())?; + Ok(()) + })?; + + Ok(Some(res)) + } + + fn user_auth_request(&mut self, packet: Packet) -> Result> { + let mut reader = packet.reader(); + let name = reader.read_utf8()?; + let service = reader.read_utf8()?; + let method = reader.read_utf8()?; + + let success = if method == "password" { + assert!(reader.read_bool()? == false); + let pass = reader.read_utf8()?; + pass == "hunter2" + } + else { + false + }; + + debug!("User Auth {:?}, {:?}, {:?}", name, service, method); + + if success { + Ok(Some(Packet::new(MessageType::UserAuthSuccess))) + } + else { + let mut res = Packet::new(MessageType::UserAuthFailure); + res.with_writer(&|w| { + w.write_string("password")?; + w.write_bool(false)?; + Ok(()) + })?; + Ok(Some(res)) + } + } + + fn channel_open(&mut self, packet: Packet) -> Result> { + let mut reader = packet.reader(); + let channel_type = reader.read_utf8()?; + let sender_channel = reader.read_uint32()?; + let window_size = reader.read_uint32()?; + let max_packet_size = reader.read_uint32()?; + + let mut res = Packet::new(MessageType::ChannelOpenConfirmation); + res.with_writer(&|w| { + w.write_uint32(sender_channel)?; + w.write_uint32(0)?; + w.write_uint32(window_size)?; + w.write_uint32(max_packet_size)?; + Ok(()) + })?; + + debug!( + "Channel Open {:?}, {:?}, {:?}, {:?}", + channel_type, + sender_channel, + window_size, + max_packet_size + ); + + Ok(Some(res)) + } + + fn channel_request(&mut self, packet: Packet) -> Result> { + let mut reader = packet.reader(); + let channel = reader.read_uint32()?; + let request = reader.read_utf8()?; + let want_reply = reader.read_bool()?; + + debug!( + "Channel Request {:?}, {:?}, {:?}", + channel, + request, + want_reply + ); + + if request == "pty-req" { + let term = reader.read_utf8()?; + let char_width = reader.read_uint32()?; + let row_height = reader.read_uint32()?; + let pixel_width = reader.read_uint32()?; + let pixel_height = reader.read_uint32()?; + let modes = reader.read_string()?; + + debug!( + "PTY request: {:?} {:?} {:?} {:?} {:?} {:?}", + term, + char_width, + row_height, + pixel_width, + pixel_height, + modes + ); + } + + if request == "shell" { + debug!("Shell request"); + } + + if want_reply { + let mut res = Packet::new(MessageType::ChannelSuccess); + res.with_writer(&|w| w.write_uint32(0))?; + Ok(Some(res)) + } + else { + Ok(None) + } + } + + fn channel_data(&mut self, packet: Packet) -> Result> { + let mut reader = packet.reader(); + let channel = reader.read_uint32()?; + let data = reader.read_string()?; + + let mut res = Packet::new(MessageType::ChannelData); + res.with_writer(&|w| { + w.write_uint32(0)?; + w.write_bytes(data.as_slice())?; + Ok(()) + })?; + + debug!( + "Channel {} Data ({} bytes): {:?}", + channel, + data.len(), + data + ); + + Ok(Some(res)) + } + + fn kex_init(&mut self, packet: Packet) -> Result> { use algorithm::*; let (kex_algo, srv_host_key_algo, enc_algo, mac_algo, comp_algo) = { @@ -451,11 +450,36 @@ impl<'a> Connection { self.state = ConnectionState::KeyExchange; self.key_exchange = kex_algo.instance(); - packet.write_to(&mut self.stream)?; - // Save payload for hash generation - self.hash_data.server_kexinit = Some(packet.payload()); + self.hash_data.server_kexinit = Some(packet.data().to_vec()); - Ok(()) + Ok(Some(packet)) + } + + fn key_exchange(&mut self, packet: Packet) -> Result> { + let mut kex = self.key_exchange.take().ok_or( + ConnectionError::KeyExchangeError, + )?; + + let result = match kex.process(self, packet) + { + KexResult::Done(packet) => { + self.state = ConnectionState::Established; + + if self.session_id.is_none() { + self.session_id = kex.exchange_hash().map(|h| h.to_vec()); + } + + self.tx_queue.push_back(Packet::new(MessageType::NewKeys)); + + Ok(Some(packet)) + } + KexResult::Ok(packet) => Ok(Some(packet)), + KexResult::Error => Err(ConnectionError::KeyExchangeError), + }; + + + self.key_exchange = Some(kex); + result } } diff --git a/src/server.rs b/src/server.rs index e284291..623d9e5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -32,14 +32,10 @@ impl Server { debug!("Incoming connection from {}", addr); thread::spawn(move || { - let mut read_stream = stream.try_clone().unwrap(); + let mut connection = + Connection::new(ConnectionType::Server(config)); - let mut connection = Connection::new( - ConnectionType::Server(config), - Box::new(stream), - ); - - let result = connection.run(&mut read_stream); + let result = connection.run(stream); if let Some(error) = result.err() { println!("sshd: {}", error)