import numpy as np
import pandas as pd
from scipy import stats
from statsmodels.stats.multitest import multipletests
from typing import Dict, List, Tuple, Optional, Union, Any
from dataclasses import dataclass
from scipy.stats.contingency import association
import warnings
[docs]
@dataclass
class StatTestResult:
"""Stores statistical test results including test statistic, p-value, and significance."""
statistic: float
p_value: float
is_significant: bool
test_name: str
critical_value: Optional[float] = None
effect_size: Optional[float] = None
confidence_interval: Optional[Tuple[float, float]] = None
def __post_init__(self):
# Ensure is_significant is a Python bool
self.is_significant = bool(self.is_significant)
[docs]
class StatisticalTester:
"""Performs statistical significance testing on metrics with support for various tests and data types."""
AVAILABLE_TESTS = {
"chi_square": "Chi-square test",
"bootstrap_test": "Bootstrap test",
}
ADJUSTMENT_METHODS = {
"bonferroni": "Bonferroni correction",
"fdr_bh": "Benjamini-Hochberg FDR",
"holm": "Holm-Bonferroni",
"none": "No correction",
}
def __init__(self):
"""Initializes StatisticalTester with default test implementations."""
self._test_implementations = {
"chi_square": self._chi_square_test,
"bootstrap_test": self._bootstrap_test,
}
self.METRIC_LIST = [
"Recall",
"Precision",
"Accuracy",
"F1 Score",
"Specificity",
"FP Rate",
"FN Rate",
"Predicted Prevalence",
]
def _bootstrap_test(self, data: List[float], config: dict) -> List[float]:
### assumption that with sufficiently large data we can assume the bootstrapped samples are normal
if len(data) >= 5000:
is_normal = True
else:
is_normal = False
Warning("Data is not normal. Try more bootstraps >=5000")
mu = np.mean(data)
sigma = np.std(data)
lower, higher = self.get_ci_bounds(config)
if is_normal:
ci_lower, ci_upper = np.percentile(data, [lower, higher])
else:
se = sigma
ci_lower, ci_upper = stats.norm.interval(
config["confidence_level"], loc=mu, scale=se
)
warnings.warn(
"Warning: Calculation may not be correct, please increase number of bootstraps"
)
# Does CI cross zero?
if ci_lower <= 0 <= ci_upper:
is_significant = False
else:
is_significant = True
if is_normal:
p_value = self.calc_p_value_bootstrap(data, config)
else:
mu_0 = 0 # Null hypothesis value
z = (mu - mu_0) / sigma
p_value = 2 * (1 - stats.norm.cdf(abs(z)))
### effect size is set as zero if the pooled std is 0
### this could actually mean effect size is inf
# effect_size = self.cohens_d(data)
# 6. Return StatisticalTestResult object
return StatTestResult(
statistic=mu,
p_value=p_value,
is_significant=is_significant,
test_name="bootstrap_mean",
confidence_interval=(ci_lower, ci_upper),
# effect_size=effect_size,
)
[docs]
def get_ci_bounds(self, config: dict) -> tuple:
"""Get confidence interval bounds based on tail type"""
tail_type = config["tail_type"]
if tail_type == "two_tailed":
lower = (config["alpha"] / 2) * 100
higher = (1 - (config["alpha"] / 2)) * 100
elif tail_type == "one_tail_less":
lower = 0
higher = config["alpha"] * 100
elif tail_type == "one_tail_greater":
lower = (1 - config["alpha"]) * 100
higher = 100
else:
raise ValueError(
"Must specify two-tailed, one-tail-less or one-tail-greater for the tail_type"
)
return lower, higher
[docs]
def calc_p_value_bootstrap(self, data: list, config: dict) -> float:
"""Calculating the p-value using the data and config"""
tail_type = config["tail_type"]
# one-tailed test
# left sided p_value test
# the 0 is our t / z value
p_value = len([num for num in data if num < 0]) / len(data)
if p_value > 1 - (config["alpha"] / 2):
# right sided p_value test
p_value = 1 - p_value
elif tail_type == "two-tailed":
## assuming symmetric dist
p_value *= 2
return p_value
def _chi_square_test(
self,
metrics: Dict[str, Any],
config: Dict[str, Any],
) -> StatTestResult:
"""Performs Chi-square test for categorical data.
Args:
metrics: Metrics of CM in a dictionary
config: Configuration dictionary containing test parameters
Returns:
StatTestResult object containing test results
"""
# Convert to numpy arrays
data = pd.DataFrame(metrics)
statistical_test_dict = {}
data = data.T
for metric in self.METRIC_LIST:
contingency_table = self.get_contingency_table(data, metric)
chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)
test_name = "Chi-Square Test"
# Cochran's rule: chi-square approximation is unreliable when more
# than 20% of expected cell counts are < 5. Use Fisher's exact (2x2)
# or warn (larger).
# See Kim HY (2017), "Statistical notes for clinical researchers:
# Chi-squared test and Fisher's exact test." Restor Dent Endod
# 42(2):152-155. https://pmc.ncbi.nlm.nih.gov/articles/PMC5426219/
n_cells = expected.size
low_expected_pct = (expected < 5).sum() / n_cells
if low_expected_pct > 0.20:
if contingency_table.shape == (2, 2):
# Fisher's exact is well-defined for 2x2, swap in transparently
_, p_value = stats.fisher_exact(contingency_table)
test_name = "Fisher's Exact Test"
chi2 = np.nan # not a chi2 statistic anymore
else:
warnings.warn(
f"Metric '{metric}': {int(low_expected_pct * 100)}% of "
f"expected cell counts < 5 (min expected = "
f"{expected.min():.2f}). Chi-square approximation may be "
f"unreliable for this {contingency_table.shape[0]} x "
f"{contingency_table.shape[1]} table per Cochran's rule. "
f"Consider Fisher's exact test."
)
statistical_test_dict[metric] = StatTestResult(
statistic=chi2,
p_value=p_value,
is_significant=p_value <= config.get("alpha", 0.05),
test_name=test_name,
)
return statistical_test_dict
def get_contingency_table(self, data, metric):
## Accuracy contingency table example ## Recall contingency table example
## |TP + FN| FN + TN | ## |TP | FN |
## |-------|---------| ## |---|----|
## GRP1 | | | ## GRP1 | | |
## GRP2 | | | ## GRP2 | | |
if metric == "Accuracy":
data["TP + TN"] = data["TN"] + data["TP"]
data["FP + FN"] = data["FP"] + data["FN"]
table = data[["TP + TN", "FP + FN"]]
return table
elif metric == "Precision":
table = data[["TP", "FP"]]
return table
elif metric == "Recall":
table = data[["TP", "FN"]]
return table
elif metric == "F1 Score":
data["FP + FN"] = data["FP"] + data["FN"]
table = data[["TP", "FP + FN"]]
return table
elif metric == "Specificity":
table = data[["TN", "FP"]]
return table
elif metric == "FN Rate":
table = data[["FN", "TP"]]
return table
elif metric == "FP Rate":
table = data[["FP", "TN"]]
return table
elif metric == "Predicted Prevalence":
data["TP + FP"] = data["TP"] + data["FP"]
data["FN + TN"] = data["FN"] + data["TN"]
table = data[["TP + FP", "FN + TN"]]
return table
elif metric == "Negative Predictive Value":
table = data[["TN", "FN"]]
return table
def _calculate_effect_size(self, metrics: Dict, metric: str) -> float:
"""Calculates the Cramer's V effect size using scipy.
Returns:
float: Cramer's V effect size
"""
data = pd.DataFrame(metrics)
data = data.T
coningency_table = self.get_contingency_table(data, metric)
effect_size = association(coningency_table, method="cramer")
return effect_size
def _adjust_p_values(
self,
results: Dict[str, Dict[str, StatTestResult]],
method: str,
alpha: float,
boot: bool = False,
) -> Dict[str, Dict[str, StatTestResult]]:
"""Adjusts p-values for multiple comparisons using specified method."""
if method not in ("bonferroni", "fdr_bh", "holm"):
return results
if boot:
# When boot=True, results has nested structure: results[group][test_type]
# Adjust across all group/test pairs together (existing behavior)
p_values = []
group_test_pairs = []
for group, group_results in results.items():
for test_type, test_result in group_results.items():
p_values.append(test_result.p_value)
group_test_pairs.append((group, test_type))
adjusted_p_values = multipletests(p_values, alpha=alpha, method=method)[1]
for idx, (group, test_type) in enumerate(group_test_pairs):
results[group][test_type].p_value = adjusted_p_values[idx]
results[group][test_type].is_significant = (
adjusted_p_values[idx] <= alpha
)
else:
# Non-bootstrapped: adjust per metric across pairwise (non-omnibus) groups only
pairwise_groups = [g for g in results.keys() if g != "omnibus"]
if not pairwise_groups:
return results
# Collect all metrics present across the pairwise groups
metrics = set()
for group in pairwise_groups:
metrics.update(results[group].keys())
# Adjust p-values per metric, across race groups
for metric in metrics:
p_values = []
groups_for_metric = []
for group in pairwise_groups:
if metric in results[group]:
p_values.append(results[group][metric].p_value)
groups_for_metric.append(group)
if not p_values:
continue
adjusted_p_values = multipletests(p_values, alpha=alpha, method=method)[
1
]
for idx, group in enumerate(groups_for_metric):
results[group][metric].p_value = adjusted_p_values[idx]
results[group][metric].is_significant = (
adjusted_p_values[idx] <= alpha
)
return results
[docs]
def analyze_metrics(
self,
metrics_data: Union[Dict, List[Dict]],
reference_group: str,
test_config: Dict[str, Any],
task: Optional[str] = None,
differences: Optional[dict] = None,
) -> Dict[str, Dict[str, StatTestResult]]:
"""Analyzes metrics for statistical significance against a reference group."""
config = {**test_config}
self._validate_config(config)
if isinstance(metrics_data, list):
results = self._analyze_bootstrapped_metrics(
differences, reference_group, config
)
else:
if task == "binary_classification":
results = self._analyze_single_metrics(
metrics_data, reference_group, config
)
else:
raise ValueError(
"Task not supported for non-bootstrapped metrics. "
"Use bootstrapped metrics."
)
## Adjust p values here b/c we now account for bootstrap within
if config["adjust_method"] != "none":
results = self.adjusting_p_vals(config, results)
return results
[docs]
def adjusting_p_vals(self, config, results):
"""Runs the adjusting p value method based on bootstrap conditions"""
if config["test_type"] == "bootstrap_test":
boot = True
else:
boot = False
# Avoid running this command if results have a len of 1; then
if len(results) > 1:
# Adjust p-values for multiple comparisons
adjusted_results = self._adjust_p_values(
results, config["adjust_method"], config["alpha"], boot=boot
)
return adjusted_results
else:
# No adjustment needed for single comparison
return results
def _validate_config(self, config: Dict[str, Any]):
"""Validates the configuration dictionary for required keys and values."""
required_keys = ["test_type", "alpha"]
for key in required_keys:
if key not in config:
raise ValueError(f"Missing required configuration key: {key}")
if config["test_type"] not in self.AVAILABLE_TESTS:
raise ValueError(
f"Invalid test type: {config['test_type']}. Available tests: {self.AVAILABLE_TESTS.keys()}"
)
if config["adjust_method"] not in self.ADJUSTMENT_METHODS:
raise ValueError(
f"Invalid adjustment method: {config['adjust_method']}. Available methods: {self.ADJUSTMENT_METHODS.keys()}"
)
[docs]
def cohens_d(self, data_1, data_2):
"""Calculate Cohen's d"""
mean_1 = np.mean(data_1)
mean_2 = np.mean(data_2)
mean_sum = mean_1 - mean_2
pooled_std = np.sqrt((np.std(data_1) ** 2 + np.std(data_2) ** 2) / 2)
return mean_sum / pooled_std if pooled_std > 0 else 0
def _analyze_single_metrics(
self, metrics: Dict, reference_group: str, config: Dict[str, Any]
) -> Dict[str, Dict[str, StatTestResult]]:
"""Analyzes non-bootstrapped metrics against a reference group."""
results = {}
test_func = self._test_implementations[config["test_type"]]
metrics_CM = ["TP", "FP", "TN", "FN"]
# Get the keys of the metrics dictionary
metrics = {
key: {k: v for k, v in metrics[key].items() if k in metrics_CM}
for key in metrics.keys()
}
ref_metrics = {k: v for k, v in metrics.items() if k in [reference_group]}
# 1) omnibus test
results["omnibus"] = test_func(metrics, config)
#### 2) Pairwise once per non-reference group gated on whether we see any signficance across omnibus
omnibus_results = results.get("omnibus", {})
any_omnibus_significant = any(
r.is_significant for r in omnibus_results.values()
)
if any_omnibus_significant:
for group in metrics:
if group == reference_group:
continue
comp_metrics = {k: v for k, v in metrics.items() if k in [group]}
ref_comp_metrics = {**ref_metrics, **comp_metrics}
results[group] = test_func(ref_comp_metrics, config)
# 3) Annotate effect sizes per metric, gated on omnibus significance
for metric in self.METRIC_LIST:
if results["omnibus"][metric].is_significant:
results["omnibus"][metric].effect_size = self._calculate_effect_size(
metrics, metric
)
for group in results:
if group == "omnibus":
continue
if results[group][metric].is_significant:
pair = {**ref_metrics, group: metrics[group]}
results[group][metric].effect_size = (
self._calculate_effect_size(pair, metric)
)
else:
results[group][metric].effect_size = None
results[group][metric].confidence_interval = None
else:
results["omnibus"][metric].effect_size = None
results["omnibus"][metric].confidence_interval = None
# Clear pairwise for this metric too since the omnibus gate failed
for group in results:
if group == "omnibus":
continue
results[group][metric].effect_size = None
results[group][metric].confidence_interval = None
return results
def _analyze_bootstrapped_metrics(
self, metrics_diff: list[Dict], reference_group: str, config: Dict[str, Any]
) -> Dict[str, Dict[str, StatTestResult]]:
"""Analyzes bootstrapped metrics differences against a reference group."""
results = {}
test_func = self._test_implementations[config["test_type"]]
metrics_boot = config["metrics"]
aggregated_metric_dict = {}
for metric_dict in metrics_diff:
## getting rid of reference group
metric_dict.pop(reference_group, None)
for group_key, group_metrics in metric_dict.items():
### create new key in dictionary for each group e.g. "asian" and set to an empty list
if group_key not in aggregated_metric_dict:
aggregated_metric_dict[group_key] = {
metric: [] for metric in metrics_boot
}
## populate list with the values from each bootstrap for each group
for metric in metrics_boot:
aggregated_metric_dict[group_key][metric].append(
group_metrics[metric]
)
# calls test for each group e.g. hispanic etc. and then calls the
# bootstrap test func for each metric. e.g. Precision_diff
for group_key, group_metrics in aggregated_metric_dict.items():
results[group_key] = {}
for metric in metrics_boot:
test_result = test_func(group_metrics[metric], config)
results[group_key][metric] = test_result
return results