diff --git a/src/connection.rs b/src/connection.rs index 241a0a9..12289ed 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,6 +1,5 @@ use std::collections::{BTreeMap, VecDeque}; use std::io::{self, BufReader, Read, Write}; -use std::net::TcpStream; use std::sync::Arc; use channel::{Channel, ChannelId, ChannelRequest}; @@ -249,10 +248,7 @@ impl<'a> Connection { ); let mut res = Packet::new(MessageType::ServiceAccept); - res.with_writer(&|w| { - w.write_bytes(name.as_slice())?; - Ok(()) - })?; + res.write_bytes(name.as_slice())?; Ok(Some(res)) } @@ -279,11 +275,9 @@ impl<'a> Connection { } else { let mut res = Packet::new(MessageType::UserAuthFailure); - res.with_writer(&|w| { - w.write_string("password")?; - w.write_bool(false)?; - Ok(()) - })?; + res.write_string("password")?; + res.write_bool(false)?; + Ok(Some(res)) } } @@ -305,13 +299,10 @@ impl<'a> Connection { let channel = Channel::new(id, peer_id, window_size, max_packet_size); let mut res = Packet::new(MessageType::ChannelOpenConfirmation); - res.with_writer(&|w| { - w.write_uint32(peer_id)?; - w.write_uint32(channel.id())?; - w.write_uint32(channel.window_size())?; - w.write_uint32(channel.max_packet_size())?; - Ok(()) - })?; + res.write_uint32(peer_id)?; + res.write_uint32(channel.id())?; + res.write_uint32(channel.window_size())?; + res.write_uint32(channel.max_packet_size())?; debug!("Open {:?}", channel); @@ -352,7 +343,7 @@ impl<'a> Connection { if want_reply { let mut res = Packet::new(MessageType::ChannelSuccess); - res.with_writer(&|w| w.write_uint32(0))?; + res.write_uint32(0)?; Ok(Some(res)) } else { @@ -417,22 +408,19 @@ impl<'a> Connection { let cookie: Vec = rng.gen_iter::().take(16).collect(); let mut packet = Packet::new(MessageType::KexInit); - packet.with_writer(&|w| { - w.write_raw_bytes(cookie.as_slice())?; - w.write_list(KEY_EXCHANGE)?; - w.write_list(HOST_KEY)?; - w.write_list(ENCRYPTION)?; - w.write_list(ENCRYPTION)?; - w.write_list(MAC)?; - w.write_list(MAC)?; - w.write_list(COMPRESSION)?; - w.write_list(COMPRESSION)?; - w.write_string("")?; - w.write_string("")?; - w.write_bool(false)?; - w.write_uint32(0)?; - Ok(()) - })?; + packet.write_raw_bytes(cookie.as_slice())?; + packet.write_list(KEY_EXCHANGE)?; + packet.write_list(HOST_KEY)?; + packet.write_list(ENCRYPTION)?; + packet.write_list(ENCRYPTION)?; + packet.write_list(MAC)?; + packet.write_list(MAC)?; + packet.write_list(COMPRESSION)?; + packet.write_list(COMPRESSION)?; + packet.write_string("")?; + packet.write_string("")?; + packet.write_bool(false)?; + packet.write_uint32(0)?; self.state = ConnectionState::KeyExchange; self.key_exchange = kex_algo.instance(); diff --git a/src/key_exchange/curve25519.rs b/src/key_exchange/curve25519.rs index f4c7481..8b18d69 100644 --- a/src/key_exchange/curve25519.rs +++ b/src/key_exchange/curve25519.rs @@ -118,14 +118,9 @@ impl KeyExchange for Curve25519 { let hash = self.hash(&[hash_data.as_slice()]); let signature = config.as_ref().key.sign(&hash).unwrap(); - packet - .with_writer(&|w| { - w.write_bytes(public_key.as_slice())?; - w.write_bytes(&server_public)?; - w.write_bytes(signature.as_slice())?; // Signature - Ok(()) - }) - .unwrap(); + packet.write_bytes(public_key.as_slice()).unwrap(); + packet.write_bytes(&server_public).unwrap(); + packet.write_bytes(signature.as_slice()).unwrap(); // Signature self.exchange_hash = Some(hash); self.shared_secret = Some(shared_secret); diff --git a/src/packet.rs b/src/packet.rs index 83233ed..6df3f68 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -97,19 +97,6 @@ impl Packet { } } - pub fn writer<'a>(&'a mut self) -> &'a mut Write { - match self - { - &mut Packet::Raw(ref mut data, _) => data, - &mut Packet::Payload(ref mut payload) => payload, - } - } - - pub fn with_writer(&mut self, f: &Fn(&mut Write) -> Result<()>) - -> Result<()> { - f(self.writer()) - } - pub fn reader<'a>(&'a self) -> BufReader<&'a [u8]> { match self { @@ -146,6 +133,24 @@ impl Packet { } } +impl Write for Packet { + fn write(&mut self, buf: &[u8]) -> Result { + match self + { + &mut Packet::Payload(ref mut payload) => payload.write(buf), + &mut Packet::Raw(ref mut data, ref mut payload_len) => { + let count = data.write(buf)?; + *payload_len += count; + Ok(count) + } + } + } + + fn flush(&mut self) -> Result<()> { + Ok(()) + } +} + pub trait ReadPacketExt: ReadBytesExt { fn read_string(&mut self) -> Result> { let len = self.read_u32::()?;