Tests for Node Deserialization

This commit is contained in:
Jeremy Wall 2022-08-23 22:03:48 -04:00
parent 4ba7ee816b
commit d723d427e6
3 changed files with 71 additions and 66 deletions

View File

@ -1,3 +1,4 @@
{ {
"rust-analyzer.checkOnSave.features": "all" "rust-analyzer.checkOnSave.features": "all",
"rust-analyzer.cargo.features": "all"
} }

View File

@ -27,7 +27,7 @@ use crate::hash::HashWriter;
/// Nodes are tied to a specific implementation of the HashWriter trait which is itself tied /// Nodes are tied to a specific implementation of the HashWriter trait which is itself tied
/// to the DAG they are stored in guaranteeing that the same Hashing implementation is used /// to the DAG they are stored in guaranteeing that the same Hashing implementation is used
/// for each node in the DAG. /// for each node in the DAG.
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq)]
pub struct Node<HW, const HASH_LEN: usize> pub struct Node<HW, const HASH_LEN: usize>
where where
HW: HashWriter<HASH_LEN>, HW: HashWriter<HASH_LEN>,
@ -54,6 +54,16 @@ where
} }
} }
fn coerce_non_const_generic_set<const HASH_LEN: usize>(
set: &BTreeSet<[u8; HASH_LEN]>,
) -> BTreeSet<&[u8]> {
let mut coerced_item = BTreeSet::new();
for arr in set {
coerced_item.insert(arr.as_slice());
}
coerced_item
}
impl<HW, const HASH_LEN: usize> Serialize for Node<HW, HASH_LEN> impl<HW, const HASH_LEN: usize> Serialize for Node<HW, HASH_LEN>
where where
HW: HashWriter<HASH_LEN>, HW: HashWriter<HASH_LEN>,
@ -63,15 +73,18 @@ where
S: Serializer, S: Serializer,
{ {
let mut structor = serializer.serialize_struct("Node", 4)?; let mut structor = serializer.serialize_struct("Node", 4)?;
structor.serialize_field("id", self.id.as_slice())?;
structor.serialize_field("item", &self.item)?; structor.serialize_field("item", &self.item)?;
structor.serialize_field("item_id", self.item_id.as_slice())?; structor.serialize_field(
// TODO(jwall): structor.serialize_field("dependency_ids", &self.dependency_ids)?; "dependency_ids",
&coerce_non_const_generic_set(&self.dependency_ids),
)?;
structor.end() structor.end()
} }
} }
fn coerce_array<const HASH_LEN: usize>(slice: &[u8]) -> Result<[u8; HASH_LEN], String> { fn coerce_const_generic_array<const HASH_LEN: usize>(
slice: &[u8],
) -> Result<[u8; HASH_LEN], String> {
let mut coerced_item: [u8; HASH_LEN] = [0; HASH_LEN]; let mut coerced_item: [u8; HASH_LEN] = [0; HASH_LEN];
if slice.len() > coerced_item.len() { if slice.len() > coerced_item.len() {
return Err(format!( return Err(format!(
@ -85,12 +98,12 @@ fn coerce_array<const HASH_LEN: usize>(slice: &[u8]) -> Result<[u8; HASH_LEN], S
Ok(coerced_item) Ok(coerced_item)
} }
fn coerce_set<const HASH_LEN: usize>( fn coerce_const_generic_set<const HASH_LEN: usize>(
set: BTreeSet<&[u8]>, set: BTreeSet<&[u8]>,
) -> Result<BTreeSet<[u8; HASH_LEN]>, String> { ) -> Result<BTreeSet<[u8; HASH_LEN]>, String> {
let mut coerced_item = BTreeSet::new(); let mut coerced_item = BTreeSet::new();
for slice in set { for slice in set {
coerced_item.insert(coerce_array(slice)?); coerced_item.insert(coerce_const_generic_array(slice)?);
} }
Ok(coerced_item) Ok(coerced_item)
} }
@ -107,9 +120,7 @@ where
#[serde(field_identifier, rename_all = "lowercase")] #[serde(field_identifier, rename_all = "lowercase")]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
enum Field { enum Field {
Id,
Item, Item,
Item_Id,
Dependency_Ids, Dependency_Ids,
} }
@ -129,53 +140,25 @@ where
where where
A: serde::de::SeqAccess<'de>, A: serde::de::SeqAccess<'de>,
{ {
let id: [u8; HASH_LEN] = coerce_array(
seq.next_element::<&[u8]>()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?,
)
.map_err(serde::de::Error::custom)?;
let item = seq let item = seq
.next_element::<Vec<u8>>()? .next_element::<Vec<u8>>()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
let item_id: [u8; HASH_LEN] = coerce_array( let dependency_ids: BTreeSet<[u8; HASH_LEN]> = coerce_const_generic_set(
seq.next_element::<&[u8]>()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?,
)
.map_err(serde::de::Error::custom)?;
let dependency_ids: BTreeSet<[u8; HASH_LEN]> = coerce_set(
seq.next_element::<BTreeSet<&[u8]>>()? seq.next_element::<BTreeSet<&[u8]>>()?
.ok_or_else(|| serde::de::Error::invalid_length(3, &self))?, .ok_or_else(|| serde::de::Error::invalid_length(3, &self))?,
) )
.map_err(serde::de::Error::custom)?; .map_err(serde::de::Error::custom)?;
Ok(Self::Value { Ok(Self::Value::new(item, dependency_ids))
id,
item,
item_id,
dependency_ids,
_phantom: PhantomData,
})
} }
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error> fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where where
A: serde::de::MapAccess<'de>, A: serde::de::MapAccess<'de>,
{ {
let mut id: Option<[u8; HASH_LEN]> = None;
let mut item: Option<Vec<u8>> = None; let mut item: Option<Vec<u8>> = None;
let mut item_id: Option<[u8; HASH_LEN]> = None;
let mut dependency_ids: Option<BTreeSet<[u8; HASH_LEN]>> = None; let mut dependency_ids: Option<BTreeSet<[u8; HASH_LEN]>> = None;
while let Some(key) = map.next_key()? { while let Some(key) = map.next_key()? {
match key { match key {
Field::Id => {
if id.is_some() {
return Err(serde::de::Error::duplicate_field("id"));
} else {
id = Some(
coerce_array(map.next_value()?)
.map_err(serde::de::Error::custom)?,
);
}
}
Field::Item => { Field::Item => {
if item.is_some() { if item.is_some() {
return Err(serde::de::Error::duplicate_field("item")); return Err(serde::de::Error::duplicate_field("item"));
@ -183,50 +166,28 @@ where
item = Some(map.next_value()?); item = Some(map.next_value()?);
} }
} }
Field::Item_Id => {
if item_id.is_some() {
return Err(serde::de::Error::duplicate_field("item_id"));
} else {
item_id = Some(
coerce_array(map.next_value()?)
.map_err(serde::de::Error::custom)?,
);
}
}
Field::Dependency_Ids => { Field::Dependency_Ids => {
if dependency_ids.is_some() { if dependency_ids.is_some() {
return Err(serde::de::Error::duplicate_field("dependency_ids")); return Err(serde::de::Error::duplicate_field("dependency_ids"));
} else { } else {
dependency_ids = Some( dependency_ids = Some(
coerce_set(map.next_value()?) coerce_const_generic_set(map.next_value()?)
.map_err(serde::de::Error::custom)?, .map_err(serde::de::Error::custom)?,
); );
} }
} }
} }
} }
let id = id.ok_or_else(|| serde::de::Error::missing_field("id"))?;
let item = item.ok_or_else(|| serde::de::Error::missing_field("item"))?; let item = item.ok_or_else(|| serde::de::Error::missing_field("item"))?;
let item_id = item_id.ok_or_else(|| serde::de::Error::missing_field("item_id"))?;
let dependency_ids = dependency_ids let dependency_ids = dependency_ids
.ok_or_else(|| serde::de::Error::missing_field("dependency_ids"))?; .ok_or_else(|| serde::de::Error::missing_field("dependency_ids"))?;
Ok(Self::Value { Ok(Self::Value::new(item, dependency_ids))
id,
item,
item_id,
dependency_ids,
_phantom: PhantomData,
})
} }
} }
const FIELDS: &'static [&'static str] = &["id", "item", "item_id", "dependency_ids"]; const FIELDS: &'static [&'static str] = &["item", "dependency_ids"];
deserializer.deserialize_struct( deserializer.deserialize_struct("Node", FIELDS, NodeVisitor::<HW, HASH_LEN>(PhantomData))
"Duration",
FIELDS,
NodeVisitor::<HW, HASH_LEN>(PhantomData),
)
} }
} }

