Starting input system

This commit is contained in:
william 2023-04-16 17:33:28 -04:00
parent 5df672398b
commit 811aa8d117
15 changed files with 575 additions and 139 deletions

View File

@ -7,6 +7,7 @@
<sourceFolder url="file://$MODULE_DIR$/messages/src" isTestSource="false" /> <sourceFolder url="file://$MODULE_DIR$/messages/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/server/src" isTestSource="false" /> <sourceFolder url="file://$MODULE_DIR$/server/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" /> <sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/linux/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/target" /> <excludeFolder url="file://$MODULE_DIR$/target" />
</content> </content>
<orderEntry type="inheritedJdk" /> <orderEntry type="inheritedJdk" />

169
Cargo.lock generated
View File

@ -2,6 +2,18 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "client" name = "client"
version = "0.1.0" version = "0.1.0"
@ -10,19 +22,170 @@ dependencies = [
] ]
[[package]] [[package]]
name = "libc" name = "hermit-abi"
version = "0.2.140" version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
[[package]]
name = "input"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeb3afdf1f8137428002b354eaf87aa629178995683941d94b04c6d145ec8937"
dependencies = [
"bitflags",
"input-sys",
"io-lifetimes",
"libc",
"log",
"udev",
]
[[package]]
name = "input-sys"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f6c2a17e8aba7217660e32863af87b0febad811d4b8620ef76b386603fddc2"
dependencies = [
"libc",
]
[[package]]
name = "io-lifetimes"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220"
dependencies = [
"hermit-abi",
"libc",
"windows-sys",
]
[[package]]
name = "libc"
version = "0.2.141"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5"
[[package]]
name = "libudev-sys"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c8469b4a23b962c1396b9b451dda50ef5b283e8dd309d69033475fa9b334324"
dependencies = [
"libc",
"pkg-config",
]
[[package]]
name = "linux"
version = "0.1.0"
dependencies = [
"libc",
]
[[package]]
name = "log"
version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "messages" name = "messages"
version = "0.1.0" version = "0.1.0"
[[package]]
name = "pkg-config"
version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]] [[package]]
name = "server" name = "server"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"input",
"libc", "libc",
"linux",
"messages", "messages",
] ]
[[package]]
name = "udev"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ebdbbd670373442a12fe9ef7aeb53aec4147a5a27a00bbc3ab639f08f48191a"
dependencies = [
"libc",
"libudev-sys",
"pkg-config",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3"
[[package]]
name = "windows_i686_gnu"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241"
[[package]]
name = "windows_i686_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a"

9
linux/Cargo.toml Normal file
View File

@ -0,0 +1,9 @@
[package]
name = "linux"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
libc = "0.2.140"

13
linux/src/buf_utils.rs Normal file
View File

@ -0,0 +1,13 @@
pub fn read_i32(buf: &[u8]) -> i32 {
(buf[3] as i32) << 24 |
(buf[2] as i32) << 16 |
(buf[1] as i32) << 8 |
buf[0] as i32
}
pub fn read_u32(buf: &[u8]) -> u32 {
(buf[3] as u32) << 24 |
(buf[2] as u32) << 16 |
(buf[1] as u32) << 8 |
buf[0] as u32
}

140
linux/src/epoll.rs Normal file
View File

@ -0,0 +1,140 @@
use std::io;
use std::os::fd::RawFd;
use libc::epoll_event;
use crate::syscall;
const EVENTS_CAPACITY: usize = 1024;
const WAIT_MAX_EVENTS: i32 = 1024;
const WAIT_TIMEOUT: i32 = 1000;
#[repr(i32)]
#[derive(Copy, Clone)]
pub enum EpollFlags {
// Events
In = libc::EPOLLIN,
Err = libc::EPOLLERR,
Hup = libc::EPOLLHUP,
Out = libc::EPOLLOUT,
Pri = libc::EPOLLPRI,
RdHup = libc::EPOLLRDHUP,
// Flags
Et = libc::EPOLLET,
Exclusive = libc::EPOLLEXCLUSIVE,
OneShot = libc::EPOLLONESHOT,
WakeUp = libc::EPOLLWAKEUP,
}
#[repr(i32)]
pub enum EpollOp {
Add = libc::EPOLL_CTL_ADD,
Modify = libc::EPOLL_CTL_MOD,
Delete = libc::EPOLL_CTL_DEL,
}
pub struct Epoll {
fd: RawFd,
flag: u32,
pub events: Vec<epoll_event>,
}
impl Epoll {
pub fn new() -> io::Result<Self> {
Self::with_flags(&[])
}
pub fn with_flags(flags: &[EpollFlags]) -> io::Result<Self> {
let flag = epoll_flags_as_u32(flags);
match epoll_create() {
Ok(fd) => Ok(Epoll {
fd,
flag,
events: Vec::with_capacity(EVENTS_CAPACITY),
}),
Err(e) => Err(e),
}
}
pub fn add_interest(&self, fd: RawFd, key: u64, flags: &[EpollFlags]) -> io::Result<()> {
epoll_add_interest(self.fd, fd, self.create_epoll_event(key, flags))
}
pub fn modify_interest(&self, fd: RawFd, key: u64, flags: &[EpollFlags]) -> io::Result<()> {
epoll_modify_interest(self.fd, fd, self.create_epoll_event(key, flags))
}
fn create_epoll_event(&self, key: u64, flags: &[EpollFlags]) -> epoll_event {
epoll_event {
events: self.flag | epoll_flags_as_u32(flags),
u64: key as u64,
}
}
pub fn wait(&mut self) -> io::Result<()> {
self.events.clear();
match epoll_wait(self.fd, &mut self.events, WAIT_MAX_EVENTS, WAIT_TIMEOUT) {
Ok(res) => {
// safe as long as the kernel does nothing wrong - copied from mio
unsafe { self.events.set_len(res) }
Ok(())
}
Err(e) => Err(e)
}
}
}
impl EpollFlags {
pub fn flag_match(&self, event: u32) -> bool {
let expected_flag = *self as i32;
event as i32 & expected_flag == expected_flag
}
}
fn epoll_create() -> io::Result<RawFd> {
let fd = syscall!(epoll_create1(0))?;
if let Ok(flags) = syscall!(fcntl(fd, libc::F_GETFD)) {
let _ = syscall!(fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC));
}
Ok(fd)
}
fn epoll_wait(epoll_fd: RawFd, events: &mut Vec<epoll_event>, max_events: i32, timeout: i32) -> io::Result<usize> {
match syscall!(epoll_wait(
epoll_fd,
events.as_mut_ptr() as *mut libc::epoll_event,
max_events,
timeout as libc::c_int
)) {
Ok(v) => Ok(v as usize),
Err(e) => Err(e)
}
}
fn epoll_add_interest(epoll_fd: RawFd, fd: RawFd, event: epoll_event) -> io::Result<()> {
epoll_ctl(epoll_fd, fd, event, EpollOp::Add)
}
fn epoll_modify_interest(epoll_fd: RawFd, fd: RawFd, event: epoll_event) -> io::Result<()> {
epoll_ctl(epoll_fd, fd, event, EpollOp::Modify)
}
fn epoll_delete_interest(epoll_fd: RawFd, fd: RawFd) -> io::Result<()> {
let event = epoll_event { events: 0, u64: 0 }; // event will be ignored
epoll_ctl(epoll_fd, fd, event, EpollOp::Delete)
}
fn epoll_ctl(epoll_fd: RawFd, fd: RawFd, mut event: epoll_event, op: EpollOp) -> io::Result<()> {
syscall!(epoll_ctl(epoll_fd, op as i32, fd, &mut event))?;
Ok(())
}
fn epoll_flags_as_u32(flags: &[EpollFlags]) -> u32 {
let mut val = 0i32;
flags.into_iter().for_each(|flag| val |= *flag as i32);
val as u32
}

