From 9480693a10abb6bfae909f5bdf374a3cc4375108 Mon Sep 17 00:00:00 2001 From: Thomas Gatzweiler Date: Sat, 22 Jul 2017 21:58:40 +0200 Subject: [PATCH] Async IO Hack --- src/bin/sshd.rs | 1 + src/channel.rs | 115 +++++++++++++++---------------- src/connection.rs | 54 ++++++++++++--- src/lib.rs | 12 ++-- src/server.rs | 6 +- src/sys/linux/mod.rs | 19 +++++ src/sys/linux/pty.rs | 160 +++++++++++++++++++++++++++++++++++++++++++ src/sys/redox.rs | 31 --------- src/sys/redox/mod.rs | 14 ++++ src/sys/redox/pty.rs | 19 +++++ src/sys/unix.rs | 75 -------------------- 11 files changed, 325 insertions(+), 181 deletions(-) create mode 100644 src/sys/linux/mod.rs create mode 100644 src/sys/linux/pty.rs delete mode 100644 src/sys/redox.rs create mode 100644 src/sys/redox/mod.rs create mode 100644 src/sys/redox/pty.rs delete mode 100644 src/sys/unix.rs diff --git a/src/bin/sshd.rs b/src/bin/sshd.rs index b54896d..516dcde 100644 --- a/src/bin/sshd.rs +++ b/src/bin/sshd.rs @@ -2,6 +2,7 @@ extern crate ssh; extern crate log; use std::env; +use std::error::Error; use std::fs::File; use std::io::{self, Write}; use std::process; diff --git a/src/channel.rs b/src/channel.rs index 73139bb..3227ee5 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,25 +1,24 @@ -use std::fs::{File, OpenOptions}; -use std::io::{self, Read, Write}; -use std::os::unix::io::{FromRawFd, IntoRawFd, RawFd}; +use std::fs::OpenOptions; +use std::io; +use std::os::unix::io::{FromRawFd, IntoRawFd}; use std::os::unix::process::CommandExt; -use std::path::PathBuf; use std::process::{self, Stdio}; -use std::thread::{self, JoinHandle}; +use std::sync::mpsc; use sys; +use connection::ConnectionEvent; + pub type ChannelId = u32; -#[derive(Debug)] pub struct Channel { id: ChannelId, peer_id: ChannelId, process: Option, - pty: Option<(RawFd, PathBuf)>, - master: Option, + pty: Option, window_size: u32, peer_window_size: u32, max_packet_size: u32, - read_thread: Option>, + events: mpsc::Sender, } #[derive(Debug)] @@ -38,18 +37,17 @@ pub enum ChannelRequest { impl Channel { pub fn new( id: ChannelId, peer_id: ChannelId, peer_window_size: u32, - max_packet_size: u32 + max_packet_size: u32, events: mpsc::Sender ) -> Channel { Channel { id: id, peer_id: peer_id, process: None, - master: None, pty: None, window_size: peer_window_size, peer_window_size: peer_window_size, max_packet_size: max_packet_size, - read_thread: None, + events: events, } } @@ -57,6 +55,10 @@ impl Channel { self.id } + pub fn peer_id(&self) -> ChannelId { + self.peer_id + } + pub fn window_size(&self) -> u32 { self.window_size } @@ -75,83 +77,78 @@ impl Channel { pixel_height, .. } => { - let (master_fd, tty_path) = sys::getpty(); + if let Ok(mut pty) = sys::Pty::get() { + pty.set_winsize(chars, rows, pixel_width, pixel_height); - sys::set_winsize( - master_fd, - chars, - rows, - pixel_width, - pixel_height, - ); + let events = self.events.clone(); + let id = self.id; - self.read_thread = Some(thread::spawn(move || { - use libc::dup; - let master2 = unsafe { dup(master_fd) }; + pty.subscribe(move || { + events.send(ConnectionEvent::ChannelData(id)).map_err( + |_| (), + ) + }); - println!("dup result: {}", dup as u32); - let mut master = unsafe { File::from_raw_fd(master2) }; - loop { - use std::str::from_utf8_unchecked; - let mut buf = [0; 4096]; - let count = master.read(&mut buf).unwrap(); - if count == 0 { - break; - } - println!("Read: {}", unsafe { - from_utf8_unchecked(&buf[0..count]) - }); - } - - println!("Quitting read thread."); - })); - - self.pty = Some((master_fd, tty_path)); - self.master = Some(unsafe { File::from_raw_fd(master_fd) }); + self.pty = Some(pty); + } } ChannelRequest::Shell => { - if let Some(&(_, ref tty_path)) = self.pty.as_ref() { + if let Some(ref pty) = self.pty { let stdin = OpenOptions::new() .read(true) .write(true) - .open(&tty_path) + .open(pty.path()) .unwrap() .into_raw_fd(); let stdout = OpenOptions::new() .read(true) .write(true) - .open(&tty_path) + .open(pty.path()) .unwrap() .into_raw_fd(); let stderr = OpenOptions::new() .read(true) .write(true) - .open(&tty_path) + .open(pty.path()) .unwrap() .into_raw_fd(); - process::Command::new("login") - .stdin(unsafe { Stdio::from_raw_fd(stdin) }) - .stdout(unsafe { Stdio::from_raw_fd(stdout) }) - .stderr(unsafe { Stdio::from_raw_fd(stderr) }) - .before_exec(|| sys::before_exec()) - .spawn() - .unwrap(); + self.process = Some( + process::Command::new("login") + .stdin(unsafe { Stdio::from_raw_fd(stdin) }) + .stdout(unsafe { Stdio::from_raw_fd(stdout) }) + .stderr(unsafe { Stdio::from_raw_fd(stderr) }) + .before_exec(|| sys::before_exec()) + .spawn() + .unwrap(), + ); } } } debug!("Channel Request: {:?}", request); } - pub fn data(&mut self, data: &[u8]) -> io::Result<()> { - if let Some(ref mut master) = self.master { - master.write_all(data)?; - master.flush() + pub fn write(&mut self, data: &[u8]) -> io::Result<()> { + match self.pty + { + Some(ref mut pty) => pty.write(data), + _ => Ok(()), } - else { - Ok(()) + } + + pub fn read(&mut self, data: &mut [u8]) -> io::Result { + match self.pty + { + Some(ref mut pty) => pty.read(data), + _ => Ok(0), } } } + +impl Drop for Channel { + fn drop(&mut self) { + self.process.take().map(|mut p| p.kill()); + } +} diff --git a/src/connection.rs b/src/connection.rs index 12289ed..153ac31 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,6 +1,7 @@ use std::collections::{BTreeMap, VecDeque}; use std::io::{self, BufReader, Read, Write}; use std::sync::Arc; +use std::sync::mpsc; use channel::{Channel, ChannelId, ChannelRequest}; use encryption::{AesCtr, Decryptor, Encryption}; @@ -18,6 +19,11 @@ enum ConnectionState { Established, } +pub enum ConnectionEvent { + ChannelData(ChannelId), + StreamData, +} + #[derive(Clone)] pub enum ConnectionType { Server(Arc), @@ -42,10 +48,14 @@ pub struct Connection { seq: (u32, u32), tx_queue: VecDeque, channels: BTreeMap, + pub events_tx: mpsc::Sender, + events_rx: mpsc::Receiver, } impl<'a> Connection { pub fn new(conn_type: ConnectionType) -> Connection { + let (events_tx, events_rx) = mpsc::channel(); + Connection { conn_type: conn_type, hash_data: HashData::default(), @@ -57,6 +67,8 @@ impl<'a> Connection { seq: (0, 0), tx_queue: VecDeque::new(), channels: BTreeMap::new(), + events_rx: events_rx, + events_tx: events_tx, } } @@ -67,23 +79,43 @@ impl<'a> Connection { let mut reader = BufReader::new(stream); loop { - let packet = self.recv(&mut reader)?; - let response = self.process(packet)?; + match self.events_rx.recv() + { + Ok(ConnectionEvent::ChannelData(id)) => { + if let Some(ref mut channel) = self.channels.get_mut(&id) { + let mut buf = [0; 4096]; + let count = channel.read(&mut buf)?; + if count > 0 { + let mut res = Packet::new(MessageType::ChannelData); + res.write_uint32(channel.peer_id())?; + res.write_bytes(&buf[..count])?; + } + } + } + Ok(ConnectionEvent::StreamData) => { + let packet = self.recv(&mut reader)?; + let response = self.process(packet)?; - let mut stream = reader.get_mut(); + let mut stream = reader.get_mut(); - if let Some(packet) = response { - self.send(&mut stream, packet)?; + if let Some(packet) = response { + self.send(&mut stream, packet)?; + } + } + Err(_) => {} } // Send additional packets from the queue let mut packets: Vec = self.tx_queue.drain(..).collect(); + let mut stream = reader.get_mut(); for packet in packets.drain(..) { self.send(&mut stream, packet)?; } } } + + fn recv(&mut self, mut stream: &mut Read) -> Result { let packet = if let Some((ref mut c2s, _)) = self.encryption { let mut decryptor = Decryptor::new(&mut **c2s, &mut stream); @@ -296,7 +328,13 @@ impl<'a> Connection { 0 }; - let channel = Channel::new(id, peer_id, window_size, max_packet_size); + let channel = Channel::new( + id, + peer_id, + window_size, + max_packet_size, + self.events_tx.clone(), + ); let mut res = Packet::new(MessageType::ChannelOpenConfirmation); res.write_uint32(peer_id)?; @@ -304,7 +342,7 @@ impl<'a> Connection { res.write_uint32(channel.window_size())?; res.write_uint32(channel.max_packet_size())?; - debug!("Open {:?}", channel); + debug!("Open Channel {}", id); self.channels.insert(id, channel); @@ -357,7 +395,7 @@ impl<'a> Connection { let data = reader.read_string()?; let mut channel = self.channels.get_mut(&channel_id).unwrap(); - channel.data(data.as_slice())?; + channel.write(data.as_slice())?; Ok(None) } diff --git a/src/lib.rs b/src/lib.rs index cc4eed8..cae8ecf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,8 +9,6 @@ extern crate syscall; #[cfg(not(target_os = "redox"))] extern crate libc; -mod error; -mod algorithm; mod packet; mod message; mod connection; @@ -19,15 +17,17 @@ mod encryption; mod mac; mod channel; +pub mod error; +pub mod algorithm; pub mod public_key; pub mod server; +pub use self::server::{Server, ServerConfig}; + #[cfg(target_os = "redox")] -#[path = "sys/redox.rs"] +#[path = "sys/redox/mod.rs"] pub mod sys; #[cfg(not(target_os = "redox"))] -#[path = "sys/unix.rs"] +#[path = "sys/linux/mod.rs"] pub mod sys; - -pub use self::server::{Server, ServerConfig}; diff --git a/src/server.rs b/src/server.rs index 590e279..7a80480 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,10 +1,12 @@ use std::io; use std::net::TcpListener; +use std::os::unix::io::AsRawFd; use std::sync::Arc; use std::thread; -use connection::{Connection, ConnectionType}; +use connection::{Connection, ConnectionEvent, ConnectionType}; use public_key::KeyPair; +use sys; pub struct ServerConfig { pub host: String, @@ -38,7 +40,7 @@ impl Server { let result = connection.run(&mut stream); if let Some(error) = result.err() { - println!("sshd: {}", error) + println!("sshd: {}", error); } }); } diff --git a/src/sys/linux/mod.rs b/src/sys/linux/mod.rs new file mode 100644 index 0000000..25dc050 --- /dev/null +++ b/src/sys/linux/mod.rs @@ -0,0 +1,19 @@ +use std::io::Result; +use std::os::unix::io::RawFd; + +mod pty; +pub use self::pty::Pty; + +pub fn before_exec() -> Result<()> { + use libc; + unsafe { + libc::setsid(); + libc::ioctl(0, libc::TIOCSCTTY, 1); + } + Ok(()) +} + +pub fn fork() -> usize { + use libc; + unsafe { libc::fork() as usize } +} diff --git a/src/sys/linux/pty.rs b/src/sys/linux/pty.rs new file mode 100644 index 0000000..909b604 --- /dev/null +++ b/src/sys/linux/pty.rs @@ -0,0 +1,160 @@ +use libc; +use std::ffi::CStr; +use std::fs::{File, OpenOptions}; +use std::io::{self, Read, Write}; +use std::os::unix::fs::OpenOptionsExt; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; +use std::path::PathBuf; +use std::sync::mpsc; +use std::thread::{self, JoinHandle}; + +pub struct Pty { + master: File, + path: PathBuf, + sub_thread: Option>, + sub_thread_tx: Option>, +} + +enum ThreadCommand { + WaitForData, + Stop, +} + +impl Pty { + pub fn get() -> Result { + const TIOCPKT: libc::c_ulong = 0x5420; + + let master_fd = OpenOptions::new() + .read(true) + .write(true) + .custom_flags(libc::O_NONBLOCK) + .open("/dev/ptmx") + .unwrap() + .into_raw_fd(); + + unsafe { + use std::io::Error; + let mut flag: libc::c_int = 1; + + if libc::ioctl( + master_fd, + TIOCPKT, + &mut flag as *mut libc::c_int, + ) < 0 + { + error!("ioctl: {:?}", Error::last_os_error()); + return Err(()); + } + if libc::grantpt(master_fd) < 0 { + error!("grantpt: {:?}", Error::last_os_error()); + return Err(()); + } + if libc::unlockpt(master_fd) < 0 { + error!("unlockpt: {:?}", Error::last_os_error()); + return Err(()); + } + } + + let tty_path = unsafe { + PathBuf::from( + CStr::from_ptr(libc::ptsname(master_fd)) + .to_string_lossy() + .into_owned(), + ) + }; + + let master = unsafe { File::from_raw_fd(master_fd) }; + + Ok(Pty { + master: master, + path: tty_path, + sub_thread: None, + sub_thread_tx: None, + }) + } + + pub fn subscribe(&mut self, callback: F) + where + F: Fn() -> Result<(), ()> + Send + 'static, + { + let (thread_tx, thread_rx) = mpsc::channel(); + + let mut pollfd = libc::pollfd { + fd: self.master.as_raw_fd(), + events: libc::POLLIN, + revents: 0, + }; + + self.sub_thread = Some(thread::spawn(move || { + loop { + match thread_rx.recv() + { + Ok(ThreadCommand::WaitForData) => {} + Ok(ThreadCommand::Stop) => return, + Err(_) => return, + } + + // Clear receive queue + while !thread_rx.try_recv().is_err() {} + + unsafe { libc::poll(&mut pollfd as *mut libc::pollfd, 1, -1) }; + if callback().is_err() { + return; + } + } + })); + + self.sub_thread_tx = Some(thread_tx); + } + + pub fn path<'a>(&'a self) -> &'a PathBuf { + &self.path + } + + pub fn set_winsize(&self, row: u16, col: u16, xpixel: u16, ypixel: u16) { + let size = libc::winsize { + ws_row: row, + ws_col: col, + ws_xpixel: xpixel, + ws_ypixel: ypixel, + }; + + unsafe { + let fd = self.master.as_raw_fd(); + libc::ioctl(fd, libc::TIOCSWINSZ, &size as *const libc::winsize); + } + } + + pub fn write(&mut self, data: &[u8]) -> io::Result<()> { + self.master.write_all(data)?; + self.master.flush() + } + + pub fn read(&mut self, data: &mut [u8]) -> io::Result { + match self.master.read(data) + { + Ok(count) => { + self.sub_thread_tx.as_ref().map(|tx| { + tx.send(ThreadCommand::WaitForData) + }); + Ok(count) + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.sub_thread_tx.as_ref().map(|tx| { + tx.send(ThreadCommand::WaitForData) + }); + Ok(0) + } + Err(e) => Err(e), + } + } +} + +impl Drop for Pty { + fn drop(&mut self) { + self.sub_thread_tx.take().map( + |tx| tx.send(ThreadCommand::Stop), + ); + self.sub_thread = None; + } +} diff --git a/src/sys/redox.rs b/src/sys/redox.rs deleted file mode 100644 index 1b0ef5b..0000000 --- a/src/sys/redox.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::io::Result; -use std::os::unix::io::RawFd; -use std::path::PathBuf; - -pub fn before_exec() -> Result<()> { - Ok(()) -} - -pub fn fork() -> usize { - extern crate syscall; - unsafe { syscall::clone(0).unwrap() } -} - -pub fn set_winsize(fd: RawFd, row: u16, col: u16, xpixel: u16, ypixel: u16) {} - -pub fn getpty() -> (RawFd, PathBuf) { - use syscall; - - let master = syscall::open("pty:", syscall::O_RDWR | syscall::O_CREAT) - .unwrap(); - - let mut buf: [u8; 4096] = [0; 4096]; - - let count = syscall::fpath(master, &mut buf).unwrap(); - ( - master, - PathBuf::from(unsafe { - String::from_utf8_unchecked(Vec::from(&buf[..count])) - }), - ) -} diff --git a/src/sys/redox/mod.rs b/src/sys/redox/mod.rs new file mode 100644 index 0000000..ee4fb75 --- /dev/null +++ b/src/sys/redox/mod.rs @@ -0,0 +1,14 @@ +use std::os::unix::io::RawFd; + +pub mod pty; + +pub fn before_exec() -> Result<()> { + Ok(()) +} + +pub fn fork() -> usize { + extern crate syscall; + unsafe { syscall::clone(0).unwrap() } +} + +pub fn set_winsize(fd: RawFd, row: u16, col: u16, xpixel: u16, ypixel: u16) {} diff --git a/src/sys/redox/pty.rs b/src/sys/redox/pty.rs new file mode 100644 index 0000000..fa569eb --- /dev/null +++ b/src/sys/redox/pty.rs @@ -0,0 +1,19 @@ +use std::io::Result; +use std::os::unix::io::RawFd; +use std::path::PathBuf; + +pub fn getpty() -> (RawFd, PathBuf) { + use syscall; + + let master = syscall::open( + "pty:", + syscall::O_RDWR | syscall::O_CREAT | syscall::O_NONBLOCK, + ).unwrap(); + + let mut buf: [u8; 4096] = [0; 4096]; + + let count = syscall::fpath(master, &mut buf).unwrap(); + let path = String::from_utf8(Vec::from(&buf[..count]).or(())).unwrap(); + + (master, PathBuf::from(path)) +} diff --git a/src/sys/unix.rs b/src/sys/unix.rs deleted file mode 100644 index 79b7946..0000000 --- a/src/sys/unix.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::io::Result; -use std::os::unix::io::RawFd; -use std::path::PathBuf; - -pub fn before_exec() -> Result<()> { - use libc; - unsafe { - libc::setsid(); - libc::ioctl(0, libc::TIOCSCTTY, 1); - } - Ok(()) -} - -pub fn fork() -> usize { - use libc; - unsafe { libc::fork() as usize } -} - -pub fn set_winsize(fd: RawFd, row: u16, col: u16, xpixel: u16, ypixel: u16) { - use libc; - unsafe { - let size = libc::winsize { - ws_row: row, - ws_col: col, - ws_xpixel: xpixel, - ws_ypixel: ypixel, - }; - libc::ioctl(fd, libc::TIOCSWINSZ, &size as *const libc::winsize); - } -} - -pub fn getpty() -> (RawFd, PathBuf) { - use libc; - use std::ffi::CStr; - use std::fs::OpenOptions; - use std::io::Error; - use std::os::unix::io::IntoRawFd; - - const TIOCPKT: libc::c_ulong = 0x5420; - extern "C" { - fn ptsname(fd: libc::c_int) -> *const libc::c_char; - fn grantpt(fd: libc::c_int) -> libc::c_int; - fn unlockpt(fd: libc::c_int) -> libc::c_int; - fn ioctl(fd: libc::c_int, request: libc::c_ulong, ...) -> libc::c_int; - } - - let master_fd = OpenOptions::new() - .read(true) - .write(true) - .open("/dev/ptmx") - .unwrap() - .into_raw_fd(); - - unsafe { - let mut flag: libc::c_int = 1; - if ioctl(master_fd, TIOCPKT, &mut flag as *mut libc::c_int) < 0 { - panic!("ioctl: {:?}", Error::last_os_error()); - } - if grantpt(master_fd) < 0 { - panic!("grantpt: {:?}", Error::last_os_error()); - } - if unlockpt(master_fd) < 0 { - panic!("unlockpt: {:?}", Error::last_os_error()); - } - } - - let tty_path = unsafe { - PathBuf::from( - CStr::from_ptr(ptsname(master_fd)) - .to_string_lossy() - .into_owned(), - ) - }; - (master_fd, tty_path) -}