From 12db6aa93fea62eb3bd19f6813fb9d3c65be91bc Mon Sep 17 00:00:00 2001 From: est31 Date: Mon, 16 Sep 2019 23:32:45 +0200 Subject: [PATCH] Support deserializing spanned keys (#333) * Store key spans in the deserializer * Support deserializing spanned keys * Store key spans of the table header as well * Support nested table key spans as well --- src/de.rs | 122 +++++++++++++++++++++++++----------- src/spanned.rs | 2 +- test-suite/tests/spanned.rs | 73 +++++++++++++++++++++ 3 files changed, 161 insertions(+), 36 deletions(-) diff --git a/src/de.rs b/src/de.rs index 439a48a..9bb5204 100644 --- a/src/de.rs +++ b/src/de.rs @@ -22,7 +22,7 @@ use crate::spanned; use crate::tokens::{Error as TokenError, Span, Token, Tokenizer}; /// Type Alias for a TOML Table pair -type TablePair<'a> = (Cow<'a, str>, Value<'a>); +type TablePair<'a> = ((Span, Cow<'a, str>), Value<'a>); /// Deserializes a byte slice into a type. /// @@ -318,9 +318,16 @@ impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> { } } +fn headers_equal<'a, 'b>(hdr_a: &[(Span, Cow<'a, str>)], hdr_b: &[(Span, Cow<'b, str>)]) -> bool { + if hdr_a.len() != hdr_b.len() { + return false; + } + hdr_a.iter().zip(hdr_b.iter()).all(|(h1, h2)| h1.1 == h2.1) +} + struct Table<'a> { at: usize, - header: Vec>, + header: Vec<(Span, Cow<'a, str>)>, values: Option>>, array: bool, } @@ -351,7 +358,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> { loop { assert!(self.next_value.is_none()); if let Some((key, value)) = self.values.next() { - let ret = seed.deserialize(StrDeserializer::new(key.clone()))?; + let ret = seed.deserialize(StrDeserializer::spanned(key.clone()))?; self.next_value = Some((key, value)); return Ok(Some(ret)); } @@ -366,7 +373,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> { return false; } match t.header.get(..self.depth) { - Some(header) => header == prefix, + Some(header) => headers_equal(&header, &prefix), None => false, } }) @@ -382,9 +389,17 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> { // Test to see if we're duplicating our parent's table, and if so // then this is an error in the toml format if self.cur_parent != pos { - if self.tables[self.cur_parent].header == self.tables[pos].header { + if headers_equal( + &self.tables[self.cur_parent].header, + &self.tables[pos].header, + ) { let at = self.tables[pos].at; - let name = self.tables[pos].header.join("."); + let name = self.tables[pos] + .header + .iter() + .map(|k| k.1.to_owned()) + .collect::>() + .join("."); return Err(self.de.error(at, ErrorKind::DuplicateTable(name))); } @@ -408,7 +423,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> { // decoding. if self.depth != table.header.len() { let key = &table.header[self.depth]; - let key = seed.deserialize(StrDeserializer::new(key.clone()))?; + let key = seed.deserialize(StrDeserializer::spanned(key.clone()))?; return Ok(Some(key)); } @@ -437,7 +452,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> { match seed.deserialize(ValueDeserializer::new(v)) { Ok(v) => return Ok(v), Err(mut e) => { - e.add_key_context(&k); + e.add_key_context(&k.1); return Err(e); } } @@ -458,7 +473,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> { de: &mut *self.de, }); res.map_err(|mut e| { - e.add_key_context(&self.tables[self.cur - 1].header[self.depth]); + e.add_key_context(&self.tables[self.cur - 1].header[self.depth].1); e }) } @@ -482,7 +497,10 @@ impl<'de, 'b> de::SeqAccess<'de> for MapVisitor<'de, 'b> { .iter() .enumerate() .skip(self.cur_parent + 1) - .find(|&(_, table)| table.array && table.header == self.tables[self.cur_parent].header) + .find(|&(_, table)| { + let tables_eq = headers_equal(&table.header, &self.tables[self.cur_parent].header); + table.array && tables_eq + }) .map(|p| p.0) .unwrap_or(self.max); @@ -560,9 +578,9 @@ impl<'de, 'b> de::Deserializer<'de> for MapVisitor<'de, 'b> { if table.header.len() == 0 { return Err(self.de.error(self.cur, ErrorKind::EmptyTableKey)); } - let name = table.header[table.header.len() - 1].to_owned(); + let name = table.header[table.header.len() - 1].1.to_owned(); visitor.visit_enum(DottedTableDeserializer { - name: name, + name, value: Value { e: E::DottedTable(values), start: 0, @@ -579,12 +597,27 @@ impl<'de, 'b> de::Deserializer<'de> for MapVisitor<'de, 'b> { } struct StrDeserializer<'a> { + span: Option, key: Cow<'a, str>, } impl<'a> StrDeserializer<'a> { + fn spanned(inner: (Span, Cow<'a, str>)) -> StrDeserializer<'a> { + StrDeserializer { + span: Some(inner.0), + key: inner.1, + } + } fn new(key: Cow<'a, str>) -> StrDeserializer<'a> { - StrDeserializer { key } + StrDeserializer { span: None, key } + } +} + +impl<'a, 'b> de::IntoDeserializer<'a, Error> for StrDeserializer<'a> { + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self } } @@ -601,9 +634,31 @@ impl<'de> de::Deserializer<'de> for StrDeserializer<'de> { } } + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + if name == spanned::NAME && fields == [spanned::START, spanned::END, spanned::VALUE] { + if let Some(span) = self.span { + return visitor.visit_map(SpannedDeserializer { + phantom_data: PhantomData, + start: Some(span.start), + value: Some(StrDeserializer::new(self.key)), + end: Some(span.end), + }); + } + } + self.deserialize_any(visitor) + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq - bytes byte_buf map struct option unit newtype_struct + bytes byte_buf map option unit newtype_struct ignored_any unit_struct tuple_struct tuple enum identifier } } @@ -690,13 +745,13 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer<'de> { .iter() .filter_map(|key_value| { let (ref key, ref _val) = *key_value; - if !fields.contains(&&(**key)) { + if !fields.contains(&&*(key.1)) { Some(key.clone()) } else { None } }) - .collect::>>(); + .collect::>(); if !extra_fields.is_empty() { return Err(Error::from_kind( @@ -704,7 +759,7 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer<'de> { ErrorKind::UnexpectedKeys { keys: extra_fields .iter() - .map(|k| k.to_string()) + .map(|k| k.1.to_string()) .collect::>(), available: fields, }, @@ -943,7 +998,7 @@ impl<'de> de::MapAccess<'de> for InlineTableDeserializer<'de> { None => return Ok(None), }; self.next_value = Some(value); - seed.deserialize(StrDeserializer::new(key)).map(Some) + seed.deserialize(StrDeserializer::spanned(key)).map(Some) } fn next_value_seed(&mut self, seed: V) -> Result @@ -976,7 +1031,7 @@ impl<'de> de::EnumAccess<'de> for InlineTableDeserializer<'de> { } }; - seed.deserialize(StrDeserializer::new(key)) + seed.deserialize(StrDeserializer::new(key.1)) .map(|val| (val, TableEnumDeserializer { value })) } } @@ -1027,13 +1082,13 @@ impl<'de> de::VariantAccess<'de> for TableEnumDeserializer<'de> { let tuple_values = values .into_iter() .enumerate() - .map(|(index, (key, value))| match key.parse::() { + .map(|(index, (key, value))| match key.1.parse::() { Ok(key_index) if key_index == index => Ok(value), Ok(_) | Err(_) => Err(Error::from_kind( - Some(value.start), + Some(key.0.start), ErrorKind::ExpectedTupleIndex { expected: index, - found: key.to_string(), + found: key.1.to_string(), }, )), }) @@ -1350,14 +1405,14 @@ impl<'a> Deserializer<'a> { .as_ref() .and_then(|values| values.last()) .map(|&(_, ref val)| val.end) - .unwrap_or_else(|| header.len()); + .unwrap_or_else(|| header.1.len()); Ok(( Value { e: E::DottedTable(table.values.unwrap_or_else(Vec::new)), start, end, }, - Some(header.clone()), + Some(header.1.clone()), )) } Some(_) => self.value().map(|val| (val, None)), @@ -1672,14 +1727,11 @@ impl<'a> Deserializer<'a> { Ok((span, ret)) } - fn table_key(&mut self) -> Result, Error> { - self.tokens - .table_key() - .map(|t| t.1) - .map_err(|e| self.token_error(e)) + fn table_key(&mut self) -> Result<(Span, Cow<'a, str>), Error> { + self.tokens.table_key().map_err(|e| self.token_error(e)) } - fn dotted_key(&mut self) -> Result>, Error> { + fn dotted_key(&mut self) -> Result)>, Error> { let mut result = Vec::new(); result.push(self.table_key()?); self.eat_whitespace()?; @@ -1705,7 +1757,7 @@ impl<'a> Deserializer<'a> { /// * `values`: The `Vec` to store the value in. fn add_dotted_key( &self, - mut key_parts: Vec>, + mut key_parts: Vec<(Span, Cow<'a, str>)>, value: Value<'a>, values: &mut Vec>, ) -> Result<(), Error> { @@ -1714,7 +1766,7 @@ impl<'a> Deserializer<'a> { values.push((key, value)); return Ok(()); } - match values.iter_mut().find(|&&mut (ref k, _)| *k == key) { + match values.iter_mut().find(|&&mut (ref k, _)| *k.1 == key.1) { Some(&mut ( _, Value { @@ -2038,7 +2090,7 @@ enum Line<'a> { header: Header<'a>, array: bool, }, - KeyValue(Vec>, Value<'a>), + KeyValue(Vec<(Span, Cow<'a, str>)>, Value<'a>), } struct Header<'a> { @@ -2058,13 +2110,13 @@ impl<'a> Header<'a> { } } - fn next(&mut self) -> Result>, TokenError> { + fn next(&mut self) -> Result)>, TokenError> { self.tokens.eat_whitespace()?; if self.first || self.tokens.eat(Token::Period)? { self.first = false; self.tokens.eat_whitespace()?; - self.tokens.table_key().map(|t| t.1).map(Some) + self.tokens.table_key().map(|t| t).map(Some) } else { self.tokens.expect(Token::RightBracket)?; if self.array { diff --git a/src/spanned.rs b/src/spanned.rs index 1538f96..3318d28 100644 --- a/src/spanned.rs +++ b/src/spanned.rs @@ -28,7 +28,7 @@ pub(crate) const VALUE: &str = "$__toml_private_value"; /// assert_eq!(u.s.into_inner(), String::from("value")); /// } /// ``` -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Spanned { /// The start range. start: usize, diff --git a/test-suite/tests/spanned.rs b/test-suite/tests/spanned.rs index 5130a72..1186645 100644 --- a/test-suite/tests/spanned.rs +++ b/test-suite/tests/spanned.rs @@ -85,3 +85,76 @@ fn test_spanned_field() { // ending at something other than the absolute end good::("foo = 42\nnoise = true", "42", Some(8)); } + +#[test] +fn test_spanned_table() { + #[derive(Deserialize)] + struct Foo { + foo: HashMap, Spanned>, + } + + fn good(s: &str) { + let foo: Foo = toml::from_str(s).unwrap(); + + for (k, v) in foo.foo.iter() { + assert_eq!(&s[k.start()..k.end()], k.get_ref()); + assert_eq!(&s[(v.start() + 1)..(v.end() - 1)], v.get_ref()); + } + } + + good( + " + [foo] + a = 'b' + bar = 'baz' + c = 'd' + e = \"f\" + ", + ); + + good( + " + foo = { a = 'b', bar = 'baz', c = 'd', e = \"f\" } + ", + ); +} + +#[test] +fn test_spanned_nested() { + #[derive(Deserialize)] + struct Foo { + foo: HashMap, HashMap, Spanned>>, + } + + fn good(s: &str) { + let foo: Foo = toml::from_str(s).unwrap(); + + for (k, v) in foo.foo.iter() { + assert_eq!(&s[k.start()..k.end()], k.get_ref()); + for (n_k, n_v) in v.iter() { + assert_eq!(&s[n_k.start()..n_k.end()], n_k.get_ref()); + assert_eq!(&s[(n_v.start() + 1)..(n_v.end() - 1)], n_v.get_ref()); + } + } + } + + good( + " + [foo.a] + a = 'b' + c = 'd' + e = \"f\" + [foo.bar] + baz = 'true' + ", + ); + + good( + " + [foo] + foo = { a = 'b', bar = 'baz', c = 'd', e = \"f\" } + bazz = {} + g = { h = 'i' } + ", + ); +}