131
linux/src/inotify.rs Normal file
View File

@ -0,0 +1,131 @@
use std::ffi::{CStr, CString};
use std::fs::File;
use std::io;
use std::io::Read;
use std::os::fd::{AsRawFd, FromRawFd, RawFd};
use libc::inotify_event;
use crate::buf_utils::{read_i32, read_u32};
use crate::syscall;
const INOTIFY_EVENT_BUFFER_CAPACITY: usize = 4096;
const INOTIFY_EVENT_SIZE: usize = std::mem::size_of::<inotify_event>();
#[derive(Copy, Clone, Debug)]
#[repr(u32)]
pub enum InotifyEvent {
Access = libc::IN_ACCESS,
Attrib = libc::IN_ATTRIB,
CloseWrite = libc::IN_CLOSE_WRITE,
CloseNoWrite = libc::IN_CLOSE_NOWRITE,
Create = libc::IN_CREATE,
Delete = libc::IN_DELETE,
DeleteSelf = libc::IN_DELETE_SELF,
Modify = libc::IN_MODIFY,
MoveSelf = libc::IN_MOVE_SELF,
MovedFrom = libc::IN_MOVED_FROM,
MovedTo = libc::IN_MOVED_TO,
Open = libc::IN_OPEN,
}
const IN_ALL_EVENTS: [InotifyEvent; 12] = [
InotifyEvent::Access,
InotifyEvent::Attrib,
InotifyEvent::CloseWrite,
InotifyEvent::CloseNoWrite,
InotifyEvent::Create,
InotifyEvent::Delete,
InotifyEvent::DeleteSelf,
InotifyEvent::Modify,
InotifyEvent::MoveSelf,
InotifyEvent::MovedFrom,
InotifyEvent::MovedTo,
InotifyEvent::Open,
];
pub struct Inotify {
file: File,
pub events: Vec<inotify_event>,
}
pub struct Watch {
wd: RawFd,
}
impl Inotify {
pub fn new() -> io::Result<Self> {
let fd = inotify_init()?;
let file = unsafe { File::from_raw_fd(fd) };
Ok(Inotify {
file,
events: Vec::new(),
})
}
pub fn watch(&self, path: &str, events: &[InotifyEvent]) -> io::Result<Watch> {
let wd = inotify_add_watch(self.file.as_raw_fd(), path, events)?;
Ok(Watch { wd })
}
pub fn wait(&mut self) -> io::Result<()> {
let mut buf = [0u8; INOTIFY_EVENT_BUFFER_CAPACITY];
let read_size = self.file.read(&mut buf)?; // Should block
self.events.clear();
// Read events
for i in 0..(read_size / INOTIFY_EVENT_SIZE) {
let offset = INOTIFY_EVENT_SIZE * i;
let event_buf = &buf[offset..offset + INOTIFY_EVENT_SIZE];
self.events.push(inotify_event {
wd: read_i32(&event_buf),
mask: read_u32(&event_buf[4..]),
cookie: read_u32(&event_buf[8..]),
len: read_u32(&event_buf[12..]),
});
}
Ok(())
}
}
impl Watch {}
fn inotify_init() -> io::Result<RawFd> {
syscall!(inotify_init())
}
fn inotify_add_watch(fd: RawFd, path: &str, events: &[InotifyEvent]) -> io::Result<RawFd> {
let c_path = CString::new(path).unwrap();
let mask = events_as_mask(events);
let wd = syscall!(inotify_add_watch(fd, c_path.as_ptr(), mask))?;
Ok(wd)
}
fn inotify_rm_watch(fd: RawFd, wd: RawFd) -> io::Result<()> {
syscall!(inotify_rm_watch(fd, wd))?;
Ok(())
}
fn events_as_mask(events: &[InotifyEvent]) -> u32 {
let mut mask = 0;
events.into_iter().for_each(|event| mask |= *event as u32);
mask
}
pub fn mask_as_events(mask: u32) -> Vec<InotifyEvent> {
let mut events = Vec::new();
for event in IN_ALL_EVENTS {
let event_u32 = event as u32;
if mask & event_u32 == event_u32 {
events.push(event);
}
}
events
}

