Main Function: analyze_dataset()
def analyze_dataset(
logged_data_path: Optional[str] = None, # Path to JSONL (optional for Direct mode)
fresh_draws_dir: Optional[str] = None, # Directory with fresh draws
estimator: str = "auto", # Estimator or "auto" for mode selection
judge_field: str = "judge_score", # Metadata field with judge scores
oracle_field: str = "oracle_label", # Metadata field with oracle labels
estimator_config: Optional[Dict] = None, # Estimator-specific config
verbose: bool = False # Print detailed progress
) -> EstimationResult
Parameters
Parameter | Type | Description |
---|
logged_data_path | str | None | Path to JSONL file with logged data (optional for Direct mode) |
fresh_draws_dir | str | None | Directory containing fresh draw response files |
estimator | str | "auto" (default) or manual: "direct" , "calibrated-ips" , "stacked-dr" , etc. |
judge_field | str | Field name for judge scores (default: "judge_score") |
oracle_field | str | Field name for oracle labels (default: "oracle_label") |
estimator_config | Dict | None | Optional estimator-specific configuration |
verbose | bool | Print detailed progress (default: False) |
Automatic Mode Selection
Use estimator="auto"
(default) and CJE will:
- • Detect the mode based on your data (Direct/IPS/DR)
- • Select the best estimator for that mode
- • Check logprob coverage (need ≥50% for IPS/DR)
Return Type: EstimationResult
@dataclass
class EstimationResult:
estimates: np.ndarray # Shape: [n_policies], values in [0,1]
standard_errors: np.ndarray # Complete SEs (IF + MC + oracle)
diagnostics: IPSDiagnostics # Health metrics
metadata: Dict # Run metadata
influence_functions: Dict # Per-sample contributions
def ci(self, alpha=0.05) -> List[Tuple[float, float]]:
"""Get confidence intervals as (lower, upper) tuples."""
Result Fields
estimates
Policy value estimates as numpy array. One estimate per target policy, in [0, 1] range.
standard_errors
Complete standard errors including all uncertainty sources: influence function variance, Monte Carlo variance (for DR), and oracle uncertainty (when oracle coverage < 100%).
diagnostics
Health metrics including:
- •
weight_ess
- Effective sample size (0-1, higher is better) - •
ess_per_policy
- ESS for each policy - •
overall_status
- GOOD/WARNING/CRITICAL - •
calibration_rmse
- Judge calibration quality
metadata
Run information including target policies, mode selected, estimator used, data sources, etc.
ci(alpha=0.05)
Convenience method returning 95% confidence intervals as list of (lower, upper) tuples.
Usage Examples
Basic Usage
from cje import analyze_dataset
# Automatic mode selection
result = analyze_dataset("data.jsonl")
# Access results
for i, policy in enumerate(result.metadata["target_policies"]):
est = result.estimates[i]
se = result.standard_errors[i]
print(f"{policy}: {est:.3f} ± {1.96*se:.3f}")
# Get confidence intervals
cis = result.ci() # Returns [(lower, upper), ...]
for i, (lower, upper) in enumerate(cis):
policy = result.metadata["target_policies"][i]
print(f"{policy}: [{lower:.3f}, {upper:.3f}]")
Check Diagnostics
# Check overall health
if result.diagnostics.overall_status.value == "CRITICAL":
print("⚠️ Critical issues detected")
print(result.diagnostics.summary())
# Check ESS for each policy
for policy, ess in result.diagnostics.ess_per_policy.items():
if ess < 0.30:
print(f"⚠️ Low ESS for {policy}: {ess:.1%}")
Custom Configuration
# Use specific estimator with custom config
result = analyze_dataset(
"logs.jsonl",
fresh_draws_dir="responses/",
estimator="stacked-dr",
estimator_config={
"n_folds": 10, # More folds for stability
"use_calibrated_weights": True
}
)
CLI Commands
Basic Analysis
# Automatic mode selection
python -m cje analyze data.jsonl
# With fresh draws (for DR)
python -m cje analyze logs.jsonl --fresh-draws-dir responses/
# Specify estimator
python -m cje analyze logs.jsonl --estimator calibrated-ips
# Save results to JSON
python -m cje analyze data.jsonl -o results.json
Data Validation
# Check data format before running
python -m cje validate data.jsonl --verbose
Common Patterns
Compare Multiple Policies
result = analyze_dataset("data.jsonl")
policies = result.metadata["target_policies"]
best_idx = result.estimates.argmax()
print(f"Best policy: {policies[best_idx]}")
Export Results
import json
result = analyze_dataset("data.jsonl")
with open("results.json", "w") as f:
json.dump({
"estimates": result.estimates.tolist(),
"standard_errors": result.standard_errors.tolist(),
"ess": result.diagnostics.weight_ess
}, f)
Reliability Gating
result = analyze_dataset("data.jsonl")
if result.diagnostics.weight_ess < 0.1:
raise ValueError("Insufficient overlap for reliable estimation")
return result.estimates[0]
Error Handling
ValueError: No data provided
At least one of logged_data_path
or fresh_draws_dir
must be provided.
ValueError: Estimator requires fresh draws
DR estimators like stacked-dr
require fresh draws.
Solution: Provide fresh_draws_dir
or use calibrated-ips
ValueError: Insufficient logprob coverage
Need ≥50% of samples with complete logprobs for IPS/DR modes.
Solution: Compute missing logprobs or use Direct mode with fresh draws
Developer Documentation
For module-level documentation, implementation details, and extending CJE, see the README files in each module directory on GitHub.