From 4f0a0fb5773005785f2f3e3c0bf412594883eadd Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Tue, 28 Oct 2025 14:13:40 +0530 Subject: [PATCH 1/8] wip: Add evaluation context --- Cargo.toml | 3 + src/engine.rs | 349 ++++++++++----------- src/engine_eval/context.rs | 188 +++++++++++ src/engine_eval/mappers.rs | 330 +++++++++++++++++++ src/engine_eval/mod.rs | 17 + src/engine_eval/result.rs | 46 +++ src/engine_eval/segment_evaluator.rs | 453 +++++++++++++++++++++++++++ src/lib.rs | 1 + tests/engine_eval_test.rs | 178 +++++++++++ tests/engine_tests/engine-test-data | 2 +- tests/engine_tests/engine_tests.rs | 103 ------ tests/tests.rs | 3 - 12 files changed, 1383 insertions(+), 290 deletions(-) create mode 100644 src/engine_eval/context.rs create mode 100644 src/engine_eval/mappers.rs create mode 100644 src/engine_eval/mod.rs create mode 100644 src/engine_eval/result.rs create mode 100644 src/engine_eval/segment_evaluator.rs create mode 100644 tests/engine_eval_test.rs delete mode 100644 tests/engine_tests/engine_tests.rs delete mode 100644 tests/tests.rs diff --git a/Cargo.toml b/Cargo.toml index 8fbfc8f..66094ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ maintenance = { status = "actively-developed" } [dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +serde_json_path = "0.7" chrono = { version = "0.4", features = ["serde"] } md-5 = "0.10.1" num-bigint = "0.4" @@ -20,6 +21,8 @@ num-traits = "0.2.14" uuid = { version = "0.8", features = ["serde", "v4"] } regex = "1" semver = "1.0" +sha2 = "0.10" [dev-dependencies] rstest = "0.12.0" +json_comments = "0.2" diff --git a/src/engine.rs b/src/engine.rs index fc50873..3464f49 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,213 +1,196 @@ -use super::environments; -use super::error; -use super::features; -use super::identities; -use super::segments::evaluator; -use crate::features::Feature; -use crate::features::FeatureState; +use crate::engine_eval::context::{EngineEvaluationContext, FeatureContext}; +use crate::engine_eval::result::{EvaluationResult, FlagResult, SegmentResult}; +use crate::engine_eval::segment_evaluator::is_context_in_segment; +use crate::utils::hashing; use std::collections::HashMap; -//Returns a vector of feature states for a given environment -pub fn get_environment_feature_states( - environment: environments::Environment, -) -> Vec { - if environment.project.hide_disabled_flags { - return environment - .feature_states - .iter() - .filter(|fs| fs.enabled) - .map(|fs| fs.clone()) - .collect(); - } - return environment.feature_states; +/// Holds a feature context with its associated segment name for priority comparison +struct FeatureContextWithSegment { + feature_context: FeatureContext, + segment_name: String, } -// Returns a specific feature state for a given feature_name in a given environment -// If exists else returns a FeatureStateNotFound error -pub fn get_environment_feature_state( - environment: environments::Environment, - feature_name: &str, -) -> Result { - let fs = environment - .feature_states - .iter() - .filter(|fs| fs.feature.name == feature_name) - .next() - .ok_or(error::Error::new(error::ErrorKind::FeatureStateNotFound)); - return Ok(fs?.clone()); +/// Helper to get priority or default +fn get_priority_or_default(priority: Option) -> f64 { + priority.unwrap_or(f64::INFINITY) // Weakest possible priority } -// Returns a vector of feature state models based on the environment, any matching -// segments and any specific identity overrides -pub fn get_identity_feature_states( - environment: &environments::Environment, - identity: &identities::Identity, - override_traits: Option<&Vec>, -) -> Vec { - let feature_states = - get_identity_feature_states_map(environment, identity, override_traits).into_values(); - if environment.project.hide_disabled_flags { - return feature_states.filter(|fs| fs.enabled).collect(); +/// Gets matching segments and their overrides +fn get_matching_segments_and_overrides( + ec: &EngineEvaluationContext, +) -> ( + Vec, + HashMap, +) { + let mut segments = Vec::new(); + let mut segment_feature_contexts: HashMap = HashMap::new(); + + // Process segments + for segment_context in ec.segments.values() { + if !is_context_in_segment(ec, segment_context) { + continue; + } + + // Add segment to results + segments.push(SegmentResult { + name: segment_context.name.clone(), + metadata: segment_context.metadata.clone(), + }); + + // Process segment overrides + for override_fc in &segment_context.overrides { + let feature_name = &override_fc.name; + + // Check if we should update the segment feature context + let should_update = if let Some(existing) = segment_feature_contexts.get(feature_name) { + let existing_priority = get_priority_or_default(existing.feature_context.priority); + let override_priority = get_priority_or_default(override_fc.priority); + override_priority < existing_priority + } else { + true + }; + + if should_update { + segment_feature_contexts.insert( + feature_name.clone(), + FeatureContextWithSegment { + feature_context: override_fc.clone(), + segment_name: segment_context.name.clone(), + }, + ); + } + } } - return feature_states.collect(); -} -// Returns a specific feature state based on the environment, any matching -// segments and any specific identity overrides -// If exists else returns a FeatureStateNotFound error -pub fn get_identity_feature_state( - environment: &environments::Environment, - identity: &identities::Identity, - feature_name: &str, - override_traits: Option<&Vec>, -) -> Result { - let feature_states = - get_identity_feature_states_map(environment, identity, override_traits).into_values(); - let fs = feature_states - .filter(|fs| fs.feature.name == feature_name) - .next() - .ok_or(error::Error::new(error::ErrorKind::FeatureStateNotFound)); - - return Ok(fs?.clone()); + (segments, segment_feature_contexts) } -fn get_identity_feature_states_map( - environment: &environments::Environment, - identity: &identities::Identity, - override_traits: Option<&Vec>, -) -> HashMap { - let mut feature_states: HashMap = HashMap::new(); +/// Gets flag results from feature contexts and segment overrides +fn get_flag_results( + ec: &EngineEvaluationContext, + segment_feature_contexts: &HashMap, +) -> HashMap { + let mut flags = HashMap::new(); + + // Get identity key if identity exists + // If identity key is not provided, construct it from environment key and identifier + let identity_key: Option = ec.identity.as_ref().map(|i| { + if i.key.is_empty() { + format!("{}_{}", ec.environment.key, i.identifier) + } else { + i.key.clone() + } + }); - // Get feature states from the environment - for fs in environment.feature_states.clone() { - feature_states.insert(fs.feature.clone(), fs); + // Process all features + for feature_context in ec.features.values() { + // Check if we have a segment override for this feature + if let Some(segment_fc) = segment_feature_contexts.get(&feature_context.name) { + // Use segment override + let fc = &segment_fc.feature_context; + let reason = format!("TARGETING_MATCH; segment={}", segment_fc.segment_name); + flags.insert( + feature_context.name.clone(), + FlagResult { + enabled: fc.enabled, + name: fc.name.clone(), + reason, + value: fc.value.clone(), + metadata: fc.metadata.clone(), + }, + ); + } else { + // Use default feature context + let flag_result = + get_flag_result_from_feature_context(feature_context, identity_key.as_ref()); + flags.insert(feature_context.name.clone(), flag_result); + } } - // Override with any feature states defined by matching segments - let identity_segments = - evaluator::get_identity_segments(environment, identity, override_traits); - for matching_segments in identity_segments { - for feature_state in matching_segments.feature_states { - let existing = feature_states.get(&feature_state.feature); - if existing.is_some() { - if existing.unwrap().is_higher_segment_priority(&feature_state) { - continue; - } + flags +} + +pub fn get_evaluation_result(ec: &EngineEvaluationContext) -> EvaluationResult { + // Process segments + let (segments, segment_feature_contexts) = get_matching_segments_and_overrides(ec); + + // Get flag results + let flags = get_flag_results(ec, &segment_feature_contexts); + + EvaluationResult { flags, segments } +} + +/// Creates a FlagResult from a FeatureContext +fn get_flag_result_from_feature_context( + feature_context: &FeatureContext, + identity_key: Option<&String>, +) -> FlagResult { + let mut reason = "DEFAULT".to_string(); + let mut value = feature_context.value.clone(); + + // Handle multivariate features + if !feature_context.variants.is_empty() + && identity_key.is_some() + && !feature_context.key.is_empty() + { + // Sort variants by priority (lower priority value = higher priority) + let mut sorted_variants = feature_context.variants.clone(); + sorted_variants.sort_by(|a, b| { + let pa = get_priority_or_default(a.priority); + let pb = get_priority_or_default(b.priority); + pa.partial_cmp(&pb).unwrap() + }); + + // Calculate hash percentage for the identity and feature combination + let object_ids = vec![feature_context.key.as_str(), identity_key.unwrap().as_str()]; + let hash_percentage = hashing::get_hashed_percentage_for_object_ids(object_ids, 1); + + // Select variant based on weighted distribution + let mut cumulative_weight = 0.0; + for variant in &sorted_variants { + cumulative_weight += variant.weight; + if (hash_percentage as f64) <= cumulative_weight { + value = variant.value.clone(); + reason = format!("SPLIT; weight={}", variant.weight); + break; } - feature_states.insert(feature_state.feature.clone(), feature_state); } } - // Override with any feature states defined directly the identity - for feature_state in identity.identity_features.clone() { - feature_states.insert(feature_state.feature.clone(), feature_state); + + FlagResult { + enabled: feature_context.enabled, + name: feature_context.name.clone(), + value, + reason, + metadata: feature_context.metadata.clone(), } - return feature_states; } #[cfg(test)] mod tests { use super::*; - static IDENTITY_JSON: &str = r#"{ - "identifier": "test_user", - "environment_api_key": "test_api_key", - "created_date": "2022-03-02T12:31:05.309861", - "identity_features": [], - "identity_traits": [], - "identity_uuid":"" - }"#; - static ENVIRONMENT_JSON: &str = r#" - { - "api_key": "test_key", - "project": { - "name": "Test project", - "organisation": { - "feature_analytics": false, - "name": "Test Org", - "id": 1, - "persist_trait_data": true, - "stop_serving_flags": false - }, - "id": 1, - "hide_disabled_flags": true, - "segments": [] - }, - "segment_overrides": [], - "id": 1, - "feature_states": [ - { - "multivariate_feature_state_values": [], - "feature_state_value": true, - "django_id": 1, - "feature": { - "name": "feature1", - "type": null, - "id": 1 - }, - "enabled": false - }, - { - "multivariate_feature_state_values": [], - "feature_state_value": null, - "django_id": 2, - "feature": { - "name": "feature_2", - "type": null, - "id": 2 - }, - "enabled": true - } - ] -}"#; + use crate::engine_eval::context::EnvironmentContext; #[test] - fn get_environment_feature_states_only_return_enabled_fs_if_hide_disabled_flags_is_true() { - let environment: environments::Environment = - serde_json::from_str(ENVIRONMENT_JSON).unwrap(); - - let environment_feature_states = get_environment_feature_states(environment); - assert_eq!(environment_feature_states.len(), 1); - assert_eq!(environment_feature_states[0].django_id.unwrap(), 2); + fn test_get_priority_or_default() { + assert_eq!(get_priority_or_default(Some(1.0)), 1.0); + assert_eq!(get_priority_or_default(None), f64::INFINITY); } #[test] - fn get_environment_feature_state_returns_correct_feature_state() { - let environment: environments::Environment = - serde_json::from_str(ENVIRONMENT_JSON).unwrap(); - let feature_name = "feature_2"; - let feature_state = get_environment_feature_state(environment, feature_name).unwrap(); - assert_eq!(feature_state.feature.name, feature_name) - } - - #[test] - fn get_environment_feature_state_returns_error_if_feature_state_does_not_exists() { - let environment: environments::Environment = - serde_json::from_str(ENVIRONMENT_JSON).unwrap(); - let feature_name = "feature_that_does_not_exists"; - let err = get_environment_feature_state(environment, feature_name) - .err() - .unwrap(); - assert_eq!(err.kind, error::ErrorKind::FeatureStateNotFound) - } + fn test_get_evaluation_result_empty_context() { + let ec = EngineEvaluationContext { + environment: EnvironmentContext { + key: "test".to_string(), + name: "test".to_string(), + }, + features: HashMap::new(), + segments: HashMap::new(), + identity: None, + }; - #[test] - fn get_identity_feature_state_returns_correct_feature_state() { - let environment: environments::Environment = - serde_json::from_str(ENVIRONMENT_JSON).unwrap(); - let feature_name = "feature_2"; - let identity: identities::Identity = serde_json::from_str(IDENTITY_JSON).unwrap(); - let feature_state = - get_identity_feature_state(&environment, &identity, feature_name, None).unwrap(); - assert_eq!(feature_state.feature.name, feature_name) - } - #[test] - fn get_identity_feature_state_returns_error_if_feature_state_does_not_exists() { - let environment: environments::Environment = - serde_json::from_str(ENVIRONMENT_JSON).unwrap(); - let feature_name = "feature_that_does_not_exists"; - let identity: identities::Identity = serde_json::from_str(IDENTITY_JSON).unwrap(); - let err = get_identity_feature_state(&environment, &identity, feature_name, None) - .err() - .unwrap(); - assert_eq!(err.kind, error::ErrorKind::FeatureStateNotFound) + let result = get_evaluation_result(&ec); + assert_eq!(result.flags.len(), 0); + assert_eq!(result.segments.len(), 0); } } diff --git a/src/engine_eval/context.rs b/src/engine_eval/context.rs new file mode 100644 index 0000000..ea42630 --- /dev/null +++ b/src/engine_eval/context.rs @@ -0,0 +1,188 @@ +use crate::types::FlagsmithValue; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Represents metadata information about a feature. +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +pub struct FeatureMetadata { + /// The feature ID. + #[serde(default)] + pub feature_id: u32, +} + +/// Represents a multivariate value for a feature flag. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FeatureValue { + /// The value of the feature. + pub value: FlagsmithValue, + /// The weight of the feature value variant, as a percentage number (i.e. 100.0). + pub weight: f64, + /// Priority of the feature flag variant. Lower values indicate a higher priority when multiple variants apply to the same context key. + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, +} + +/// Represents a feature context for feature flag evaluation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FeatureContext { + /// String key used for hashing in percentage splits. + pub key: String, + /// The name of the feature. + pub name: String, + /// Whether the feature is enabled. + pub enabled: bool, + /// The default value for the feature. + pub value: FlagsmithValue, + /// Priority for this feature context. Lower values indicate higher priority. + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + /// Multivariate feature variants. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub variants: Vec, + /// Metadata about the feature. + #[serde(default)] + pub metadata: FeatureMetadata, +} + +/// Represents environment metadata. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct EnvironmentContext { + /// The environment API key. + pub key: String, + /// The environment name. + pub name: String, +} + +/// Represents identity context for feature flag evaluation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IdentityContext { + /// The identity identifier. + pub identifier: String, + /// String key used for hashing in percentage splits. + /// If not provided during deserialization, it will be constructed as "environment_key_identifier". + #[serde(default)] + pub key: String, + /// Identity traits as a map of trait keys to values. + #[serde(default)] + pub traits: HashMap, +} + +/// Segment rule condition operators. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum ConditionOperator { + Equal, + NotEqual, + GreaterThan, + GreaterThanInclusive, + LessThan, + LessThanInclusive, + Contains, + NotContains, + In, + Regex, + PercentageSplit, + Modulo, + IsSet, + IsNotSet, +} + +// Helper function to deserialize value that can be a string or array +fn deserialize_condition_value<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + use serde_json::Value; + let value: Value = serde::Deserialize::deserialize(deserializer)?; + Ok(match value { + Value::String(s) => s, + Value::Array(_) | Value::Object(_) | Value::Number(_) | Value::Bool(_) => { + // Serialize non-string values back to JSON string + serde_json::to_string(&value).unwrap_or_default() + } + Value::Null => String::new(), + }) +} + +/// Represents a condition for segment rule evaluation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Condition { + /// The operator for this condition. + pub operator: ConditionOperator, + /// The property to evaluate (can be a JSONPath expression starting with $.). + pub property: String, + /// The value to compare against (can be a string or serialized JSON). + #[serde(deserialize_with = "deserialize_condition_value")] + pub value: String, +} + +/// Segment rule types. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum SegmentRuleType { + All, + Any, + None, +} + +/// Represents a segment rule (can be recursive). +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SegmentRule { + /// The type of rule (ALL, ANY, NONE). + #[serde(rename = "type")] + pub rule_type: SegmentRuleType, + /// Conditions for this rule. + #[serde(default)] + pub conditions: Vec, + /// Nested rules. + #[serde(default)] + pub rules: Vec, +} + +/// Segment metadata. +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +pub struct SegmentMetadata { + /// Segment ID. + #[serde(skip_serializing_if = "Option::is_none")] + pub segment_id: Option, + /// Source of the segment (api or identity_override). + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, +} + +/// Represents a segment context for feature flag evaluation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SegmentContext { + /// Key used for percentage split segmentation. + pub key: String, + /// The name of the segment. + pub name: String, + /// Metadata about the segment. + #[serde(default)] + pub metadata: SegmentMetadata, + /// Feature overrides for the segment. + #[serde(default)] + pub overrides: Vec, + /// Rules that define the segment. + pub rules: Vec, +} + +/// Engine evaluation context that holds pre-processed environment data +/// for efficient feature flag evaluation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct EngineEvaluationContext { + /// Environment metadata. + pub environment: EnvironmentContext, + + /// Feature contexts indexed by feature name. + #[serde(default)] + pub features: HashMap, + + /// Segment contexts indexed by segment key. + #[serde(default)] + pub segments: HashMap, + + /// Optional identity context for evaluation. + #[serde(skip_serializing_if = "Option::is_none")] + pub identity: Option, +} diff --git a/src/engine_eval/mappers.rs b/src/engine_eval/mappers.rs new file mode 100644 index 0000000..dc2afcf --- /dev/null +++ b/src/engine_eval/mappers.rs @@ -0,0 +1,330 @@ +use super::context::{ + Condition, ConditionOperator, EngineEvaluationContext, EnvironmentContext, FeatureContext, + FeatureMetadata, FeatureValue, IdentityContext, SegmentContext, SegmentMetadata, SegmentRule, + SegmentRuleType, +}; +use crate::environments::Environment; +use crate::features::{FeatureState, MultivariateFeatureStateValue}; +use crate::identities::{Identity, Trait}; +use crate::segments::{Segment, SegmentRule as OldSegmentRule}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; + +/// Maps an Environment to an EngineEvaluationContext +/// +/// # Arguments +/// * `environment` - The environment to convert +/// +/// # Returns +/// A new engine evaluation context +pub fn environment_to_context(environment: Environment) -> EngineEvaluationContext { + let mut ctx = EngineEvaluationContext { + environment: EnvironmentContext { + key: environment.api_key.clone(), + name: environment.api_key.clone(), + }, + features: HashMap::new(), + segments: HashMap::new(), + identity: None, + }; + + // Map feature states to feature contexts + for fs in &environment.feature_states { + let fc = map_feature_state_to_feature_context(fs); + ctx.features.insert(fc.name.clone(), fc); + } + + // Map project segments to segment contexts + for segment in &environment.project.segments { + let sc = map_segment_to_segment_context(segment); + ctx.segments.insert(sc.key.clone(), sc); + } + + // Map identity overrides to segments + if !environment.identity_overrides.is_empty() { + let identity_segments = map_identity_overrides_to_segments(&environment.identity_overrides); + for (key, segment) in identity_segments { + ctx.segments.insert(key, segment); + } + } + + ctx +} + +/// Maps a FeatureState to a FeatureContext +fn map_feature_state_to_feature_context(fs: &FeatureState) -> FeatureContext { + let key = if let Some(django_id) = fs.django_id { + django_id.to_string() + } else { + fs.featurestate_uuid.clone() + }; + + let mut fc = FeatureContext { + enabled: fs.enabled, + key, + name: fs.feature.name.clone(), + value: fs.get_value(None), + priority: None, + variants: map_multivariate_values_to_variants(&fs.multivariate_feature_state_values), + metadata: FeatureMetadata { + feature_id: fs.feature.id, + }, + }; + + // Set priority if this is a segment override + if let Some(feature_segment) = &fs.feature_segment { + fc.priority = Some(feature_segment.priority as f64); + } + + fc +} + +/// Maps multivariate feature state values to FeatureValue variants +fn map_multivariate_values_to_variants( + mv_values: &[MultivariateFeatureStateValue], +) -> Vec { + mv_values + .iter() + .map(|mv| FeatureValue { + value: mv.multivariate_feature_option.value.clone(), + weight: mv.percentage_allocation as f64, + priority: None, + }) + .collect() +} + +/// Maps a Segment to a SegmentContext +fn map_segment_to_segment_context(segment: &Segment) -> SegmentContext { + let mut sc = SegmentContext { + key: segment.id.to_string(), + name: segment.name.clone(), + metadata: SegmentMetadata { + segment_id: Some(segment.id as i32), + source: Some("api".to_string()), + }, + overrides: vec![], + rules: vec![], + }; + + // Map feature state overrides + for fs in &segment.feature_states { + sc.overrides.push(map_feature_state_to_feature_context(fs)); + } + + // Map segment rules + for rule in &segment.rules { + sc.rules.push(map_segment_rule_to_rule(rule)); + } + + sc +} + +/// Maps a legacy SegmentRule to the new SegmentRule format +fn map_segment_rule_to_rule(rule: &OldSegmentRule) -> SegmentRule { + let rule_type = map_rule_type(&rule.segment_rule_type); + + let conditions = rule + .conditions + .iter() + .map(|c| Condition { + operator: map_operator(&c.operator), + property: c.property.clone().unwrap_or_default(), + value: c.value.clone().unwrap_or_default(), + }) + .collect(); + + let rules = rule + .rules + .iter() + .map(|r| map_segment_rule_to_rule(r)) + .collect(); + + SegmentRule { + rule_type, + conditions, + rules, + } +} + +/// Maps a rule type string to SegmentRuleType enum +fn map_rule_type(rule_type: &str) -> SegmentRuleType { + match rule_type { + "ALL" => SegmentRuleType::All, + "ANY" => SegmentRuleType::Any, + "NONE" => SegmentRuleType::None, + _ => SegmentRuleType::All, + } +} + +/// Maps an operator string to ConditionOperator enum +fn map_operator(operator: &str) -> ConditionOperator { + match operator { + "EQUAL" => ConditionOperator::Equal, + "NOT_EQUAL" => ConditionOperator::NotEqual, + "GREATER_THAN" => ConditionOperator::GreaterThan, + "GREATER_THAN_INCLUSIVE" => ConditionOperator::GreaterThanInclusive, + "LESS_THAN" => ConditionOperator::LessThan, + "LESS_THAN_INCLUSIVE" => ConditionOperator::LessThanInclusive, + "CONTAINS" => ConditionOperator::Contains, + "NOT_CONTAINS" => ConditionOperator::NotContains, + "IN" => ConditionOperator::In, + "REGEX" => ConditionOperator::Regex, + "PERCENTAGE_SPLIT" => ConditionOperator::PercentageSplit, + "MODULO" => ConditionOperator::Modulo, + "IS_SET" => ConditionOperator::IsSet, + "IS_NOT_SET" => ConditionOperator::IsNotSet, + _ => ConditionOperator::Equal, + } +} + +/// Helper struct for grouping identity overrides +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +struct OverrideKey { + feature_name: String, + enabled: String, + feature_value: String, + feature_id: u32, +} + +/// Maps identity overrides to segment contexts +fn map_identity_overrides_to_segments(identities: &[Identity]) -> HashMap { + let mut features_to_identifiers: HashMap> = HashMap::new(); + let mut overrides_key_to_list: HashMap> = HashMap::new(); + + for identity in identities { + if identity.identity_features.is_empty() { + continue; + } + + // Create override keys from features + let mut overrides = Vec::new(); + for fs in &identity.identity_features { + // Use proper JSON serialization instead of Debug format + let feature_value = + serde_json::to_string(&fs.get_value(None)).unwrap_or_else(|_| "null".to_string()); + overrides.push(OverrideKey { + feature_name: fs.feature.name.clone(), + enabled: fs.enabled.to_string(), + feature_value, + feature_id: fs.feature.id, + }); + } + + // Sort overrides for consistent hashing + overrides.sort(); + + // Generate hash for this set of overrides + let overrides_hash = generate_hash(&overrides); + + // Group identifiers by their overrides + features_to_identifiers + .entry(overrides_hash.clone()) + .or_insert_with(Vec::new) + .push(identity.identifier.clone()); + + overrides_key_to_list.insert(overrides_hash, overrides); + } + + // Create segment contexts for each unique set of overrides + let mut segment_contexts = HashMap::new(); + + for (overrides_hash, identifiers) in features_to_identifiers { + let overrides = overrides_key_to_list.get(&overrides_hash).unwrap(); + + // Create segment context + let mut sc = SegmentContext { + key: String::new(), // Identity override segments never use % Split operator + name: "identity_overrides".to_string(), + metadata: SegmentMetadata { + segment_id: None, + source: Some("identity_override".to_string()), + }, + overrides: vec![], + rules: vec![SegmentRule { + rule_type: SegmentRuleType::All, + conditions: vec![Condition { + operator: ConditionOperator::In, + property: "$.identity.identifier".to_string(), + value: identifiers.join(","), + }], + rules: vec![], + }], + }; + + // Create feature overrides + for override_key in overrides { + let priority = f64::NEG_INFINITY; // Strongest possible priority + let feature_override = FeatureContext { + key: String::new(), // Identity overrides never carry multivariate options + name: override_key.feature_name.clone(), + enabled: override_key.enabled == "true", + value: serde_json::from_str(&override_key.feature_value).unwrap_or_default(), + priority: Some(priority), + variants: vec![], + metadata: FeatureMetadata { + feature_id: override_key.feature_id, + }, + }; + + sc.overrides.push(feature_override); + } + + segment_contexts.insert(overrides_hash, sc); + } + + segment_contexts +} + +/// Generates a hash from override keys for use as segment key +fn generate_hash(overrides: &[OverrideKey]) -> String { + let mut hasher = Sha256::new(); + + for override_key in overrides { + hasher.update(format!( + "{}:{}:{}:{};", + override_key.feature_id, + override_key.feature_name, + override_key.enabled, + override_key.feature_value + )); + } + + let result = hasher.finalize(); + // Use safe slicing - take first 16 chars without panicking + let hex = format!("{:x}", result); + hex.chars().take(16).collect() +} + +/// Adds identity data to an existing context +/// +/// # Arguments +/// * `context` - The context to enrich with identity data +/// * `identifier` - The identity identifier +/// * `traits` - The identity traits +/// +/// # Returns +/// A new context with identity information +pub fn add_identity_to_context( + context: &EngineEvaluationContext, + identifier: &str, + traits: &[Trait], +) -> EngineEvaluationContext { + let mut new_context = context.clone(); + + // Create traits map + let mut identity_traits = HashMap::new(); + for trait_obj in traits { + identity_traits.insert(trait_obj.trait_key.clone(), trait_obj.trait_value.clone()); + } + + // Create identity context + let environment_key = &new_context.environment.key; + let identity = IdentityContext { + identifier: identifier.to_string(), + key: format!("{}_{}", environment_key, identifier), + traits: identity_traits, + }; + + new_context.identity = Some(identity); + new_context +} diff --git a/src/engine_eval/mod.rs b/src/engine_eval/mod.rs new file mode 100644 index 0000000..9f4d226 --- /dev/null +++ b/src/engine_eval/mod.rs @@ -0,0 +1,17 @@ +/// Evaluation context module containing the EngineEvaluationContext struct +pub mod context; + +/// Result module containing evaluation result types +pub mod result; + +/// Segment evaluator module for evaluating segment rules +pub mod segment_evaluator; + +/// Mappers module for converting between old and new types +pub mod mappers; + +// Re-export commonly used types for convenience +pub use context::{EngineEvaluationContext, FeatureContext, FeatureMetadata}; +pub use mappers::{add_identity_to_context, environment_to_context}; +pub use result::{EvaluationResult, FlagResult, SegmentResult}; +pub use segment_evaluator::is_context_in_segment; diff --git a/src/engine_eval/result.rs b/src/engine_eval/result.rs new file mode 100644 index 0000000..0cc6cd7 --- /dev/null +++ b/src/engine_eval/result.rs @@ -0,0 +1,46 @@ +use super::context::{FeatureMetadata, SegmentMetadata}; +use crate::types::FlagsmithValue; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Represents the result of a feature flag evaluation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct EvaluationResult { + /// Map of feature names to their evaluated flag results. + pub flags: HashMap, + + /// List of segments that matched during evaluation. + #[serde(default)] + pub segments: Vec, +} + +/// Represents the evaluated result for a single feature flag. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FlagResult { + /// Whether the feature is enabled. + pub enabled: bool, + + /// The name of the feature. + pub name: String, + + /// The reason for this evaluation result (e.g., "DEFAULT", "TARGETING_MATCH; segment=name", "SPLIT; weight=50"). + pub reason: String, + + /// The value of the feature flag. + pub value: FlagsmithValue, + + /// Metadata about the feature. + #[serde(default)] + pub metadata: FeatureMetadata, +} + +/// Represents a segment that matched during evaluation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SegmentResult { + /// The segment name. + pub name: String, + + /// Metadata about the segment. + #[serde(default)] + pub metadata: SegmentMetadata, +} diff --git a/src/engine_eval/segment_evaluator.rs b/src/engine_eval/segment_evaluator.rs new file mode 100644 index 0000000..ef20689 --- /dev/null +++ b/src/engine_eval/segment_evaluator.rs @@ -0,0 +1,453 @@ +use super::context::{ + Condition, ConditionOperator, EngineEvaluationContext, SegmentContext, SegmentRule, + SegmentRuleType, +}; +use crate::types::FlagsmithValue; +use crate::utils::hashing; +use regex::Regex; +use semver::Version; +use serde_json_path::JsonPath; + +/// Determines if the given evaluation context matches the segment rules +pub fn is_context_in_segment(ec: &EngineEvaluationContext, segment: &SegmentContext) -> bool { + if segment.rules.is_empty() { + return false; + } + + // All top-level rules must match + for rule in &segment.rules { + if !context_matches_segment_rule(ec, rule, &segment.key) { + return false; + } + } + + true +} + +/// Checks if the context matches a segment rule +fn context_matches_segment_rule( + ec: &EngineEvaluationContext, + rule: &SegmentRule, + segment_key: &str, +) -> bool { + // Check conditions if present + if !rule.conditions.is_empty() { + if !matches_conditions_by_rule_type(ec, &rule.conditions, &rule.rule_type, segment_key) { + return false; + } + } + + // Check nested rules + for nested_rule in &rule.rules { + if !context_matches_segment_rule(ec, nested_rule, segment_key) { + return false; + } + } + + true +} + +/// Checks if conditions match according to the rule type +fn matches_conditions_by_rule_type( + ec: &EngineEvaluationContext, + conditions: &[Condition], + rule_type: &SegmentRuleType, + segment_key: &str, +) -> bool { + for condition in conditions { + let condition_matches = context_matches_condition(ec, condition, segment_key); + + match rule_type { + SegmentRuleType::All => { + if !condition_matches { + return false; // Short-circuit: ALL requires all conditions to match + } + } + SegmentRuleType::None => { + if condition_matches { + return false; // Short-circuit: NONE requires no conditions to match + } + } + SegmentRuleType::Any => { + if condition_matches { + return true; // Short-circuit: ANY requires at least one condition to match + } + } + } + } + + // If we reach here: ALL/NONE passed all checks, ANY found no matches + *rule_type != SegmentRuleType::Any +} + +/// Checks if the context matches a specific condition +fn context_matches_condition( + ec: &EngineEvaluationContext, + condition: &Condition, + segment_key: &str, +) -> bool { + let context_value = if !condition.property.is_empty() { + get_context_value(ec, &condition.property) + } else { + None + }; + + match condition.operator { + ConditionOperator::PercentageSplit => { + match_percentage_split(ec, condition, segment_key, context_value.as_ref()) + } + ConditionOperator::In => match_in_operator(condition, context_value.as_ref()), + ConditionOperator::IsNotSet => context_value.is_none(), + ConditionOperator::IsSet => context_value.is_some(), + _ => { + if let Some(ref ctx_val) = context_value { + parse_and_match(&condition.operator, ctx_val, &condition.value) + } else { + false + } + } + } +} + +/// Gets a value from the context by property name or JSONPath +fn get_context_value(ec: &EngineEvaluationContext, property: &str) -> Option { + // If property starts with $., try to parse it as a JSONPath expression + if property.starts_with("$.") { + if let Some(value) = get_value_from_jsonpath(ec, property) { + return Some(value); + } + // If JSONPath parsing fails, fall through to treat it as a trait name + } + + // Check traits by property name + if let Some(ref identity) = ec.identity { + if let Some(trait_value) = identity.traits.get(property) { + return Some(trait_value.clone()); + } + } + + None +} + +/// Gets a value from the context using JSONPath +fn get_value_from_jsonpath(ec: &EngineEvaluationContext, path: &str) -> Option { + // Parse the JSONPath expression + let json_path = match JsonPath::parse(path) { + Ok(p) => p, + Err(_) => return None, + }; + + // Serialize the context to JSON + let context_json = match serde_json::to_value(ec) { + Ok(v) => v, + Err(_) => return None, + }; + + // Query the JSON using the path + let result = json_path.query(&context_json); + + // Get the first match (if any) + let node_list = result.all(); + if node_list.is_empty() { + return None; + } + + // Extract the value from the first match + let value = node_list[0]; + + // Convert to FlagsmithValue based on the JSON type + match value { + serde_json::Value::String(s) => Some(FlagsmithValue { + value: s.clone(), + value_type: crate::types::FlagsmithValueType::String, + }), + serde_json::Value::Number(n) => { + if n.is_f64() { + Some(FlagsmithValue { + value: n.to_string(), + value_type: crate::types::FlagsmithValueType::Float, + }) + } else { + Some(FlagsmithValue { + value: n.to_string(), + value_type: crate::types::FlagsmithValueType::Integer, + }) + } + } + serde_json::Value::Bool(b) => Some(FlagsmithValue { + value: b.to_string(), + value_type: crate::types::FlagsmithValueType::Bool, + }), + _ => None, + } +} + +/// Matches percentage split condition +fn match_percentage_split( + ec: &EngineEvaluationContext, + condition: &Condition, + segment_key: &str, + context_value: Option<&FlagsmithValue>, +) -> bool { + let float_value = match condition.value.parse::() { + Ok(v) => v, + Err(_) => return false, + }; + + // Build object IDs based on context + let context_str = context_value.map(|v| v.value.clone()); + let object_ids: Vec<&str> = if let Some(ref ctx_str) = context_str { + vec![segment_key, ctx_str.as_str()] + } else if let Some(ref identity) = ec.identity { + vec![segment_key, &identity.key] + } else { + return false; + }; + + let hash_percentage = hashing::get_hashed_percentage_for_object_ids(object_ids, 1); + (hash_percentage as f64) <= float_value +} + +/// Matches IN operator +fn match_in_operator(condition: &Condition, context_value: Option<&FlagsmithValue>) -> bool { + if context_value.is_none() { + return false; + } + + let ctx_value = context_value.unwrap(); + + // IN operator only works with string values, not booleans + use crate::types::FlagsmithValueType; + if ctx_value.value_type == FlagsmithValueType::Bool { + return false; + } + + let trait_value = &ctx_value.value; + + // Check if the value is in JSON array format (starts with '[') + if condition.value.trim().starts_with('[') { + // Try to parse as JSON array + if let Ok(array) = serde_json::from_str::>(&condition.value) { + return array.iter().any(|v| { + if let Some(s) = v.as_str() { + s == trait_value + } else if let Some(n) = v.as_i64() { + n.to_string() == *trait_value + } else if let Some(n) = v.as_f64() { + n.to_string() == *trait_value + } else { + false + } + }); + } + } + + // Fall back to comma-separated format + let values: Vec<&str> = condition.value.split(',').collect(); + values.contains(&trait_value.as_str()) +} + +/// Parses and matches values based on the operator using type-aware strategy +fn parse_and_match( + operator: &ConditionOperator, + trait_value: &FlagsmithValue, + condition_value: &str, +) -> bool { + use crate::types::FlagsmithValueType; + + // Handle special operators that work across all types + match operator { + ConditionOperator::Modulo => return evaluate_modulo(&trait_value.value, condition_value), + ConditionOperator::Regex => return evaluate_regex(&trait_value.value, condition_value), + ConditionOperator::Contains => return trait_value.value.contains(condition_value), + ConditionOperator::NotContains => return !trait_value.value.contains(condition_value), + _ => {} + } + + // Use type-aware strategy based on trait value type + match trait_value.value_type { + FlagsmithValueType::Bool => compare_bool(operator, &trait_value.value, condition_value), + FlagsmithValueType::Integer => { + compare_integer(operator, &trait_value.value, condition_value) + } + FlagsmithValueType::Float => compare_float(operator, &trait_value.value, condition_value), + FlagsmithValueType::String => compare_string(operator, &trait_value.value, condition_value), + _ => false, + } +} + +/// Parses a boolean string value with optional integer conversion +/// NOTE: Historical engine behavior - only "1" is treated as true, "0" is NOT treated as false +fn parse_bool(s: &str, allow_int_conversion: bool) -> Option { + match s.to_lowercase().as_str() { + "true" => Some(true), + "1" if allow_int_conversion => Some(true), + "false" => Some(false), + _ => None, + } +} + +/// Compares boolean values +fn compare_bool(operator: &ConditionOperator, trait_value: &str, condition_value: &str) -> bool { + if let (Some(b1), Some(b2)) = ( + parse_bool(trait_value, true), + parse_bool(condition_value, true), + ) { + match operator { + ConditionOperator::Equal => b1 == b2, + ConditionOperator::NotEqual => b1 != b2, + _ => false, + } + } else { + false + } +} + +/// Compares integer values +fn compare_integer(operator: &ConditionOperator, trait_value: &str, condition_value: &str) -> bool { + if let (Ok(i1), Ok(i2)) = (trait_value.parse::(), condition_value.parse::()) { + dispatch_operator(operator, i1, i2) + } else { + false + } +} + +/// Compares float values +fn compare_float(operator: &ConditionOperator, trait_value: &str, condition_value: &str) -> bool { + if let (Ok(f1), Ok(f2)) = (trait_value.parse::(), condition_value.parse::()) { + dispatch_operator(operator, f1, f2) + } else { + false + } +} + +/// Compares string values, with special handling for semver +fn compare_string(operator: &ConditionOperator, trait_value: &str, condition_value: &str) -> bool { + // Check for semver comparison + if condition_value.ends_with(":semver") { + let version_str = &condition_value[..condition_value.len() - 7]; + if let Ok(condition_version) = Version::parse(version_str) { + return evaluate_semver(operator, trait_value, &condition_version); + } + return false; + } + + // Try parsing as boolean for string types (strict - no integer conversion) + if let (Some(b1), Some(b2)) = ( + parse_bool(trait_value, false), + parse_bool(condition_value, false), + ) { + return match operator { + ConditionOperator::Equal => b1 == b2, + ConditionOperator::NotEqual => b1 != b2, + _ => false, + }; + } + + // Try parsing as integer + if let (Ok(i1), Ok(i2)) = (trait_value.parse::(), condition_value.parse::()) { + return dispatch_operator(operator, i1, i2); + } + + // Try parsing as float + if let (Ok(f1), Ok(f2)) = (trait_value.parse::(), condition_value.parse::()) { + return dispatch_operator(operator, f1, f2); + } + + // Fall back to string comparison + dispatch_operator(operator, trait_value, condition_value) +} + +/// Dispatches the operator to the appropriate comparison function +fn dispatch_operator( + operator: &ConditionOperator, + v1: T, + v2: T, +) -> bool { + match operator { + ConditionOperator::Equal => v1 == v2, + ConditionOperator::NotEqual => v1 != v2, + ConditionOperator::GreaterThan => v1 > v2, + ConditionOperator::LessThan => v1 < v2, + ConditionOperator::GreaterThanInclusive => v1 >= v2, + ConditionOperator::LessThanInclusive => v1 <= v2, + _ => false, + } +} + +/// Evaluates regex matching +fn evaluate_regex(trait_value: &str, condition_value: &str) -> bool { + if let Ok(re) = Regex::new(condition_value) { + return re.is_match(trait_value); + } + false +} + +/// Evaluates modulo operation +fn evaluate_modulo(trait_value: &str, condition_value: &str) -> bool { + let values: Vec<&str> = condition_value.split('|').collect(); + if values.len() != 2 { + return false; + } + + let divisor = match values[0].parse::() { + Ok(v) => v, + Err(_) => return false, + }; + + let remainder = match values[1].parse::() { + Ok(v) => v, + Err(_) => return false, + }; + + let trait_value_float = match trait_value.parse::() { + Ok(v) => v, + Err(_) => return false, + }; + + // Use epsilon comparison for float equality to handle precision errors + const EPSILON: f64 = 1e-10; + ((trait_value_float % divisor) - remainder).abs() < EPSILON +} + +/// Evaluates semantic version comparisons +fn evaluate_semver( + operator: &ConditionOperator, + trait_value: &str, + condition_version: &Version, +) -> bool { + let trait_version = match Version::parse(trait_value) { + Ok(v) => v, + Err(_) => return false, + }; + + match operator { + ConditionOperator::Equal => trait_version == *condition_version, + ConditionOperator::NotEqual => trait_version != *condition_version, + ConditionOperator::GreaterThan => trait_version > *condition_version, + ConditionOperator::LessThan => trait_version < *condition_version, + ConditionOperator::GreaterThanInclusive => trait_version >= *condition_version, + ConditionOperator::LessThanInclusive => trait_version <= *condition_version, + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dispatch_operator_integers() { + assert!(dispatch_operator(&ConditionOperator::Equal, 5, 5)); + assert!(!dispatch_operator(&ConditionOperator::Equal, 5, 6)); + assert!(dispatch_operator(&ConditionOperator::GreaterThan, 6, 5)); + assert!(!dispatch_operator(&ConditionOperator::GreaterThan, 5, 6)); + } + + #[test] + fn test_evaluate_modulo() { + assert!(evaluate_modulo("2", "2|0")); + assert!(!evaluate_modulo("3", "2|0")); + assert!(evaluate_modulo("35.0", "4|3")); + } +} diff --git a/src/lib.rs b/src/lib.rs index babf9f0..dd4bae9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod engine; +pub mod engine_eval; pub mod environments; pub mod error; pub mod features; diff --git a/tests/engine_eval_test.rs b/tests/engine_eval_test.rs new file mode 100644 index 0000000..f3d0716 --- /dev/null +++ b/tests/engine_eval_test.rs @@ -0,0 +1,178 @@ +use flagsmith_flag_engine::engine::get_evaluation_result; +use flagsmith_flag_engine::engine_eval::{EngineEvaluationContext, EvaluationResult}; +use json_comments::StripComments; +use rstest::*; +use serde_json; +use std::fs; +use std::io::Read; + +#[rstest] +fn test_engine_evaluation() { + // Get all test files + let test_dir = "tests/engine_tests/engine-test-data/test_cases"; + let test_files = fs::read_dir(test_dir).expect("Failed to read test directory"); + + let mut test_count = 0; + let mut passed = 0; + let mut failed = 0; + let mut failed_tests = Vec::new(); + + for entry in test_files { + let entry = entry.expect("Failed to read directory entry"); + let path = entry.path(); + + // Only process JSON and JSONC files + let extension = path.extension().and_then(|s| s.to_str()); + if extension != Some("json") && extension != Some("jsonc") { + continue; + } + + let test_name = path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + + test_count += 1; + + // Read the test file + let test_data = + fs::read_to_string(&path).expect(&format!("Failed to read test file: {:?}", path)); + + // Strip comments if it's a JSONC file + let json_string = if extension == Some("jsonc") { + let mut stripped = String::new(); + StripComments::new(test_data.as_bytes()) + .read_to_string(&mut stripped) + .expect(&format!("Failed to strip comments from: {}", test_name)); + stripped + } else { + test_data + }; + + // Parse the JSON + let test_json: serde_json::Value = serde_json::from_str(&json_string) + .expect(&format!("Failed to parse JSON for: {}", test_name)); + + // Deserialize the context + let context_result: Result = + serde_json::from_value(test_json["context"].clone()); + + if context_result.is_err() { + println!( + "FAIL {}: Failed to deserialize context: {:?}", + test_name, + context_result.err() + ); + failed += 1; + failed_tests.push(test_name.to_string()); + continue; + } + + let context = context_result.unwrap(); + + // Get the evaluation result + let result = get_evaluation_result(&context); + + // Deserialize the expected result + let expected_result: Result = + serde_json::from_value(test_json["result"].clone()); + + if expected_result.is_err() { + println!( + "FAIL {}: Failed to deserialize expected result: {:?}", + test_name, + expected_result.err() + ); + failed += 1; + failed_tests.push(test_name.to_string()); + continue; + } + + let expected = expected_result.unwrap(); + + // Compare results + if compare_evaluation_results(&result, &expected, test_name) { + passed += 1; + println!("PASS {}", test_name); + } else { + failed += 1; + failed_tests.push(test_name.to_string()); + } + } + + // Print summary + println!("\n========== TEST SUMMARY =========="); + println!("Total tests: {}", test_count); + println!("Passed: {} ({}%)", passed, (passed * 100) / test_count); + println!("Failed: {} ({}%)", failed, (failed * 100) / test_count); + + if !failed_tests.is_empty() { + println!("\nFailed tests:"); + for test in &failed_tests { + println!(" - {}", test); + } + } + + println!("==================================\n"); + + // Assert that all tests passed + assert_eq!(failed, 0, "{} out of {} tests failed", failed, test_count); +} + +fn compare_evaluation_results( + result: &EvaluationResult, + expected: &EvaluationResult, + test_name: &str, +) -> bool { + let mut success = true; + + // Compare flags + if result.flags.len() != expected.flags.len() { + println!( + "FAIL {}: Flag count mismatch - got {}, expected {}", + test_name, + result.flags.len(), + expected.flags.len() + ); + success = false; + } + + for (flag_name, expected_flag) in &expected.flags { + match result.flags.get(flag_name) { + None => { + println!("FAIL {}: Missing flag: {}", test_name, flag_name); + success = false; + } + Some(actual_flag) => { + if actual_flag.enabled != expected_flag.enabled { + println!( + "FAIL {}: Flag '{}' enabled mismatch - got {}, expected {}", + test_name, flag_name, actual_flag.enabled, expected_flag.enabled + ); + success = false; + } + + if actual_flag.value != expected_flag.value { + println!( + "FAIL {}: Flag '{}' value mismatch - got {:?}, expected {:?}", + test_name, flag_name, actual_flag.value, expected_flag.value + ); + success = false; + } + } + } + } + + // Compare segments + if result.segments.len() != expected.segments.len() { + println!( + "FAIL {}: Segment count mismatch - got {}, expected {}", + test_name, + result.segments.len(), + expected.segments.len() + ); + success = false; + } + + success +} diff --git a/tests/engine_tests/engine-test-data b/tests/engine_tests/engine-test-data index 71a9631..4fe4c8d 160000 --- a/tests/engine_tests/engine-test-data +++ b/tests/engine_tests/engine-test-data @@ -1 +1 @@ -Subproject commit 71a963198d66d681d12f2bf92c42a3036ffe92a7 +Subproject commit 4fe4c8dc80cb0e165f679491b19cde11a357811c diff --git a/tests/engine_tests/engine_tests.rs b/tests/engine_tests/engine_tests.rs deleted file mode 100644 index ba92c44..0000000 --- a/tests/engine_tests/engine_tests.rs +++ /dev/null @@ -1,103 +0,0 @@ -use core::panic; -use flagsmith_flag_engine::engine; -use flagsmith_flag_engine::environments; -use flagsmith_flag_engine::environments::builders::build_environment_struct; -use flagsmith_flag_engine::identities; -use flagsmith_flag_engine::identities::builders::build_identity_struct; -use flagsmith_flag_engine::types::FlagsmithValueType; -use std::fs::File; -use std::io::BufReader; -use std::path::PathBuf; - -use rstest::*; - -#[fixture] -fn test_json() -> serde_json::Value { - // First, Let's convert the json file to serde value - let file_path = - "tests/engine_tests/engine-test-data/data/environment_n9fbf9h3v4fFgH3U3ngWhb.json"; - - let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - d.push(file_path); - let file = File::open(file_path).unwrap(); - let reader = BufReader::new(file); - let value: serde_json::Value = serde_json::from_reader(reader).unwrap(); - return value; -} - -#[rstest] -fn test_engine(test_json: serde_json::Value) { - fn check( - environment: &environments::Environment, - identity: &identities::Identity, - expected_response: serde_json::Value, - ) { - // Given - let expected_flags = expected_response["flags"].as_array().unwrap(); - - // When - let mut engine_response = - engine::get_identity_feature_states(&environment, &identity, None); - // Sort the feature states so that we can iterate over it and compare them with expected response - engine_response.sort_by_key(|fs| fs.feature.name.clone()); - - // Then - // engine returned right number of feature states - assert_eq!(engine_response.len(), expected_flags.len()); - for (index, fs) in engine_response.iter().enumerate() { - // and the values and enabled status of each of the feature states returned by the - // engine is as expected - assert_eq!( - fs.enabled, - expected_flags[index]["enabled"].as_bool().unwrap() - ); - - let identity_id = match identity.django_id { - Some(id) => id.to_string(), - None => identity.identity_uuid.clone(), - }; - - let fs_value = fs.get_value(Some(&identity_id)); - - match fs_value.value_type { - FlagsmithValueType::Bool => assert_eq!( - fs_value.value.parse::().unwrap(), - expected_flags[index]["feature_state_value"] - .as_bool() - .unwrap() - ), - FlagsmithValueType::Integer => assert_eq!( - fs_value.value.parse::().unwrap(), - expected_flags[index]["feature_state_value"] - .as_i64() - .unwrap() - ), - FlagsmithValueType::String => assert_eq!( - fs_value.value, - expected_flags[index]["feature_state_value"] - .as_str() - .unwrap() - ), - FlagsmithValueType::None => assert_eq!( - (), - expected_flags[index]["feature_state_value"] - .as_null() - .unwrap() - ), - FlagsmithValueType::Float => { - panic!("Floats are not allowed for feature state value") - } - } - } - } - let environment = build_environment_struct(test_json["environment"].clone()); - - for identity_and_response in test_json["identities_and_responses"].as_array().unwrap() { - let identity = build_identity_struct(identity_and_response["identity"].clone()); - check( - &environment, - &identity, - identity_and_response["response"].clone(), - ); - } -} diff --git a/tests/tests.rs b/tests/tests.rs deleted file mode 100644 index 348f4bb..0000000 --- a/tests/tests.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod engine_tests { - mod engine_tests; // Run engine integration tests -} From b95fec7399a76c8ad6e5654f7efbede54b857b6d Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Thu, 6 Nov 2025 11:59:35 +0530 Subject: [PATCH 2/8] fix multivariate segment override --- src/engine.rs | 24 ++++++++++-------------- tests/engine_tests/engine-test-data | 2 +- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 3464f49..8441c25 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -86,23 +86,18 @@ fn get_flag_results( for feature_context in ec.features.values() { // Check if we have a segment override for this feature if let Some(segment_fc) = segment_feature_contexts.get(&feature_context.name) { - // Use segment override + // Use segment override with multivariate evaluation let fc = &segment_fc.feature_context; let reason = format!("TARGETING_MATCH; segment={}", segment_fc.segment_name); - flags.insert( - feature_context.name.clone(), - FlagResult { - enabled: fc.enabled, - name: fc.name.clone(), - reason, - value: fc.value.clone(), - metadata: fc.metadata.clone(), - }, - ); + let flag_result = get_flag_result_from_feature_context(fc, identity_key.as_ref(), reason); + flags.insert(feature_context.name.clone(), flag_result); } else { // Use default feature context - let flag_result = - get_flag_result_from_feature_context(feature_context, identity_key.as_ref()); + let flag_result = get_flag_result_from_feature_context( + feature_context, + identity_key.as_ref(), + "DEFAULT".to_string(), + ); flags.insert(feature_context.name.clone(), flag_result); } } @@ -124,8 +119,9 @@ pub fn get_evaluation_result(ec: &EngineEvaluationContext) -> EvaluationResult { fn get_flag_result_from_feature_context( feature_context: &FeatureContext, identity_key: Option<&String>, + default_reason: String, ) -> FlagResult { - let mut reason = "DEFAULT".to_string(); + let mut reason = default_reason; let mut value = feature_context.value.clone(); // Handle multivariate features diff --git a/tests/engine_tests/engine-test-data b/tests/engine_tests/engine-test-data index 4fe4c8d..dc60562 160000 --- a/tests/engine_tests/engine-test-data +++ b/tests/engine_tests/engine-test-data @@ -1 +1 @@ -Subproject commit 4fe4c8dc80cb0e165f679491b19cde11a357811c +Subproject commit dc6056289662aea2a8ef017fdf799f7b601a0f43 From b691ee352a1e5e39f4033c84804ec1eb977907d5 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Thu, 6 Nov 2025 13:56:32 +0530 Subject: [PATCH 3/8] feat: Add ConditionValue enum - Implement ConditionValue enum to handle string or array values - Supports Single(String) and Multiple(Vec) variants - Custom deserializer handles JSON arrays, JSON array strings, and comma-separated strings - Helper methods: as_string(), as_vec(), contains_string() - Simplify IN operator implementation - Use ConditionValue's contains_string() for string matching - Remove redundant JSON array parsing logic --- src/engine.rs | 3 +- src/engine_eval/context.rs | 89 ++++++++++++++++++++++------ src/engine_eval/mappers.rs | 4 +- src/engine_eval/segment_evaluator.rs | 27 ++------- 4 files changed, 80 insertions(+), 43 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 8441c25..ac48835 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -89,7 +89,8 @@ fn get_flag_results( // Use segment override with multivariate evaluation let fc = &segment_fc.feature_context; let reason = format!("TARGETING_MATCH; segment={}", segment_fc.segment_name); - let flag_result = get_flag_result_from_feature_context(fc, identity_key.as_ref(), reason); + let flag_result = + get_flag_result_from_feature_context(fc, identity_key.as_ref(), reason); flags.insert(feature_context.name.clone(), flag_result); } else { // Use default feature context diff --git a/src/engine_eval/context.rs b/src/engine_eval/context.rs index ea42630..1e545b4 100644 --- a/src/engine_eval/context.rs +++ b/src/engine_eval/context.rs @@ -87,21 +87,77 @@ pub enum ConditionOperator { IsNotSet, } -// Helper function to deserialize value that can be a string or array -fn deserialize_condition_value<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - use serde_json::Value; - let value: Value = serde::Deserialize::deserialize(deserializer)?; - Ok(match value { - Value::String(s) => s, - Value::Array(_) | Value::Object(_) | Value::Number(_) | Value::Bool(_) => { - // Serialize non-string values back to JSON string - serde_json::to_string(&value).unwrap_or_default() +/// Represents a condition value that can be either a single string or an array of strings. +#[derive(Clone, Debug, Serialize)] +#[serde(untagged)] +pub enum ConditionValue { + /// Multiple values as an array + Multiple(Vec), + /// Single value as a string + Single(String), +} + +impl<'de> serde::Deserialize<'de> for ConditionValue { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde_json::Value; + let value: Value = serde::Deserialize::deserialize(deserializer)?; + + match value { + // If it's already an array, use Multiple + Value::Array(arr) => { + let strings: Vec = arr + .into_iter() + .map(|v| match v { + Value::String(s) => s, + _ => v.to_string(), + }) + .collect(); + Ok(ConditionValue::Multiple(strings)) + } + // If it's a string, check if it's a JSON array string + Value::String(s) => { + if s.trim().starts_with('[') { + // Try to parse as JSON array + if let Ok(arr) = serde_json::from_str::>(&s) { + return Ok(ConditionValue::Multiple(arr)); + } + } + // Otherwise treat as single string + Ok(ConditionValue::Single(s)) + } + // For other types, convert to string + _ => Ok(ConditionValue::Single(value.to_string())), + } + } +} + +impl ConditionValue { + /// Get the value as a single string (joins arrays with comma) + pub fn as_string(&self) -> String { + match self { + ConditionValue::Single(s) => s.clone(), + ConditionValue::Multiple(arr) => arr.join(","), + } + } + + /// Get values as a Vec (splits single strings by comma, or returns array as-is) + pub fn as_vec(&self) -> Vec { + match self { + ConditionValue::Single(s) => s.split(',').map(|s| s.trim().to_string()).collect(), + ConditionValue::Multiple(arr) => arr.clone(), + } + } + + /// Check if value contains a string (for string-based IN operator) + pub fn contains_string(&self, search: &str) -> bool { + match self { + ConditionValue::Single(s) => s.split(',').any(|v| v.trim() == search), + ConditionValue::Multiple(arr) => arr.iter().any(|v| v == search), } - Value::Null => String::new(), - }) + } } /// Represents a condition for segment rule evaluation. @@ -111,9 +167,8 @@ pub struct Condition { pub operator: ConditionOperator, /// The property to evaluate (can be a JSONPath expression starting with $.). pub property: String, - /// The value to compare against (can be a string or serialized JSON). - #[serde(deserialize_with = "deserialize_condition_value")] - pub value: String, + /// The value to compare against (can be a string or array of strings). + pub value: ConditionValue, } /// Segment rule types. diff --git a/src/engine_eval/mappers.rs b/src/engine_eval/mappers.rs index dc2afcf..a1ca476 100644 --- a/src/engine_eval/mappers.rs +++ b/src/engine_eval/mappers.rs @@ -129,7 +129,7 @@ fn map_segment_rule_to_rule(rule: &OldSegmentRule) -> SegmentRule { .map(|c| Condition { operator: map_operator(&c.operator), property: c.property.clone().unwrap_or_default(), - value: c.value.clone().unwrap_or_default(), + value: super::context::ConditionValue::Single(c.value.clone().unwrap_or_default()), }) .collect(); @@ -245,7 +245,7 @@ fn map_identity_overrides_to_segments(identities: &[Identity]) -> HashMap context_value.is_some(), _ => { if let Some(ref ctx_val) = context_value { - parse_and_match(&condition.operator, ctx_val, &condition.value) + parse_and_match(&condition.operator, ctx_val, &condition.value.as_string()) } else { false } @@ -189,7 +189,7 @@ fn match_percentage_split( segment_key: &str, context_value: Option<&FlagsmithValue>, ) -> bool { - let float_value = match condition.value.parse::() { + let float_value = match condition.value.as_string().parse::() { Ok(v) => v, Err(_) => return false, }; @@ -224,27 +224,8 @@ fn match_in_operator(condition: &Condition, context_value: Option<&FlagsmithValu let trait_value = &ctx_value.value; - // Check if the value is in JSON array format (starts with '[') - if condition.value.trim().starts_with('[') { - // Try to parse as JSON array - if let Ok(array) = serde_json::from_str::>(&condition.value) { - return array.iter().any(|v| { - if let Some(s) = v.as_str() { - s == trait_value - } else if let Some(n) = v.as_i64() { - n.to_string() == *trait_value - } else if let Some(n) = v.as_f64() { - n.to_string() == *trait_value - } else { - false - } - }); - } - } - - // Fall back to comma-separated format - let values: Vec<&str> = condition.value.split(',').collect(); - values.contains(&trait_value.as_str()) + // Use the ConditionValue's contains_string method for simple string matching + condition.value.contains_string(trait_value) } /// Parses and matches values based on the operator using type-aware strategy From 0525956bfdb549ea9692fac06a89677367e72534 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Thu, 6 Nov 2025 14:00:16 +0530 Subject: [PATCH 4/8] chore: Fix clippy warnings in engine_eval module - Fix collapsible_if warning: collapse nested if statements - Fix manual_strip warning: use strip_suffix instead of manual slicing - Fix unwrap_or_default warning: use or_default() instead of or_insert_with(Vec::new) --- src/engine_eval/mappers.rs | 2 +- src/engine_eval/segment_evaluator.rs | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/engine_eval/mappers.rs b/src/engine_eval/mappers.rs index a1ca476..871c7ca 100644 --- a/src/engine_eval/mappers.rs +++ b/src/engine_eval/mappers.rs @@ -219,7 +219,7 @@ fn map_identity_overrides_to_segments(identities: &[Identity]) -> HashMap bool { // Check conditions if present - if !rule.conditions.is_empty() { - if !matches_conditions_by_rule_type(ec, &rule.conditions, &rule.rule_type, segment_key) { - return false; - } + if !rule.conditions.is_empty() + && !matches_conditions_by_rule_type(ec, &rule.conditions, &rule.rule_type, segment_key) + { + return false; } // Check nested rules @@ -305,8 +305,7 @@ fn compare_float(operator: &ConditionOperator, trait_value: &str, condition_valu /// Compares string values, with special handling for semver fn compare_string(operator: &ConditionOperator, trait_value: &str, condition_value: &str) -> bool { // Check for semver comparison - if condition_value.ends_with(":semver") { - let version_str = &condition_value[..condition_value.len() - 7]; + if let Some(version_str) = condition_value.strip_suffix(":semver") { if let Ok(condition_version) = Version::parse(version_str) { return evaluate_semver(operator, trait_value, &condition_version); } From d2d3ca3803c8ac944c1446bb7ea320e56c6929ca Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Fri, 7 Nov 2025 10:30:37 +0530 Subject: [PATCH 5/8] test: Add reason field validation in engine evaluation tests - Compare flag reason field in addition to enabled and value - Ensure evaluation reasons (DEFAULT, TARGETING_MATCH, SPLIT) are correct --- tests/engine_eval_test.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/engine_eval_test.rs b/tests/engine_eval_test.rs index f3d0716..2f21f6d 100644 --- a/tests/engine_eval_test.rs +++ b/tests/engine_eval_test.rs @@ -159,6 +159,14 @@ fn compare_evaluation_results( ); success = false; } + + if actual_flag.reason != expected_flag.reason { + println!( + "FAIL {}: Flag '{}' reason mismatch - got '{}', expected '{}'", + test_name, flag_name, actual_flag.reason, expected_flag.reason + ); + success = false; + } } } } From e47bb72ff68c41d06c452e77a655d3f6076f8bba Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Fri, 7 Nov 2025 11:16:28 +0530 Subject: [PATCH 6/8] refactor: improve engine_tests --- src/engine_eval/context.rs | 4 +-- src/engine_eval/result.rs | 6 ++-- tests/engine_eval_test.rs | 68 ++++++++------------------------------ 3 files changed, 18 insertions(+), 60 deletions(-) diff --git a/src/engine_eval/context.rs b/src/engine_eval/context.rs index 1e545b4..7d5e592 100644 --- a/src/engine_eval/context.rs +++ b/src/engine_eval/context.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; /// Represents metadata information about a feature. -#[derive(Clone, Debug, Serialize, Deserialize, Default)] +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq)] pub struct FeatureMetadata { /// The feature ID. #[serde(default)] @@ -195,7 +195,7 @@ pub struct SegmentRule { } /// Segment metadata. -#[derive(Clone, Debug, Serialize, Deserialize, Default)] +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq)] pub struct SegmentMetadata { /// Segment ID. #[serde(skip_serializing_if = "Option::is_none")] diff --git a/src/engine_eval/result.rs b/src/engine_eval/result.rs index 0cc6cd7..8f65635 100644 --- a/src/engine_eval/result.rs +++ b/src/engine_eval/result.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; /// Represents the result of a feature flag evaluation. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct EvaluationResult { /// Map of feature names to their evaluated flag results. pub flags: HashMap, @@ -15,7 +15,7 @@ pub struct EvaluationResult { } /// Represents the evaluated result for a single feature flag. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct FlagResult { /// Whether the feature is enabled. pub enabled: bool, @@ -35,7 +35,7 @@ pub struct FlagResult { } /// Represents a segment that matched during evaluation. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct SegmentResult { /// The segment name. pub name: String, diff --git a/tests/engine_eval_test.rs b/tests/engine_eval_test.rs index 2f21f6d..837a80f 100644 --- a/tests/engine_eval_test.rs +++ b/tests/engine_eval_test.rs @@ -124,63 +124,21 @@ fn compare_evaluation_results( expected: &EvaluationResult, test_name: &str, ) -> bool { - let mut success = true; - - // Compare flags - if result.flags.len() != expected.flags.len() { - println!( - "FAIL {}: Flag count mismatch - got {}, expected {}", - test_name, - result.flags.len(), - expected.flags.len() - ); - success = false; + // Compare flags - simple equality check + if result.flags != expected.flags { + println!("FAIL {}: Flags mismatch", test_name); + println!(" Expected flags: {:?}", expected.flags); + println!(" Actual flags: {:?}", result.flags); + return false; } - for (flag_name, expected_flag) in &expected.flags { - match result.flags.get(flag_name) { - None => { - println!("FAIL {}: Missing flag: {}", test_name, flag_name); - success = false; - } - Some(actual_flag) => { - if actual_flag.enabled != expected_flag.enabled { - println!( - "FAIL {}: Flag '{}' enabled mismatch - got {}, expected {}", - test_name, flag_name, actual_flag.enabled, expected_flag.enabled - ); - success = false; - } - - if actual_flag.value != expected_flag.value { - println!( - "FAIL {}: Flag '{}' value mismatch - got {:?}, expected {:?}", - test_name, flag_name, actual_flag.value, expected_flag.value - ); - success = false; - } - - if actual_flag.reason != expected_flag.reason { - println!( - "FAIL {}: Flag '{}' reason mismatch - got '{}', expected '{}'", - test_name, flag_name, actual_flag.reason, expected_flag.reason - ); - success = false; - } - } - } - } - - // Compare segments - if result.segments.len() != expected.segments.len() { - println!( - "FAIL {}: Segment count mismatch - got {}, expected {}", - test_name, - result.segments.len(), - expected.segments.len() - ); - success = false; + // Compare segments - simple equality check + if result.segments != expected.segments { + println!("FAIL {}: Segments mismatch", test_name); + println!(" Expected segments: {:?}", expected.segments); + println!(" Actual segments: {:?}", result.segments); + return false; } - success + true } From 5138a332cfcdd9953b24c3c47489b16d263cbedf Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Fri, 7 Nov 2025 11:17:55 +0530 Subject: [PATCH 7/8] process segment in order --- src/engine.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index ac48835..5d3782d 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -25,8 +25,14 @@ fn get_matching_segments_and_overrides( let mut segments = Vec::new(); let mut segment_feature_contexts: HashMap = HashMap::new(); - // Process segments - for segment_context in ec.segments.values() { + // Sort segment keys for deterministic ordering + let mut segment_keys: Vec<_> = ec.segments.keys().collect(); + segment_keys.sort(); + + // Process segments in sorted order + for segment_key in segment_keys { + let segment_context = &ec.segments[segment_key]; + if !is_context_in_segment(ec, segment_context) { continue; } From fcc75c1af3b7ec94d3d985fe70695581f9a8d3c5 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Fri, 7 Nov 2025 11:59:46 +0530 Subject: [PATCH 8/8] better error messages on test failure --- tests/engine_eval_test.rs | 109 ++++++++------------------------------ 1 file changed, 21 insertions(+), 88 deletions(-) diff --git a/tests/engine_eval_test.rs b/tests/engine_eval_test.rs index 837a80f..acf76c3 100644 --- a/tests/engine_eval_test.rs +++ b/tests/engine_eval_test.rs @@ -12,11 +12,6 @@ fn test_engine_evaluation() { let test_dir = "tests/engine_tests/engine-test-data/test_cases"; let test_files = fs::read_dir(test_dir).expect("Failed to read test directory"); - let mut test_count = 0; - let mut passed = 0; - let mut failed = 0; - let mut failed_tests = Vec::new(); - for entry in test_files { let entry = entry.expect("Failed to read directory entry"); let path = entry.path(); @@ -32,8 +27,6 @@ fn test_engine_evaluation() { .and_then(|n| n.to_str()) .unwrap_or("unknown"); - test_count += 1; - // Read the test file let test_data = fs::read_to_string(&path).expect(&format!("Failed to read test file: {:?}", path)); @@ -54,91 +47,31 @@ fn test_engine_evaluation() { .expect(&format!("Failed to parse JSON for: {}", test_name)); // Deserialize the context - let context_result: Result = - serde_json::from_value(test_json["context"].clone()); - - if context_result.is_err() { - println!( - "FAIL {}: Failed to deserialize context: {:?}", - test_name, - context_result.err() - ); - failed += 1; - failed_tests.push(test_name.to_string()); - continue; - } - - let context = context_result.unwrap(); + let context: EngineEvaluationContext = serde_json::from_value(test_json["context"].clone()) + .unwrap_or_else(|e| panic!("Failed to deserialize context for {}: {:?}", test_name, e)); // Get the evaluation result let result = get_evaluation_result(&context); // Deserialize the expected result - let expected_result: Result = - serde_json::from_value(test_json["result"].clone()); - - if expected_result.is_err() { - println!( - "FAIL {}: Failed to deserialize expected result: {:?}", - test_name, - expected_result.err() - ); - failed += 1; - failed_tests.push(test_name.to_string()); - continue; - } - - let expected = expected_result.unwrap(); - - // Compare results - if compare_evaluation_results(&result, &expected, test_name) { - passed += 1; - println!("PASS {}", test_name); - } else { - failed += 1; - failed_tests.push(test_name.to_string()); - } + let expected: EvaluationResult = serde_json::from_value(test_json["result"].clone()) + .unwrap_or_else(|e| { + panic!( + "Failed to deserialize expected result for {}: {:?}", + test_name, e + ) + }); + + // Compare results - panic immediately on mismatch + assert_eq!( + result.flags, expected.flags, + "Flags mismatch in {}", + test_name + ); + assert_eq!( + result.segments, expected.segments, + "Segments mismatch in {}", + test_name + ); } - - // Print summary - println!("\n========== TEST SUMMARY =========="); - println!("Total tests: {}", test_count); - println!("Passed: {} ({}%)", passed, (passed * 100) / test_count); - println!("Failed: {} ({}%)", failed, (failed * 100) / test_count); - - if !failed_tests.is_empty() { - println!("\nFailed tests:"); - for test in &failed_tests { - println!(" - {}", test); - } - } - - println!("==================================\n"); - - // Assert that all tests passed - assert_eq!(failed, 0, "{} out of {} tests failed", failed, test_count); -} - -fn compare_evaluation_results( - result: &EvaluationResult, - expected: &EvaluationResult, - test_name: &str, -) -> bool { - // Compare flags - simple equality check - if result.flags != expected.flags { - println!("FAIL {}: Flags mismatch", test_name); - println!(" Expected flags: {:?}", expected.flags); - println!(" Actual flags: {:?}", result.flags); - return false; - } - - // Compare segments - simple equality check - if result.segments != expected.segments { - println!("FAIL {}: Segments mismatch", test_name); - println!(" Expected segments: {:?}", expected.segments); - println!(" Actual segments: {:?}", result.segments); - return false; - } - - true }