15
linux/src/lib.rs Normal file
View File

@ -0,0 +1,15 @@
pub mod epoll;
pub mod inotify;
mod buf_utils;
#[macro_export]
macro_rules! syscall {
($fn: ident ( $($arg: expr),* $(,)* ) ) => {{
let res = unsafe { libc::$fn($($arg, )*) };
if res == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(res)
}
}};
}

View File

@ -4,5 +4,7 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
libc = "0.2.140"
messages = { path = "../messages" } messages = { path = "../messages" }
linux = { path = "../linux" }
input = "0.8.2"
libc = "0.2.141"

View File

@ -0,0 +1 @@
pub mod pointer;

View File

@ -0,0 +1,29 @@
pub enum PointerEvent {
Motion(PointerMotionEvent),
Button(PointerButtonEvent),
Scroll(PointerScrollEvent),
}
pub enum Axis {
Horizontal = 0,
Vertical = 1,
}
pub struct PointerButtonEvent {
pub button: u32,
pub seat_button_count: u32,
pub button_state: u32,
}
pub struct PointerMotionEvent {
pub dx: f64,
pub dx_unaccelerated: f64,
pub dy: f64,
pub dy_unaccelerated: f64,
}
pub struct PointerScrollEvent {
pub axis: Axis,
pub scroll_value: f64,
pub scroll_value_v120: f64,
}

