Tests for Node Deserialization

This commit is contained in:
Jeremy Wall 2022-08-23 22:03:48 -04:00
parent 4ba7ee816b
commit d723d427e6
3 changed files with 71 additions and 66 deletions

View File

@ -1,3 +1,4 @@
{
"rust-analyzer.checkOnSave.features": "all"
"rust-analyzer.checkOnSave.features": "all",
"rust-analyzer.cargo.features": "all"
}

View File

@ -27,7 +27,7 @@ use crate::hash::HashWriter;
/// Nodes are tied to a specific implementation of the HashWriter trait which is itself tied
/// to the DAG they are stored in guaranteeing that the same Hashing implementation is used
/// for each node in the DAG.
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Eq)]
pub struct Node<HW, const HASH_LEN: usize>
where
HW: HashWriter<HASH_LEN>,
@ -54,6 +54,16 @@ where
}
}
fn coerce_non_const_generic_set<const HASH_LEN: usize>(
set: &BTreeSet<[u8; HASH_LEN]>,
) -> BTreeSet<&[u8]> {
let mut coerced_item = BTreeSet::new();
for arr in set {
coerced_item.insert(arr.as_slice());
}
coerced_item
}
impl<HW, const HASH_LEN: usize> Serialize for Node<HW, HASH_LEN>
where
HW: HashWriter<HASH_LEN>,
@ -63,15 +73,18 @@ where
S: Serializer,
{
let mut structor = serializer.serialize_struct("Node", 4)?;
structor.serialize_field("id", self.id.as_slice())?;
structor.serialize_field("item", &self.item)?;
structor.serialize_field("item_id", self.item_id.as_slice())?;
// TODO(jwall): structor.serialize_field("dependency_ids", &self.dependency_ids)?;
structor.serialize_field(
"dependency_ids",
&coerce_non_const_generic_set(&self.dependency_ids),
)?;
structor.end()
}
}
fn coerce_array<const HASH_LEN: usize>(slice: &[u8]) -> Result<[u8; HASH_LEN], String> {
fn coerce_const_generic_array<const HASH_LEN: usize>(
slice: &[u8],
) -> Result<[u8; HASH_LEN], String> {
let mut coerced_item: [u8; HASH_LEN] = [0; HASH_LEN];
if slice.len() > coerced_item.len() {
return Err(format!(
@ -85,12 +98,12 @@ fn coerce_array<const HASH_LEN: usize>(slice: &[u8]) -> Result<[u8; HASH_LEN], S
Ok(coerced_item)
}
fn coerce_set<const HASH_LEN: usize>(
fn coerce_const_generic_set<const HASH_LEN: usize>(
set: BTreeSet<&[u8]>,
) -> Result<BTreeSet<[u8; HASH_LEN]>, String> {
let mut coerced_item = BTreeSet::new();
for slice in set {
coerced_item.insert(coerce_array(slice)?);
coerced_item.insert(coerce_const_generic_array(slice)?);
}
Ok(coerced_item)
}
@ -107,9 +120,7 @@ where
#[serde(field_identifier, rename_all = "lowercase")]
#[allow(non_camel_case_types)]
enum Field {
Id,
Item,
Item_Id,
Dependency_Ids,
}
@ -129,53 +140,25 @@ where
where
A: serde::de::SeqAccess<'de>,
{
let id: [u8; HASH_LEN] = coerce_array(
seq.next_element::<&[u8]>()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?,
)
.map_err(serde::de::Error::custom)?;
let item = seq
.next_element::<Vec<u8>>()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
let item_id: [u8; HASH_LEN] = coerce_array(
seq.next_element::<&[u8]>()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?,
)
.map_err(serde::de::Error::custom)?;
let dependency_ids: BTreeSet<[u8; HASH_LEN]> = coerce_set(
let dependency_ids: BTreeSet<[u8; HASH_LEN]> = coerce_const_generic_set(
seq.next_element::<BTreeSet<&[u8]>>()?
.ok_or_else(|| serde::de::Error::invalid_length(3, &self))?,
)
.map_err(serde::de::Error::custom)?;
Ok(Self::Value {
id,
item,
item_id,
dependency_ids,
_phantom: PhantomData,
})
Ok(Self::Value::new(item, dependency_ids))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut id: Option<[u8; HASH_LEN]> = None;
let mut item: Option<Vec<u8>> = None;
let mut item_id: Option<[u8; HASH_LEN]> = None;
let mut dependency_ids: Option<BTreeSet<[u8; HASH_LEN]>> = None;
while let Some(key) = map.next_key()? {
match key {
Field::Id => {
if id.is_some() {
return Err(serde::de::Error::duplicate_field("id"));
} else {
id = Some(
coerce_array(map.next_value()?)
.map_err(serde::de::Error::custom)?,
);
}
}
Field::Item => {
if item.is_some() {
return Err(serde::de::Error::duplicate_field("item"));
@ -183,50 +166,28 @@ where
item = Some(map.next_value()?);
}
}
Field::Item_Id => {
if item_id.is_some() {
return Err(serde::de::Error::duplicate_field("item_id"));
} else {
item_id = Some(
coerce_array(map.next_value()?)
.map_err(serde::de::Error::custom)?,
);
}
}
Field::Dependency_Ids => {
if dependency_ids.is_some() {
return Err(serde::de::Error::duplicate_field("dependency_ids"));
} else {
dependency_ids = Some(
coerce_set(map.next_value()?)
coerce_const_generic_set(map.next_value()?)
.map_err(serde::de::Error::custom)?,
);
}
}
}
}
let id = id.ok_or_else(|| serde::de::Error::missing_field("id"))?;
let item = item.ok_or_else(|| serde::de::Error::missing_field("item"))?;
let item_id = item_id.ok_or_else(|| serde::de::Error::missing_field("item_id"))?;
let dependency_ids = dependency_ids
.ok_or_else(|| serde::de::Error::missing_field("dependency_ids"))?;
Ok(Self::Value {
id,
item,
item_id,
dependency_ids,
_phantom: PhantomData,
})
Ok(Self::Value::new(item, dependency_ids))
}
}
const FIELDS: &'static [&'static str] = &["id", "item", "item_id", "dependency_ids"];
deserializer.deserialize_struct(
"Duration",
FIELDS,
NodeVisitor::<HW, HASH_LEN>(PhantomData),
)
const FIELDS: &'static [&'static str] = &["item", "dependency_ids"];
deserializer.deserialize_struct("Node", FIELDS, NodeVisitor::<HW, HASH_LEN>(PhantomData))
}
}

View File

@ -161,3 +161,46 @@ fn test_node_comparison_no_shared_graph() {
NodeCompare::Uncomparable
);
}
#[cfg(feature = "cbor")]
mod cbor_serialization_tests {
use super::TestDag;
use crate::prelude::*;
use ciborium::{de::from_reader, ser::into_writer};
use std::collections::{hash_map::DefaultHasher, BTreeSet};
#[test]
fn test_node_deserializaton() {
let mut dag = TestDag::new();
let simple_node_id = dag.add_node("simple", BTreeSet::new()).unwrap();
let mut dep_set = BTreeSet::new();
dep_set.insert(simple_node_id);
let root_node_id = dag.add_node("root", dep_set).unwrap();
let simple_node_to_serialize = dag.get_node_by_id(&simple_node_id).unwrap().unwrap();
let root_node_to_serialize = dag.get_node_by_id(&root_node_id).unwrap().unwrap();
let mut simple_node_vec: Vec<u8> = Vec::new();
let mut root_node_vec: Vec<u8> = Vec::new();
into_writer(&simple_node_to_serialize, &mut simple_node_vec).unwrap();
into_writer(&root_node_to_serialize, &mut root_node_vec).unwrap();
let simple_node_de: Node<DefaultHasher, 8> =
from_reader(simple_node_vec.as_slice()).unwrap();
let root_node_de: Node<DefaultHasher, 8> = from_reader(root_node_vec.as_slice()).unwrap();
assert_eq!(simple_node_to_serialize.id(), simple_node_de.id());
assert_eq!(simple_node_to_serialize.item_id(), simple_node_de.item_id());
assert_eq!(simple_node_to_serialize.item(), simple_node_de.item());
assert_eq!(
simple_node_to_serialize.dependency_ids(),
simple_node_de.dependency_ids()
);
assert_eq!(root_node_to_serialize.id(), root_node_de.id());
assert_eq!(root_node_to_serialize.item_id(), root_node_de.item_id());
assert_eq!(root_node_to_serialize.item(), root_node_de.item());
assert_eq!(
root_node_to_serialize.dependency_ids(),
root_node_de.dependency_ids()
);
}
}