Allow to deserialize/serialize into enums

Close #164
This commit is contained in:
Vincent Prouillet 2017-04-24 21:48:02 +09:00
parent 046a3e5117
commit 7cb357c168
4 changed files with 155 additions and 3 deletions

135
src/de.rs
View file

@ -11,6 +11,7 @@ use std::str;
use std::vec; use std::vec;
use serde::de; use serde::de;
use serde::de::IntoDeserializer;
use tokens::{Tokenizer, Token, Error as TokenError}; use tokens::{Tokenizer, Token, Error as TokenError};
use datetime::{SERDE_STRUCT_FIELD_NAME, SERDE_STRUCT_NAME}; use datetime::{SERDE_STRUCT_FIELD_NAME, SERDE_STRUCT_NAME};
@ -121,6 +122,12 @@ enum ErrorKind {
/// type. /// type.
Custom, Custom,
/// TODO
ExpectedMapEnd,
ExpectedEnum,
ExpectedMapColon,
ExpectedString,
#[doc(hidden)] #[doc(hidden)]
__Nonexhaustive, __Nonexhaustive,
} }
@ -145,6 +152,7 @@ impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> {
values: None, values: None,
array: false, array: false,
}; };
while let Some(line) = self.line()? { while let Some(line) = self.line()? {
match line { match line {
Line::Table { at, mut header, array } => { Line::Table { at, mut header, array } => {
@ -192,9 +200,40 @@ impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> {
}) })
} }
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V
) -> Result<V::Value, Error>
where V: de::Visitor<'de>
{
if self.peek_char()? == '"' {
// Visit a unit variant.
match self.next()?.unwrap() {
Token::String { ref val, ..} => {
visitor.visit_enum(val.clone().into_deserializer())
},
_ => Err(Error::from_kind(ErrorKind::ExpectedString))
}
} else if self.next_char()? == '{' {
// Visit a newtype variant, tuple variant, or struct variant.
let value = visitor.visit_enum(Enum::new(self))?;
// Parse the matching close brace.
if self.next_char()? == '}' {
Ok(value)
} else {
Err(Error::from_kind(ErrorKind::ExpectedMapEnd))
}
} else {
Err(Error::from_kind(ErrorKind::ExpectedEnum))
}
}
forward_to_deserialize_any! { forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq
bytes byte_buf map struct unit enum newtype_struct bytes byte_buf map struct unit newtype_struct
ignored_any unit_struct tuple_struct tuple option identifier ignored_any unit_struct tuple_struct tuple option identifier
} }
} }
@ -574,6 +613,82 @@ impl<'de> de::MapAccess<'de> for InlineTableDeserializer<'de> {
} }
} }
struct Enum<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
}
impl<'a, 'de> Enum<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>) -> Self {
Enum { de: de }
}
}
// `EnumAccess` is provided to the `Visitor` to give it the ability to determine
// which variant of the enum is supposed to be deserialized.
//
// Note that all enum deserialization methods in Serde refer exclusively to the
// "externally tagged" enum representation.
impl<'de, 'a> de::EnumAccess<'de> for Enum<'a, 'de> {
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error>
where V: de::DeserializeSeed<'de>
{
// The `deserialize_enum` method parsed a `{` character so we are
// currently inside of a map. The seed will be deserializing itself from
// the key of the map.
let val = seed.deserialize(&mut *self.de)?;
// Parse the colon separating map key from value.
if self.de.next_char()? == ':' {
Ok((val, self))
} else {
Err(Error::from_kind(ErrorKind::ExpectedMapColon))
}
}
}
// `VariantAccess` is provided to the `Visitor` to give it the ability to see
// the content of the single variant that it decided to deserialize.
impl<'de, 'a> de::VariantAccess<'de> for Enum<'a, 'de> {
type Error = Error;
// If the `Visitor` expected this variant to be a unit variant, the input
// should have been the plain string case handled in `deserialize_enum`.
fn unit_variant(self) -> Result<(), Error> {
Err(Error::from_kind(ErrorKind::ExpectedString))
}
// Newtype variants are represented in JSON as `{ NAME: VALUE }` so
// deserialize the value here.
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
where T: de::DeserializeSeed<'de>
{
seed.deserialize(self.de)
}
// Tuple variants are represented in JSON as `{ NAME: [DATA...] }` so
// deserialize the sequence of data here.
fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
where V: de::Visitor<'de>
{
de::Deserializer::deserialize_seq(self.de, visitor)
}
// Struct variants are represented in JSON as `{ NAME: { K: V, ... } }` so
// deserialize the inner map here.
fn struct_variant<V>(
self,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Error>
where V: de::Visitor<'de>
{
de::Deserializer::deserialize_map(self.de, visitor)
}
}
impl<'a> Deserializer<'a> { impl<'a> Deserializer<'a> {
/// Creates a new deserializer which will be deserializing the string /// Creates a new deserializer which will be deserializing the string
/// provided. /// provided.
@ -959,6 +1074,16 @@ impl<'a> Deserializer<'a> {
self.tokens.peek().map_err(|e| self.token_error(e)) self.tokens.peek().map_err(|e| self.token_error(e))
} }
fn peek_char(&mut self) -> Result<char, Error> {
self.input.chars().next().ok_or(Error::from_kind(ErrorKind::UnexpectedEof))
}
fn next_char(&mut self) -> Result<char, Error> {
let ch = self.peek_char()?;
self.input = &self.input[ch.len_utf8()..];
Ok(ch)
}
fn eof(&self) -> Error { fn eof(&self) -> Error {
self.error(self.input.len(), ErrorKind::UnexpectedEof) self.error(self.input.len(), ErrorKind::UnexpectedEof)
} }
@ -1092,6 +1217,10 @@ impl fmt::Display for Error {
ErrorKind::RedefineAsArray => "table redefined as array".fmt(f)?, ErrorKind::RedefineAsArray => "table redefined as array".fmt(f)?,
ErrorKind::EmptyTableKey => "empty table key found".fmt(f)?, ErrorKind::EmptyTableKey => "empty table key found".fmt(f)?,
ErrorKind::Custom => self.inner.message.fmt(f)?, ErrorKind::Custom => self.inner.message.fmt(f)?,
ErrorKind::ExpectedMapEnd => "expected end of map".fmt(f)?,
ErrorKind::ExpectedEnum => "expected enum".fmt(f)?,
ErrorKind::ExpectedMapColon => "expected map colon".fmt(f)?,
ErrorKind::ExpectedString => "expected string".fmt(f)?,
ErrorKind::__Nonexhaustive => panic!(), ErrorKind::__Nonexhaustive => panic!(),
} }
@ -1134,6 +1263,10 @@ impl error::Error for Error {
ErrorKind::RedefineAsArray => "table redefined as array", ErrorKind::RedefineAsArray => "table redefined as array",
ErrorKind::EmptyTableKey => "empty table key found", ErrorKind::EmptyTableKey => "empty table key found",
ErrorKind::Custom => "a custom error", ErrorKind::Custom => "a custom error",
ErrorKind::ExpectedMapEnd => "expected end of map",
ErrorKind::ExpectedEnum => "expected enum",
ErrorKind::ExpectedMapColon => "expected map colon",
ErrorKind::ExpectedString => "expected string",
ErrorKind::__Nonexhaustive => panic!(), ErrorKind::__Nonexhaustive => panic!(),
} }
} }