42
server/src/input/mod.rs Normal file
View File

@ -0,0 +1,42 @@
mod events;
use std::fs::{File, OpenOptions};
use std::io;
use std::os::fd::OwnedFd;
use std::os::unix::fs::OpenOptionsExt;
use std::path::Path;
use input::{Libinput, LibinputInterface};
use libc::{O_RDONLY, O_RDWR, O_WRONLY};
const UDEV_SEAT: &str = "seat0";
struct Interface;
impl LibinputInterface for Interface {
fn open_restricted(&mut self, path: &Path, flags: i32) -> Result<OwnedFd, i32> {
OpenOptions::new()
.custom_flags(flags)
.read((flags & O_RDONLY == O_RDONLY) | (flags & O_RDWR == O_RDWR))
.write((flags & O_WRONLY == O_WRONLY) | (flags & O_RDWR == O_RDWR))
.open(path)
.map(|file| file.into())
.map_err(|err| err.raw_os_error().unwrap())
}
fn close_restricted(&mut self, fd: OwnedFd) {
File::from(fd);
}
}
pub(crate) fn read_inputs() -> io::Result<()> {
let mut input = Libinput::new_with_udev(Interface);
input.udev_assign_seat(UDEV_SEAT).unwrap();
loop {
input.dispatch().unwrap();
for event in &mut input {
// event.
println!("Got event: {:?}", event);
}
}
}

View File

@ -5,16 +5,20 @@ use std::net::SocketAddr;
use messages::client_registration::ClientRegistration; use messages::client_registration::ClientRegistration;
use messages::serialization::{DeserializeMessage, read_message_data, SerializeMessage}; use messages::serialization::{DeserializeMessage, read_message_data, SerializeMessage};
use crate::input::read_inputs;
use crate::net::tcp_server::{NextIntent, TcpClient, TcpServer}; use crate::net::tcp_server::{NextIntent, TcpClient, TcpServer};
mod net; mod net;
mod input;
mod client; mod client;
fn main() -> io::Result<()> { fn main() -> io::Result<()> {
let addr = SocketAddr::from(([127, 0, 0, 1], 4433)); // let addr = SocketAddr::from(([127, 0, 0, 1], 4433));
let mut server: TcpServer<KvmClient> = TcpServer::new(addr)?; // let mut server: TcpServer<KvmClient> = TcpServer::new(addr)?;
//
// server.listen()?;
server.listen()?; read_inputs()?;
Ok(()) Ok(())
} }

View File

