Support deserializing spanned keys (#333)

* Store key spans in the deserializer

* Support deserializing spanned keys

* Store key spans of the table header as well

* Support nested table key spans as well
This commit is contained in:
est31 2019-09-16 23:32:45 +02:00 committed by Alex Crichton
parent 55ca6c5e30
commit 12db6aa93f
3 changed files with 161 additions and 36 deletions

122
src/de.rs
View file

@ -22,7 +22,7 @@ use crate::spanned;
use crate::tokens::{Error as TokenError, Span, Token, Tokenizer}; use crate::tokens::{Error as TokenError, Span, Token, Tokenizer};
/// Type Alias for a TOML Table pair /// Type Alias for a TOML Table pair
type TablePair<'a> = (Cow<'a, str>, Value<'a>); type TablePair<'a> = ((Span, Cow<'a, str>), Value<'a>);
/// Deserializes a byte slice into a type. /// Deserializes a byte slice into a type.
/// ///
@ -318,9 +318,16 @@ impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> {
} }
} }
fn headers_equal<'a, 'b>(hdr_a: &[(Span, Cow<'a, str>)], hdr_b: &[(Span, Cow<'b, str>)]) -> bool {
if hdr_a.len() != hdr_b.len() {
return false;
}
hdr_a.iter().zip(hdr_b.iter()).all(|(h1, h2)| h1.1 == h2.1)
}
struct Table<'a> { struct Table<'a> {
at: usize, at: usize,
header: Vec<Cow<'a, str>>, header: Vec<(Span, Cow<'a, str>)>,
values: Option<Vec<TablePair<'a>>>, values: Option<Vec<TablePair<'a>>>,
array: bool, array: bool,
} }
@ -351,7 +358,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
loop { loop {
assert!(self.next_value.is_none()); assert!(self.next_value.is_none());
if let Some((key, value)) = self.values.next() { if let Some((key, value)) = self.values.next() {
let ret = seed.deserialize(StrDeserializer::new(key.clone()))?; let ret = seed.deserialize(StrDeserializer::spanned(key.clone()))?;
self.next_value = Some((key, value)); self.next_value = Some((key, value));
return Ok(Some(ret)); return Ok(Some(ret));
} }
@ -366,7 +373,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
return false; return false;
} }
match t.header.get(..self.depth) { match t.header.get(..self.depth) {
Some(header) => header == prefix, Some(header) => headers_equal(&header, &prefix),
None => false, None => false,
} }
}) })
@ -382,9 +389,17 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
// Test to see if we're duplicating our parent's table, and if so // Test to see if we're duplicating our parent's table, and if so
// then this is an error in the toml format // then this is an error in the toml format
if self.cur_parent != pos { if self.cur_parent != pos {
if self.tables[self.cur_parent].header == self.tables[pos].header { if headers_equal(
&self.tables[self.cur_parent].header,
&self.tables[pos].header,
) {
let at = self.tables[pos].at; let at = self.tables[pos].at;
let name = self.tables[pos].header.join("."); let name = self.tables[pos]
.header
.iter()
.map(|k| k.1.to_owned())
.collect::<Vec<_>>()
.join(".");
return Err(self.de.error(at, ErrorKind::DuplicateTable(name))); return Err(self.de.error(at, ErrorKind::DuplicateTable(name)));
} }
@ -408,7 +423,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
// decoding. // decoding.
if self.depth != table.header.len() { if self.depth != table.header.len() {
let key = &table.header[self.depth]; let key = &table.header[self.depth];
let key = seed.deserialize(StrDeserializer::new(key.clone()))?; let key = seed.deserialize(StrDeserializer::spanned(key.clone()))?;
return Ok(Some(key)); return Ok(Some(key));
} }
@ -437,7 +452,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
match seed.deserialize(ValueDeserializer::new(v)) { match seed.deserialize(ValueDeserializer::new(v)) {
Ok(v) => return Ok(v), Ok(v) => return Ok(v),
Err(mut e) => { Err(mut e) => {
e.add_key_context(&k); e.add_key_context(&k.1);
return Err(e); return Err(e);
} }
} }
@ -458,7 +473,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
de: &mut *self.de, de: &mut *self.de,
}); });
res.map_err(|mut e| { res.map_err(|mut e| {
e.add_key_context(&self.tables[self.cur - 1].header[self.depth]); e.add_key_context(&self.tables[self.cur - 1].header[self.depth].1);
e e
}) })
} }
@ -482,7 +497,10 @@ impl<'de, 'b> de::SeqAccess<'de> for MapVisitor<'de, 'b> {
.iter() .iter()
.enumerate() .enumerate()
.skip(self.cur_parent + 1) .skip(self.cur_parent + 1)
.find(|&(_, table)| table.array && table.header == self.tables[self.cur_parent].header) .find(|&(_, table)| {
let tables_eq = headers_equal(&table.header, &self.tables[self.cur_parent].header);
table.array && tables_eq
})
.map(|p| p.0) .map(|p| p.0)
.unwrap_or(self.max); .unwrap_or(self.max);
@ -560,9 +578,9 @@ impl<'de, 'b> de::Deserializer<'de> for MapVisitor<'de, 'b> {
if table.header.len() == 0 { if table.header.len() == 0 {
return Err(self.de.error(self.cur, ErrorKind::EmptyTableKey)); return Err(self.de.error(self.cur, ErrorKind::EmptyTableKey));
} }
let name = table.header[table.header.len() - 1].to_owned(); let name = table.header[table.header.len() - 1].1.to_owned();
visitor.visit_enum(DottedTableDeserializer { visitor.visit_enum(DottedTableDeserializer {
name: name, name,
value: Value { value: Value {
e: E::DottedTable(values), e: E::DottedTable(values),
start: 0, start: 0,
@ -579,12 +597,27 @@ impl<'de, 'b> de::Deserializer<'de> for MapVisitor<'de, 'b> {
} }
struct StrDeserializer<'a> { struct StrDeserializer<'a> {
span: Option<Span>,
key: Cow<'a, str>, key: Cow<'a, str>,
} }
impl<'a> StrDeserializer<'a> { impl<'a> StrDeserializer<'a> {
fn spanned(inner: (Span, Cow<'a, str>)) -> StrDeserializer<'a> {
StrDeserializer {
span: Some(inner.0),
key: inner.1,
}
}
fn new(key: Cow<'a, str>) -> StrDeserializer<'a> { fn new(key: Cow<'a, str>) -> StrDeserializer<'a> {
StrDeserializer { key } StrDeserializer { span: None, key }
}
}
impl<'a, 'b> de::IntoDeserializer<'a, Error> for StrDeserializer<'a> {
type Deserializer = Self;
fn into_deserializer(self) -> Self::Deserializer {
self
} }
} }
@ -601,9 +634,31 @@ impl<'de> de::Deserializer<'de> for StrDeserializer<'de> {
} }
} }
fn deserialize_struct<V>(
self,
name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Error>
where
V: de::Visitor<'de>,
{
if name == spanned::NAME && fields == [spanned::START, spanned::END, spanned::VALUE] {
if let Some(span) = self.span {
return visitor.visit_map(SpannedDeserializer {
phantom_data: PhantomData,
start: Some(span.start),
value: Some(StrDeserializer::new(self.key)),
end: Some(span.end),
});
}
}
self.deserialize_any(visitor)
}
serde::forward_to_deserialize_any! { serde::forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq
bytes byte_buf map struct option unit newtype_struct bytes byte_buf map option unit newtype_struct
ignored_any unit_struct tuple_struct tuple enum identifier ignored_any unit_struct tuple_struct tuple enum identifier
} }
} }
@ -690,13 +745,13 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer<'de> {
.iter() .iter()
.filter_map(|key_value| { .filter_map(|key_value| {
let (ref key, ref _val) = *key_value; let (ref key, ref _val) = *key_value;
if !fields.contains(&&(**key)) { if !fields.contains(&&*(key.1)) {
Some(key.clone()) Some(key.clone())
} else { } else {
None None
} }
}) })
.collect::<Vec<Cow<'de, str>>>(); .collect::<Vec<_>>();
if !extra_fields.is_empty() { if !extra_fields.is_empty() {
return Err(Error::from_kind( return Err(Error::from_kind(
@ -704,7 +759,7 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer<'de> {
ErrorKind::UnexpectedKeys { ErrorKind::UnexpectedKeys {
keys: extra_fields keys: extra_fields
.iter() .iter()
.map(|k| k.to_string()) .map(|k| k.1.to_string())
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
available: fields, available: fields,
}, },
@ -943,7 +998,7 @@ impl<'de> de::MapAccess<'de> for InlineTableDeserializer<'de> {
None => return Ok(None), None => return Ok(None),
}; };
self.next_value = Some(value); self.next_value = Some(value);
seed.deserialize(StrDeserializer::new(key)).map(Some) seed.deserialize(StrDeserializer::spanned(key)).map(Some)
} }
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Error> fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Error>
@ -976,7 +1031,7 @@ impl<'de> de::EnumAccess<'de> for InlineTableDeserializer<'de> {
} }
}; };
seed.deserialize(StrDeserializer::new(key)) seed.deserialize(StrDeserializer::new(key.1))
.map(|val| (val, TableEnumDeserializer { value })) .map(|val| (val, TableEnumDeserializer { value }))
} }
} }
@ -1027,13 +1082,13 @@ impl<'de> de::VariantAccess<'de> for TableEnumDeserializer<'de> {
let tuple_values = values let tuple_values = values
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(index, (key, value))| match key.parse::<usize>() { .map(|(index, (key, value))| match key.1.parse::<usize>() {
Ok(key_index) if key_index == index => Ok(value), Ok(key_index) if key_index == index => Ok(value),
Ok(_) | Err(_) => Err(Error::from_kind( Ok(_) | Err(_) => Err(Error::from_kind(
Some(value.start), Some(key.0.start),
ErrorKind::ExpectedTupleIndex { ErrorKind::ExpectedTupleIndex {
expected: index, expected: index,
found: key.to_string(), found: key.1.to_string(),
}, },
)), )),
}) })
@ -1350,14 +1405,14 @@ impl<'a> Deserializer<'a> {
.as_ref() .as_ref()
.and_then(|values| values.last()) .and_then(|values| values.last())
.map(|&(_, ref val)| val.end) .map(|&(_, ref val)| val.end)
.unwrap_or_else(|| header.len()); .unwrap_or_else(|| header.1.len());
Ok(( Ok((
Value { Value {
e: E::DottedTable(table.values.unwrap_or_else(Vec::new)), e: E::DottedTable(table.values.unwrap_or_else(Vec::new)),
start, start,
end, end,
}, },
Some(header.clone()), Some(header.1.clone()),
)) ))
} }
Some(_) => self.value().map(|val| (val, None)), Some(_) => self.value().map(|val| (val, None)),
@ -1672,14 +1727,11 @@ impl<'a> Deserializer<'a> {
Ok((span, ret)) Ok((span, ret))
} }
fn table_key(&mut self) -> Result<Cow<'a, str>, Error> { fn table_key(&mut self) -> Result<(Span, Cow<'a, str>), Error> {
self.tokens self.tokens.table_key().map_err(|e| self.token_error(e))
.table_key()
.map(|t| t.1)
.map_err(|e| self.token_error(e))
} }
fn dotted_key(&mut self) -> Result<Vec<Cow<'a, str>>, Error> { fn dotted_key(&mut self) -> Result<Vec<(Span, Cow<'a, str>)>, Error> {
let mut result = Vec::new(); let mut result = Vec::new();
result.push(self.table_key()?); result.push(self.table_key()?);
self.eat_whitespace()?; self.eat_whitespace()?;
@ -1705,7 +1757,7 @@ impl<'a> Deserializer<'a> {
/// * `values`: The `Vec` to store the value in. /// * `values`: The `Vec` to store the value in.
fn add_dotted_key( fn add_dotted_key(
&self, &self,
mut key_parts: Vec<Cow<'a, str>>, mut key_parts: Vec<(Span, Cow<'a, str>)>,
value: Value<'a>, value: Value<'a>,
values: &mut Vec<TablePair<'a>>, values: &mut Vec<TablePair<'a>>,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -1714,7 +1766,7 @@ impl<'a> Deserializer<'a> {
values.push((key, value)); values.push((key, value));
return Ok(()); return Ok(());
} }
match values.iter_mut().find(|&&mut (ref k, _)| *k == key) { match values.iter_mut().find(|&&mut (ref k, _)| *k.1 == key.1) {
Some(&mut ( Some(&mut (
_, _,
Value { Value {
@ -2038,7 +2090,7 @@ enum Line<'a> {
header: Header<'a>, header: Header<'a>,
array: bool, array: bool,
}, },
KeyValue(Vec<Cow<'a, str>>, Value<'a>), KeyValue(Vec<(Span, Cow<'a, str>)>, Value<'a>),
} }
struct Header<'a> { struct Header<'a> {
@ -2058,13 +2110,13 @@ impl<'a> Header<'a> {
} }
} }
fn next(&mut self) -> Result<Option<Cow<'a, str>>, TokenError> { fn next(&mut self) -> Result<Option<(Span, Cow<'a, str>)>, TokenError> {
self.tokens.eat_whitespace()?; self.tokens.eat_whitespace()?;
if self.first || self.tokens.eat(Token::Period)? { if self.first || self.tokens.eat(Token::Period)? {
self.first = false; self.first = false;
self.tokens.eat_whitespace()?; self.tokens.eat_whitespace()?;
self.tokens.table_key().map(|t| t.1).map(Some) self.tokens.table_key().map(|t| t).map(Some)
} else { } else {
self.tokens.expect(Token::RightBracket)?; self.tokens.expect(Token::RightBracket)?;
if self.array { if self.array {

View file

@ -28,7 +28,7 @@ pub(crate) const VALUE: &str = "$__toml_private_value";
/// assert_eq!(u.s.into_inner(), String::from("value")); /// assert_eq!(u.s.into_inner(), String::from("value"));
/// } /// }
/// ``` /// ```
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Spanned<T> { pub struct Spanned<T> {
/// The start range. /// The start range.
start: usize, start: usize,

View file

@ -85,3 +85,76 @@ fn test_spanned_field() {
// ending at something other than the absolute end // ending at something other than the absolute end
good::<u32>("foo = 42\nnoise = true", "42", Some(8)); good::<u32>("foo = 42\nnoise = true", "42", Some(8));
} }
#[test]
fn test_spanned_table() {
#[derive(Deserialize)]
struct Foo {
foo: HashMap<Spanned<String>, Spanned<String>>,
}
fn good(s: &str) {
let foo: Foo = toml::from_str(s).unwrap();
for (k, v) in foo.foo.iter() {
assert_eq!(&s[k.start()..k.end()], k.get_ref());
assert_eq!(&s[(v.start() + 1)..(v.end() - 1)], v.get_ref());
}
}
good(
"
[foo]
a = 'b'
bar = 'baz'
c = 'd'
e = \"f\"
",
);
good(
"
foo = { a = 'b', bar = 'baz', c = 'd', e = \"f\" }
",
);
}
#[test]
fn test_spanned_nested() {
#[derive(Deserialize)]
struct Foo {
foo: HashMap<Spanned<String>, HashMap<Spanned<String>, Spanned<String>>>,
}
fn good(s: &str) {
let foo: Foo = toml::from_str(s).unwrap();
for (k, v) in foo.foo.iter() {
assert_eq!(&s[k.start()..k.end()], k.get_ref());
for (n_k, n_v) in v.iter() {
assert_eq!(&s[n_k.start()..n_k.end()], n_k.get_ref());
assert_eq!(&s[(n_v.start() + 1)..(n_v.end() - 1)], n_v.get_ref());
}
}
}
good(
"
[foo.a]
a = 'b'
c = 'd'
e = \"f\"
[foo.bar]
baz = 'true'
",
);
good(
"
[foo]
foo = { a = 'b', bar = 'baz', c = 'd', e = \"f\" }
bazz = {}
g = { h = 'i' }
",
);
}