diff --git a/kubi-udp/src/client.rs b/kubi-udp/src/client.rs index 91c2222..0174048 100644 --- a/kubi-udp/src/client.rs +++ b/kubi-udp/src/client.rs @@ -3,7 +3,7 @@ use std::{ net::{UdpSocket, SocketAddr}, time::{Instant, Duration}, marker::PhantomData, - collections::VecDeque, + collections::{VecDeque, vec_deque::Drain as DrainDeque}, }; use bincode::{Encode, Decode}; use crate::{ @@ -12,7 +12,7 @@ use crate::{ common::ClientId }; -#[derive(Default, Clone)] +#[derive(Default, Clone, Debug)] #[repr(u8)] pub enum DisconnectReason { #[default] @@ -35,9 +35,17 @@ pub struct ClientConfig { pub timeout: Duration, pub heartbeat_interval: Duration, } +impl Default for ClientConfig { + fn default() -> Self { + Self { + timeout: Duration::from_secs(5), + heartbeat_interval: Duration::from_secs(3), + } + } +} pub enum ClientEvent where T: Encode + Decode { - Connected, + Connected(ClientId), Disconnected(DisconnectReason), MessageReceived(T) } @@ -52,7 +60,7 @@ pub struct Client where S: Encode + Decode, R: Encode + Decode { client_id: Option, disconnect_reason: DisconnectReason, event_queue: VecDeque>, - _s: PhantomData<*const S>, + _s: PhantomData, } impl Client where S: Encode + Decode, R: Encode + Decode { pub fn new(addr: SocketAddr, config: ClientConfig) -> Result { @@ -156,7 +164,7 @@ impl Client where S: Encode + Decode, R: Encode + Decode { ServerPacket::Connected(client_id) => { self.client_id = Some(client_id); self.status = ClientStatus::Connected; - self.event_queue.push_back(ClientEvent::Connected); + self.event_queue.push_back(ClientEvent::Connected(client_id)); return Ok(()) }, ServerPacket::Disconnected(reason) => { @@ -180,7 +188,7 @@ impl Client where S: Encode + Decode, R: Encode + Decode { pub fn get_event(&mut self) -> Option> { self.event_queue.pop_front() } - pub fn process_events(&mut self) -> impl Iterator> + '_ { + pub fn process_events(&mut self) -> DrainDeque> { self.event_queue.drain(..) } } diff --git a/kubi-udp/src/server.rs b/kubi-udp/src/server.rs index f9a8322..d2d563c 100644 --- a/kubi-udp/src/server.rs +++ b/kubi-udp/src/server.rs @@ -2,7 +2,7 @@ use std::{ net::{UdpSocket, SocketAddr}, time::Instant, marker::PhantomData, - collections::VecDeque + collections::{VecDeque, vec_deque::Drain as DrainDeque} }; use anyhow::{Result, bail}; use bincode::{Encode, Decode}; @@ -105,8 +105,9 @@ impl Server where S: Encode + Decode, R: Encode + Decode { } Ok(()) } - pub fn send_message(&mut self) { - + pub fn send_message(&mut self, id: ClientId, message: S) -> anyhow::Result<()> { + self.send_packet(IdServerPacket(Some(id), ServerPacket::Data(message)))?; + Ok(()) } pub fn bind(addr: SocketAddr, config: ServerConfig) -> anyhow::Result { assert!(config.max_clients <= MAX_CLIENTS); @@ -183,4 +184,11 @@ impl Server where S: Encode + Decode, R: Encode + Decode { } Ok(()) } + + pub fn get_event(&mut self) -> Option> { + self.event_queue.pop_front() + } + pub fn process_events(&mut self) -> DrainDeque> { + self.event_queue.drain(..) + } } diff --git a/kubi-udp/tests/test.rs b/kubi-udp/tests/test.rs new file mode 100644 index 0000000..82d1bbd --- /dev/null +++ b/kubi-udp/tests/test.rs @@ -0,0 +1,87 @@ +use kubi_udp::{ + server::{Server, ServerConfig, ServerEvent}, + client::{Client, ClientConfig, ClientEvent}, +}; +use std::{thread, time::Duration}; + +const TEST_ADDR: &str = "127.0.0.1:12345"; + +type CtsMessage = u32; +type StcMessage = u64; + +const CTS_MSG: CtsMessage = 0xbeef_face; +const STC_MSG: StcMessage = 0xdead_beef_cafe_face; + +#[test] +fn test_connection() { + //Create server and client + let mut server: Server = Server::bind( + TEST_ADDR.parse().expect("Invalid TEST_ADDR"), + ServerConfig::default() + ).expect("Failed to create server"); + let mut client: Client = Client::new( + TEST_ADDR.parse().unwrap(), + ClientConfig::default() + ).expect("Failed to create client"); + + //Start server update thread + let server_handle = thread::spawn(move || { + let mut message_received = false; + loop { + server.update(); + let events: Vec<_> = server.process_events().collect(); + for event in events { + match event { + ServerEvent::Connected(id) => { + assert_eq!(id.get(), 1, "Unexpected client id"); + server.send_message(id, STC_MSG); + }, + ServerEvent::Disconnected(id) => { + assert!(message_received, "Client {id} disconnected from the server before sending the message") + }, + ServerEvent::MessageReceived { from, message } => { + assert_eq!(message, CTS_MSG, "Received message not equal"); + message_received = true; + break; + }, + _ => () + } + } + } + }); + + //Wait a bit + thread::sleep(Duration::from_secs(1)); + + //Connect client + client.connect().expect("Client connect failed"); + + //Start updating the client + let client_handle = thread::spawn(move || { + let mut message_received = false; + loop { + client.update(); + let events: Vec<_> = client.process_events().collect(); + for event in events { + match event { + ClientEvent::Connected(id) => { + assert_eq!(id.get(), 1, "Unexpected client id"); + client.send_message(CTS_MSG); + }, + ClientEvent::Disconnected(reason) => { + assert!(message_received, "Client lost connection to the server before sending the message with reason: {reason:?}") + }, + ClientEvent::MessageReceived(data) => { + assert_eq!(data, STC_MSG, "Received message not equal"); + message_received = true; + break; + }, + _ => () + } + } + } + }); + + client_handle.join().unwrap(); + server_handle.join().unwrap(); +}