@ -1,113 +0,0 @@
use std::io;
use std::os::fd::RawFd;
use libc::{epoll_event};
const EVENTS_CAPACITY: usize = 1024;
const WAIT_MAX_EVENTS: i32 = 1024;
const WAIT_TIMEOUT: i32 = 1000;
pub struct Epoll {
fd: RawFd,
pub events: Vec<epoll_event>,
}
#[derive(Copy, Clone)]
#[repr(i32)]
pub enum EpollEvent {
Read = libc::EPOLLIN,
Write = libc::EPOLLOUT,
Disconnect = libc::EPOLLRDHUP,
}
impl Epoll {
pub fn create() -> io::Result<Self> {
match epoll_create() {
Ok(fd) => Ok(Epoll {
fd,
events: Vec::with_capacity(EVENTS_CAPACITY),
}),
Err(e) => Err(e),
}
}
pub fn add_interest(&self, fd: RawFd, key: u16, events: &[EpollEvent]) -> io::Result<()> {
add_interest(self.fd, fd, create_oneshot_epoll_event(key, events))
}
pub fn modify_interest(&self, fd: RawFd, key: u16, events: &[EpollEvent]) -> io::Result<()> {
modify_interest(self.fd, fd, create_oneshot_epoll_event(key, events))
}
pub fn wait(&mut self) -> io::Result<()> {
self.events.clear();
match epoll_wait(self.fd, &mut self.events, WAIT_MAX_EVENTS, WAIT_TIMEOUT) {
Ok(res) => {
// safe as long as the kernel does nothing wrong - copied from mio
unsafe { self.events.set_len(res) }
Ok(())
}
Err(e) => Err(e)
}
}
}
pub fn match_epoll_event(event: u32, expected_event: EpollEvent) -> bool {
let expected_event = expected_event as i32;
event as i32 & expected_event == expected_event
}
macro_rules! syscall {
($fn: ident ( $($arg: expr),* $(,)* ) ) => {{
let res = unsafe { libc::$fn($($arg, )*) };
if res == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(res)
}
}};
}
fn epoll_create() -> io::Result<RawFd> {
let fd = syscall!(epoll_create1(0))?;
if let Ok(flags) = syscall!(fcntl(fd, libc::F_GETFD)) {
let _ = syscall!(fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC));
}
Ok(fd)
}
fn epoll_wait(epoll_fd: RawFd, events: &mut Vec<epoll_event>, max_events: i32, timeout: i32) -> io::Result<usize> {
match syscall!(epoll_wait(
epoll_fd,
events.as_mut_ptr() as *mut libc::epoll_event,
max_events,
timeout as libc::c_int
)) {
Ok(v) => Ok(v as usize),
Err(e) => Err(e)
}
}
fn add_interest(epoll_fd: RawFd, fd: RawFd, mut event: epoll_event) -> io::Result<()> {
syscall!(epoll_ctl(epoll_fd, libc::EPOLL_CTL_ADD, fd, &mut event))?;
Ok(())
}
fn modify_interest(epoll_fd: RawFd, fd: RawFd, mut event: epoll_event) -> io::Result<()> {
syscall!(epoll_ctl(epoll_fd, libc::EPOLL_CTL_MOD, fd, &mut event))?;
Ok(())
}
fn create_oneshot_epoll_event(key: u16, events: &[EpollEvent]) -> epoll_event {
epoll_event {
events: get_oneshot_events_flag(events),
u64: key as u64,
}
}
fn get_oneshot_events_flag(events: &[EpollEvent]) -> u32 {
let mut flag: i32 = libc::EPOLLONESHOT;
events.into_iter().for_each(|e| flag = flag | *e as i32);
flag as u32
}

View File

@ -1,3 +1,2 @@
pub mod tcp_server; pub mod tcp_server;
mod epoll;

View File

