diff --git a/src/decoder/serde.rs b/src/decoder/serde.rs index 01806d4..2f69eb9 100644 --- a/src/decoder/serde.rs +++ b/src/decoder/serde.rs @@ -89,6 +89,94 @@ impl de::Deserializer for Decoder { self.visit(visitor) } } + + fn visit_enum(&mut self, + _enum: &str, + variants: &[&str], + mut visitor: V) -> Result + where V: de::EnumVisitor, + { + // When decoding enums, this crate takes the strategy of trying to + // decode the current TOML as all of the possible variants, returning + // success on the first one that succeeds. + // + // Note that fidelity of the errors returned here is a little nebulous, + // but we try to return the error that had the relevant field as the + // longest field. This way we hopefully match an error against what was + // most likely being written down without losing too much info. + let mut first_error = None::; + + for variant in 0 .. variants.len() { + let mut de = VariantVisitor { + de: self.sub_decoder(self.toml.clone(), ""), + variant: variant, + }; + + match visitor.visit(&mut de) { + Ok(value) => { + self.toml = de.de.toml; + return Ok(value); + } + Err(e) => { + if let Some(ref first) = first_error { + let my_len = e.field.as_ref().map(|s| s.len()); + let first_len = first.field.as_ref().map(|s| s.len()); + if my_len <= first_len { + continue + } + } + first_error = Some(e); + } + } + } + + Err(first_error.unwrap_or_else(|| self.err(DecodeErrorKind::NoEnumVariants))) + } +} + +struct VariantVisitor { + de: Decoder, + variant: usize, +} + +impl de::VariantVisitor for VariantVisitor { + type Error = DecodeError; + + fn visit_variant(&mut self) -> Result + where V: de::Deserialize + { + use serde::de::value::ValueDeserializer; + + let mut de = self.variant.into_deserializer(); + + de::Deserialize::deserialize(&mut de).map_err(|e| se2toml(e, "variant")) + } + + fn visit_unit(&mut self) -> Result<(), DecodeError> { + de::Deserialize::deserialize(&mut self.de) + } + + fn visit_newtype(&mut self) -> Result + where T: de::Deserialize, + { + de::Deserialize::deserialize(&mut self.de) + } + + fn visit_tuple(&mut self, + _len: usize, + visitor: V) -> Result + where V: de::Visitor, + { + de::Deserializer::visit(&mut self.de, visitor) + } + + fn visit_struct(&mut self, + _fields: &'static [&'static str], + visitor: V) -> Result + where V: de::Visitor, + { + de::Deserializer::visit(&mut self.de, visitor) + } } struct SeqDeserializer<'a, I> {