From 923e8b7218da119dd608c6eb72eb197a6c2f1c9b Mon Sep 17 00:00:00 2001 From: Talha Qamar Date: Thu, 19 Sep 2024 14:00:25 +0500 Subject: [PATCH] added some basic validation --- dev/src/idl/mod.rs | 10 +- dev/src/idl/parser.rs | 10 +- dev/src/idl/protocol.rs | 169 +++++++++++++++++++++++++++-- dev/src/idl/types.rs | 35 +++++- sysdata/idl/test/src/protocol.aldi | 3 + 5 files changed, 204 insertions(+), 23 deletions(-) create mode 100644 sysdata/idl/test/src/protocol.aldi diff --git a/dev/src/idl/mod.rs b/dev/src/idl/mod.rs index b1098dc2..4d1217e7 100644 --- a/dev/src/idl/mod.rs +++ b/dev/src/idl/mod.rs @@ -2,7 +2,7 @@ mod parser; mod types; mod protocol; -use crate::idl::parser::parse; +use crate::idl::{parser::parse, types::get_protocols}; use std::io::Read; use logos::{Lexer, Logos, Skip}; @@ -54,8 +54,7 @@ enum Token { #[token(";", get_token_position)] SemiColon((usize, usize)), - #[token(",", get_token_position)] - Comma((usize, usize)), + #[token(",", get_token_position)] Comma((usize, usize)), #[token("=", get_token_position)] Equal((usize, usize)), @@ -89,7 +88,10 @@ pub fn build_idl(name: String) { } } - println!("{:#?}", parse(tokens)); + let protocols = get_protocols(parse(tokens)); + let data : Vec = vec![1, 5, 12, 12, 12, 12, 3, 28, 8, 28]; + println!("{:#?}", &protocols); + protocols.validate("Foo", "bar" , data).unwrap(); } fn open_protocol(name: String) -> String { diff --git a/dev/src/idl/parser.rs b/dev/src/idl/parser.rs index d73acf1e..8f15d823 100644 --- a/dev/src/idl/parser.rs +++ b/dev/src/idl/parser.rs @@ -41,15 +41,15 @@ pub struct EnumMember { #[derive(Debug, Clone, PartialEq)] pub struct ProtocolDeclaration{ - name : String, - interface : Vec, + pub name : String, + pub interface : Vec, } #[derive(Debug, Clone, PartialEq)] pub struct FuncDeclaration { - name : String, - arg_list : Vec, - return_type : String, + pub name : String, + pub arg_list : Vec, + pub return_type : String, } macro_rules! consume { diff --git a/dev/src/idl/protocol.rs b/dev/src/idl/protocol.rs index 7620922f..02c06b17 100644 --- a/dev/src/idl/protocol.rs +++ b/dev/src/idl/protocol.rs @@ -1,14 +1,167 @@ -pub struct Protocol { +use std::collections::HashMap; +use crate::idl::types::Type; + +#[derive(Debug, Clone, PartialEq)] +pub struct Function{ + pub arguments : Vec, } -impl Protocol { - pub fn is_empty(&self) -> bool { - true +#[derive(Debug, Clone, PartialEq)] +pub struct Protocol{ + interface : HashMap, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Protocols { + protocols : HashMap, + symbol_table : HashMap, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ValidationError{ + IncorrectVersion, + InvalidHeader, + FunctionDoesNotExist, + ProtocolDoesNotExist, + InvalidSize, + InvalidArgument, + NonExistentType(String), +} + +impl Protocols { + pub fn new(symbol_table: HashMap) -> Self { + let protocols = HashMap::new(); + Self { protocols, symbol_table } + } + + pub fn add_protocol(&mut self, name : String, interface : HashMap) { + self.protocols.insert(name, Protocol::new(interface)); } - pub fn validate_data(&self, data: Vec) -> bool { - if !data.is_empty() && self.is_empty() { - return false; + pub fn validate(&self, protocol_name : &str, function_name : &str, data : Vec) -> Result<(), ValidationError>{ + match self.protocols.get(protocol_name) { + Some(s) => s.validate(function_name, data, &self.symbol_table), + None => Err(ValidationError::ProtocolDoesNotExist), + } + } + +} + + + +impl Protocol { + pub fn new(interface: HashMap) -> Self { + Self {interface} + } + fn validate(&self, function_name : &str, data : Vec, symbols : &HashMap) -> Result<(), ValidationError> { + match self.interface.get(function_name){ + Some(s) => s.validate(data, symbols), + None => Err(ValidationError::FunctionDoesNotExist), } - true } } + +impl Function { + fn validate(&self, data : Vec, symbols : &HashMap) -> Result<(), ValidationError> { + let mut types = Vec::new(); + for arg in self.arguments.iter() { + let type_value = symbols.get(arg); + if let Some(type_value) = type_value { + types.push(type_value); + } + else{ + return Err(ValidationError::NonExistentType(arg.to_string())); + } + } + let mut data = data.iter(); + if let Some(ver) = data.next() { + if *ver == 1 { + // We got the correct version number + // Now to parse individual argumebts + let mut types = types.iter().peekable(); + loop { + let type_byte = data.next(); + if let Some(type_byte) = type_byte { + let type_value = types.next(); + if type_value.is_none() { + return Err(ValidationError::InvalidSize); + } + let type_value = type_value.unwrap(); + let data_type = match type_byte { + 0 => Some(Type::U64), + 1 => Some(Type::U32), + 2 => Some(Type::U16), + 3 => Some(Type::U8), + 4 => Some(Type::I64), + 5 => Some(Type::I32), + 6 => Some(Type::I16), + 7 => Some(Type::I8), + 8 => Some(Type::Bool), + 9 => Some(Type::F32), + 10 => Some(Type::F64), + 11 => Some(Type::Str), + _ => None, + }; + if data_type.is_none() || data_type.as_ref().unwrap() != *type_value { + println!("{:#?}", *type_value); + return Err(ValidationError::InvalidArgument); + } + + match data_type.unwrap(){ + Type::U64 | Type::I64 | Type::F64 => { + data.next(); + data.next(); + data.next(); + data.next(); + data.next(); + data.next(); + data.next(); + data.next(); + }, + Type::U32 | Type::I32 | Type::F32 => { + data.next(); + data.next(); + data.next(); + data.next(); + }, + Type::U16 | Type::I16 => { + data.next(); + data.next(); + }, + Type::U8 | Type::I8 | Type::Bool => { + data.next(); + }, + Type::Str => todo!(), + _ => panic!("Should not be possible"), + } + } else if types.peek().is_none() { + return Ok(()); + } + break; + } + } + else { + return Err(ValidationError::IncorrectVersion); + } + } else { + return Err(ValidationError::InvalidHeader); + } + Ok(()) + } +} + + + + + + + + + + + + + + + + + diff --git a/dev/src/idl/types.rs b/dev/src/idl/types.rs index 25c7588b..6343aa8b 100644 --- a/dev/src/idl/types.rs +++ b/dev/src/idl/types.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; -use crate::idl::protocol::Protocol; use crate::idl::parser::AST; +use super::protocol::{Function, Protocols}; + #[derive(Debug, Clone, PartialEq)] -enum Type { +pub enum Type { U64, U32, U16, @@ -19,6 +20,7 @@ enum Type { F32, F64, + Str, Struct(StructType), Enum(EnumType), @@ -26,20 +28,34 @@ enum Type { } #[derive(Debug, Clone, PartialEq)] -struct StructType { +pub struct StructType { members : HashMap } #[derive(Debug, Clone, PartialEq)] -struct EnumType { +pub struct EnumType { members : HashMap } -pub fn get_protocols(ast : AST) -> Vec { - let protocols : Vec = Vec::new(); +fn add_builtin_types(symbol_table : &mut HashMap) { + symbol_table.insert("u8".to_string(), Type::U8); + symbol_table.insert("u16".to_string(), Type::U16); + symbol_table.insert("u32".to_string(), Type::U32); + symbol_table.insert("u64".to_string(), Type::U64); + symbol_table.insert("i8".to_string(), Type::I8); + symbol_table.insert("i16".to_string(), Type::I16); + symbol_table.insert("i32".to_string(), Type::I32); + symbol_table.insert("i64".to_string(), Type::I64); + symbol_table.insert("bool".to_string(), Type::Bool); + symbol_table.insert("f32".to_string(), Type::F32); + symbol_table.insert("f64".to_string(), Type::F64); +} + +pub fn get_protocols(ast : AST) -> Protocols{ let mut symbol_table : HashMap = HashMap::new(); let declarations = ast.0; + add_builtin_types(&mut symbol_table); // First Pass // We just populate the symbol table here for decl in declarations.iter() { @@ -64,9 +80,16 @@ pub fn get_protocols(ast : AST) -> Vec { super::parser::Declaration::ProtocolDeclaration(_) => {}, } } + + let mut protocols = Protocols::new(symbol_table); for decl in declarations.iter(){ match decl { super::parser::Declaration::ProtocolDeclaration(p) => { + let mut funcs : HashMap = HashMap::new(); + for i in p.interface.iter(){ + funcs.insert(i.name.to_string(), Function{arguments : i.arg_list.clone()}); + } + protocols.add_protocol(p.name.to_string(), funcs); }, _ => {} } diff --git a/sysdata/idl/test/src/protocol.aldi b/sysdata/idl/test/src/protocol.aldi new file mode 100644 index 00000000..3b787c2f --- /dev/null +++ b/sysdata/idl/test/src/protocol.aldi @@ -0,0 +1,3 @@ +protocol Foo{ + fn bar(i32, u8, bool) -> void; +}