@ -4,11 +4,11 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream}; use std::net::{SocketAddr, TcpListener, TcpStream};
use std::os::fd::{AsRawFd, RawFd}; use std::os::fd::{AsRawFd, RawFd};
use crate::net::epoll::{Epoll, EpollEvent, match_epoll_event}; use linux::epoll::{Epoll, EpollFlags};
// Based on: https://www.zupzup.org/epoll-with-rust/index.html // Based on: https://www.zupzup.org/epoll-with-rust/index.html
const KEY_NEW_CONNECTION: u16 = 0; const KEY_NEW_CONNECTION: u64 = 0;
pub trait TcpClient { pub trait TcpClient {
fn new() -> Self; fn new() -> Self;
@ -28,13 +28,13 @@ pub struct TcpServer<T>
listener: TcpListener, listener: TcpListener,
listener_fd: RawFd, listener_fd: RawFd,
epoll: Epoll, epoll: Epoll,
key: u16, key: u64,
request_contexts: HashMap<u16, TcpContext<T>>, request_contexts: HashMap<u64, TcpContext<T>>,
} }
struct TcpContext<T> struct TcpContext<T>
where T: TcpClient { where T: TcpClient {
key: u16, key: u64,
stream: TcpStream, stream: TcpStream,
client: T, client: T,
} }
@ -47,8 +47,8 @@ impl<T> TcpServer<T>
let listener_fd = listener.as_raw_fd(); let listener_fd = listener.as_raw_fd();
let epoll = Epoll::create()?; let epoll = Epoll::with_flags(&[EpollFlags::OneShot])?;
epoll.add_interest(listener_fd, KEY_NEW_CONNECTION, &[EpollEvent::Read])?; epoll.add_interest(listener_fd, KEY_NEW_CONNECTION, &[EpollFlags::In])?;
Ok(TcpServer { Ok(TcpServer {
addr, addr,
@ -73,7 +73,7 @@ impl<T> TcpServer<T>
let mut to_remove = Vec::new(); let mut to_remove = Vec::new();
for (events, u64) in events { for (events, u64) in events {
match *u64 as u16 { match *u64 {
KEY_NEW_CONNECTION => self.accept_connection()?, KEY_NEW_CONNECTION => self.accept_connection()?,
key => { key => {
if let Some(context) = self.request_contexts.get_mut(&key) { if let Some(context) = self.request_contexts.get_mut(&key) {
@ -104,16 +104,16 @@ impl<T> TcpServer<T>
let context = TcpContext { key, stream, client }; let context = TcpContext { key, stream, client };
self.epoll.add_interest(fd, key, &[EpollEvent::Read, EpollEvent::Disconnect])?; self.epoll.add_interest(fd, key, &[EpollFlags::In, EpollFlags::RdHup])?;
self.request_contexts.insert(key, context); self.request_contexts.insert(key, context);
} }
Err(e) => eprintln!("Couldn't accept: {e}") Err(e) => eprintln!("Couldn't accept: {e}")
}; };
self.epoll.modify_interest(self.listener_fd, KEY_NEW_CONNECTION, &[EpollEvent::Read]) self.epoll.modify_interest(self.listener_fd, KEY_NEW_CONNECTION, &[EpollFlags::In])
} }
fn get_next_key(&mut self) -> u16 { fn get_next_key(&mut self) -> u64 {
self.key += 1; self.key += 1;
self.key self.key
} }
@ -122,15 +122,15 @@ impl<T> TcpServer<T>
fn handle_event<T>(epoll: &Epoll, context: &mut TcpContext<T>, event: u32) -> bool fn handle_event<T>(epoll: &Epoll, context: &mut TcpContext<T>, event: u32) -> bool
where T: TcpClient { where T: TcpClient {
match event { match event {
v if match_epoll_event(v, EpollEvent::Read) => { v if EpollFlags::In.flag_match(v) => {
println!("Read"); println!("Read");
return handle_read_event(epoll, context); return handle_read_event(epoll, context);
} }
v if match_epoll_event(v, EpollEvent::Write) => { v if EpollFlags::Out.flag_match(v) => {
println!("Write"); println!("Write");
return handle_write_event(epoll, context); return handle_write_event(epoll, context);
} }
v if match_epoll_event(v, EpollEvent::Disconnect) => { v if EpollFlags::RdHup.flag_match(v) => {
println!("Disconnect"); println!("Disconnect");
return true; return true;
} }
@ -171,14 +171,14 @@ fn handle_write_event<T>(epoll: &Epoll, context: &mut TcpContext<T>) -> bool
fn set_interest<T>(epoll: &Epoll, context: &TcpContext<T>, next_intent: &NextIntent) fn set_interest<T>(epoll: &Epoll, context: &TcpContext<T>, next_intent: &NextIntent)
where T: TcpClient { where T: TcpClient {
let event = match next_intent { let event = match next_intent {
NextIntent::Read => EpollEvent::Read, NextIntent::Read => EpollFlags::In,
NextIntent::Write => EpollEvent::Write, NextIntent::Write => EpollFlags::Out,
}; };
epoll.modify_interest( epoll.modify_interest(
context.stream.as_raw_fd(), context.stream.as_raw_fd(),
context.key, context.key,
&[event, EpollEvent::Disconnect]) &[event, EpollFlags::RdHup])
.unwrap(); .unwrap();
} }