View file

@ -423,7 +423,7 @@ impl<'a, 'b> ser::Serializer for &'b mut Serializer<'a> {
_variant_index: u32, _variant_index: u32,
_variant: &'static str) _variant: &'static str)
-> Result<(), Self::Error> { -> Result<(), Self::Error> {
Err(Error::UnsupportedType) self.serialize_str(_variant)
} }
fn serialize_newtype_struct<T: ?Sized>(self, _name: &'static str, value: &T) fn serialize_newtype_struct<T: ?Sized>(self, _name: &'static str, value: &T)

View file

@ -694,7 +694,7 @@ impl ser::Serializer for Serializer {
_variant_index: u32, _variant_index: u32,
_variant: &'static str) _variant: &'static str)
-> Result<Value, ::ser::Error> { -> Result<Value, ::ser::Error> {
Err(::ser::Error::UnsupportedType) self.serialize_str(_variant)
} }
fn serialize_newtype_struct<T: ?Sized>(self, fn serialize_newtype_struct<T: ?Sized>(self,

View file

@ -329,6 +329,25 @@ fn parse_enum() {
} }
} }
#[test]
fn parse_enum_string() {
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
struct Foo { a: Sort }
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
#[serde(rename_all = "lowercase")]
enum Sort {
Asc,
Desc,
}
equivalent! {
Foo { a: Sort::Desc },
Table(map! { a: Value::String("desc".to_string()) }),
}
}
// #[test] // #[test]
// fn unused_fields() { // fn unused_fields() {
// #[derive(Serialize, Deserialize, PartialEq, Debug)] // #[derive(Serialize, Deserialize, PartialEq, Debug)]