diff --git a/src/connection.rs b/src/connection.rs index 0fc3559..241a0a9 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -61,11 +61,11 @@ impl<'a> Connection { } } - pub fn run(&mut self, mut stream: TcpStream) -> Result<()> { - self.send_id(&mut stream)?; - self.read_id(&stream)?; + pub fn run(&mut self, mut stream: &mut S) -> Result<()> { + self.send_id(stream)?; + self.read_id(stream)?; - let mut reader = BufReader::new(&stream); + let mut reader = BufReader::new(stream); loop { let packet = self.recv(&mut reader)?; @@ -142,7 +142,7 @@ impl<'a> Connection { Ok(()) } - fn send_id(&mut self, stream: &mut TcpStream) -> io::Result<()> { + fn send_id(&mut self, stream: &mut Write) -> io::Result<()> { let id = format!("SSH-2.0-RedoxSSH_{}", env!("CARGO_PKG_VERSION")); info!("Identifying as {:?}", id); @@ -155,26 +155,19 @@ impl<'a> Connection { Ok(()) } - fn read_id(&mut self, stream: &TcpStream) -> io::Result<()> { - use std::ascii::AsciiExt; + fn read_id(&mut self, stream: &mut Read) -> io::Result<()> { + use std::str; - let mut id = String::new(); + let mut buf = [0; 255]; + let count = stream.read(&mut buf)?; - for byte in stream.take(255).bytes() { - match byte - { - Ok(b'\n') => break, - Ok(b) if b.is_ascii() => id.push(b as char), - Ok(_) => {} - Err(_) => break, - } - } - - let id = id.trim().to_owned(); + let id = str::from_utf8(&buf[0..count]).map(str::trim).or(Err( + io::Error::new(io::ErrorKind::InvalidData, "invalid id"), + ))?; if id.starts_with("SSH-") { info!("Peer identifies as {:?}", id); - self.hash_data.client_id = Some(id); + self.hash_data.client_id = Some(id.to_owned()); Ok(()) } else { @@ -338,10 +331,10 @@ impl<'a> Connection { { "pty-req" => Some(ChannelRequest::Pty { term: reader.read_utf8()?, - char_width: reader.read_uint32()?, - row_height: reader.read_uint32()?, - pixel_width: reader.read_uint32()?, - pixel_height: reader.read_uint32()?, + chars: reader.read_uint32()? as u16, + rows: reader.read_uint32()? as u16, + pixel_width: reader.read_uint32()? as u16, + pixel_height: reader.read_uint32()? as u16, modes: reader.read_string()?, }), "shell" => Some(ChannelRequest::Shell), @@ -375,22 +368,7 @@ impl<'a> Connection { let mut channel = self.channels.get_mut(&channel_id).unwrap(); channel.data(data.as_slice())?; - let data = channel.read()?; - - if data.len() > 0 { - let mut res = Packet::new(MessageType::ChannelData); - res.with_writer(&|w| { - w.write_uint32(0)?; - w.write_bytes(data.as_slice())?; - Ok(()) - })?; - - Ok(Some(res)) - } - else { - Ok(None) - } - + Ok(None) } fn kex_init(&mut self, packet: Packet) -> Result> { diff --git a/src/server.rs b/src/server.rs index 623d9e5..590e279 100644 --- a/src/server.rs +++ b/src/server.rs @@ -26,7 +26,7 @@ impl Server { TcpListener::bind((&*self.config.host, self.config.port))?; loop { - let (stream, addr) = listener.accept()?; + let (mut stream, addr) = listener.accept()?; let config = self.config.clone(); debug!("Incoming connection from {}", addr); @@ -35,7 +35,7 @@ impl Server { let mut connection = Connection::new(ConnectionType::Server(config)); - let result = connection.run(stream); + let result = connection.run(&mut stream); if let Some(error) = result.err() { println!("sshd: {}", error)