View File

@ -161,3 +161,46 @@ fn test_node_comparison_no_shared_graph() {
NodeCompare::Uncomparable NodeCompare::Uncomparable
); );
} }
#[cfg(feature = "cbor")]
mod cbor_serialization_tests {
use super::TestDag;
use crate::prelude::*;
use ciborium::{de::from_reader, ser::into_writer};
use std::collections::{hash_map::DefaultHasher, BTreeSet};
#[test]
fn test_node_deserializaton() {
let mut dag = TestDag::new();
let simple_node_id = dag.add_node("simple", BTreeSet::new()).unwrap();
let mut dep_set = BTreeSet::new();
dep_set.insert(simple_node_id);
let root_node_id = dag.add_node("root", dep_set).unwrap();
let simple_node_to_serialize = dag.get_node_by_id(&simple_node_id).unwrap().unwrap();
let root_node_to_serialize = dag.get_node_by_id(&root_node_id).unwrap().unwrap();
let mut simple_node_vec: Vec<u8> = Vec::new();
let mut root_node_vec: Vec<u8> = Vec::new();
into_writer(&simple_node_to_serialize, &mut simple_node_vec).unwrap();
into_writer(&root_node_to_serialize, &mut root_node_vec).unwrap();
let simple_node_de: Node<DefaultHasher, 8> =
from_reader(simple_node_vec.as_slice()).unwrap();
let root_node_de: Node<DefaultHasher, 8> = from_reader(root_node_vec.as_slice()).unwrap();
assert_eq!(simple_node_to_serialize.id(), simple_node_de.id());
assert_eq!(simple_node_to_serialize.item_id(), simple_node_de.item_id());
assert_eq!(simple_node_to_serialize.item(), simple_node_de.item());
assert_eq!(
simple_node_to_serialize.dependency_ids(),
simple_node_de.dependency_ids()
);
assert_eq!(root_node_to_serialize.id(), root_node_de.id());
assert_eq!(root_node_to_serialize.item_id(), root_node_de.item_id());
assert_eq!(root_node_to_serialize.item(), root_node_de.item());
assert_eq!(
root_node_to_serialize.dependency_ids(),
root_node_de.dependency_ids()
);
}
}