use crate::sketchbook::observations::_var_value::VarValue;
use crate::sketchbook::{ids::ObservationId, utils::assert_name_valid};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct Observation {
id: ObservationId,
name: String,
annotation: String,
values: Vec<VarValue>,
}
impl Observation {
pub fn new(values: Vec<VarValue>, id: &str) -> Result<Self, String> {
Self::new_annotated(values, id, id, "")
}
pub fn new_annotated(
values: Vec<VarValue>,
id: &str,
name: &str,
annot: &str,
) -> Result<Self, String> {
assert_name_valid(name)?;
Ok(Self {
id: ObservationId::new(id)?,
name: name.to_string(),
annotation: annot.to_string(),
values,
})
}
pub fn new_full_ones(n: usize, id: &str) -> Result<Self, String> {
Self::new(vec![VarValue::True; n], id)
}
pub fn new_full_zeros(n: usize, id: &str) -> Result<Self, String> {
Self::new(vec![VarValue::False; n], id)
}
pub fn new_full_unspecified(n: usize, id: &str) -> Result<Self, String> {
Self::new(vec![VarValue::Any; n], id)
}
pub fn try_from_str(observation_str: &str, id: &str) -> Result<Self, String> {
let mut observation_vec: Vec<VarValue> = Vec::new();
for c in observation_str.chars() {
observation_vec.push(VarValue::from_str(&c.to_string())?)
}
Self::new(observation_vec, id)
}
pub fn try_from_str_annotated(
observation_str: &str,
id: &str,
name: &str,
annot: &str,
) -> Result<Self, String> {
let mut obs = Self::try_from_str(observation_str, id)?;
obs.set_name(name)?;
obs.set_annotation(annot);
Ok(obs)
}
}
impl Observation {
pub fn set_name(&mut self, name: &str) -> Result<(), String> {
assert_name_valid(name)?;
self.name = name.to_string();
Ok(())
}
pub fn set_annotation(&mut self, annotation: &str) {
self.annotation = annotation.to_string();
}
pub fn set_value(&mut self, index: usize, value: VarValue) -> Result<(), String> {
if index >= self.num_values() {
return Err("Index is larger than number of values.".to_string());
}
self.values[index] = value;
Ok(())
}
pub fn set_value_by_str(&mut self, index: usize, value: &str) -> Result<(), String> {
let converted_value = VarValue::from_str(value)?;
self.set_value(index, converted_value)
}
pub fn set_all_values(&mut self, values: Vec<VarValue>) -> Result<(), String> {
if values.len() != self.num_values() {
return Err("Vectors of old and new values differ in length.".to_string());
}
self.values = values;
Ok(())
}
pub fn set_all_values_by_str(&mut self, values: &str) -> Result<(), String> {
let mut converted_values: Vec<VarValue> = Vec::new();
for c in values.chars() {
converted_values.push(VarValue::from_str(&c.to_string())?)
}
self.set_all_values(converted_values)
}
pub fn set_id(&mut self, id: ObservationId) {
self.id = id;
}
pub fn set_id_by_str(&mut self, id: &str) -> Result<(), String> {
let obs_id = ObservationId::new(id)?;
self.set_id(obs_id);
Ok(())
}
pub fn remove_nth_value(&mut self, index: usize) -> Result<(), String> {
if index >= self.num_values() {
return Err("Index is larger than number of values.".to_string());
}
self.values.remove(index);
Ok(())
}
pub fn add_value(&mut self, index: usize, value: VarValue) -> Result<(), String> {
if index > self.num_values() {
return Err("Index is larger than number of values.".to_string());
}
self.values.insert(index, value);
Ok(())
}
}
impl Observation {
pub fn get_name(&self) -> &str {
&self.name
}
pub fn get_annotation(&self) -> &str {
&self.annotation
}
pub fn get_values(&self) -> &Vec<VarValue> {
&self.values
}
pub fn get_id(&self) -> &ObservationId {
&self.id
}
pub fn num_values(&self) -> usize {
self.values.len()
}
pub fn num_unspecified_values(&self) -> usize {
self.values.iter().filter(|&v| *v == VarValue::Any).count()
}
pub fn num_specified_values(&self) -> usize {
self.values.iter().filter(|&v| *v != VarValue::Any).count()
}
pub fn num_ones(&self) -> usize {
self.values.iter().filter(|&v| *v == VarValue::True).count()
}
pub fn num_zeros(&self) -> usize {
self.values
.iter()
.filter(|&v| *v == VarValue::False)
.count()
}
pub fn value_at_idx(&self, index: usize) -> Result<&VarValue, String> {
if index >= self.num_values() {
return Err("Index is larger than number of values.".to_string());
}
Ok(&self.values[index])
}
pub fn to_values_string(&self) -> String {
let mut values_string = String::new();
self.values
.iter()
.for_each(|v| values_string.push_str(v.as_str()));
values_string
}
pub fn to_debug_string(&self) -> String {
let values_string = self.to_values_string();
format!("{}({values_string})", self.id)
}
}
#[cfg(test)]
mod tests {
use crate::sketchbook::observations::{Observation, VarValue};
#[test]
fn test_observation_from_str() {
let observation_str = "001**";
let id = "observation_id";
let expected_values = vec![
VarValue::False,
VarValue::False,
VarValue::True,
VarValue::Any,
VarValue::Any,
];
let expected_obs = Observation::new(expected_values, id).unwrap();
assert_eq!(
Observation::try_from_str(observation_str, id).unwrap(),
expected_obs
);
}
#[test]
fn test_creating_shortcuts() {
let obs = Observation::new_full_ones(4, "o").unwrap();
let expected_obs = Observation::try_from_str("1111", "o").unwrap();
assert_eq!(obs, expected_obs);
let obs = Observation::new_full_zeros(4, "o").unwrap();
let expected_obs = Observation::try_from_str("0000", "o").unwrap();
assert_eq!(obs, expected_obs);
let obs = Observation::new_full_unspecified(4, "o").unwrap();
let expected_obs = Observation::try_from_str("****", "o").unwrap();
assert_eq!(obs, expected_obs);
}
#[test]
fn test_getters() {
let obs = Observation::try_from_str("10*11*", "o").unwrap();
assert_eq!(obs.num_values(), 6);
assert_eq!(obs.num_ones(), 3);
assert_eq!(obs.num_zeros(), 1);
assert_eq!(obs.num_specified_values(), 4);
assert_eq!(obs.num_unspecified_values(), 2);
assert_eq!(obs.get_id().as_str(), "o");
assert_eq!(obs.value_at_idx(0).unwrap().as_str(), "1");
assert_eq!(obs.value_at_idx(5).unwrap().as_str(), "*");
assert!(obs.value_at_idx(6).is_err());
}
#[test]
fn test_setters() {
let mut obs = Observation::try_from_str("10*11*", "o").unwrap();
obs.set_id_by_str("p").unwrap();
assert_eq!(obs.get_id().as_str(), "p");
obs.set_value_by_str(1, "1").unwrap();
assert_eq!(obs.to_values_string().as_str(), "11*11*");
obs.set_all_values_by_str("111111").unwrap();
assert_eq!(obs.to_values_string().as_str(), "111111");
}
#[test]
fn test_insert_remove_value() {
let mut observation = Observation::try_from_str("001**", "id1").unwrap();
observation.add_value(1, VarValue::True).unwrap();
assert_eq!(observation.to_values_string(), "0101**");
observation.remove_nth_value(3).unwrap();
assert_eq!(observation.to_values_string(), "010**");
}
#[test]
fn test_err_observation_from_str() {
let observation_str1 = "0 1**";
let observation_str2 = "0**a";
assert!(Observation::try_from_str(observation_str1, "obs1").is_err());
assert!(Observation::try_from_str(observation_str2, "obs2").is_err());
}
#[test]
fn test_display_observations() {
let values_str = "001**";
let observation = Observation::try_from_str(values_str, "id1").unwrap();
let expected_long = "id1(001**)".to_string();
assert_eq!(observation.to_values_string(), values_str.to_string());
assert_eq!(observation.to_debug_string(), expected_long);
}
}