CIMO LabsCIMO Labs
← Back to CJE Overview

API Reference

Complete reference for analyze_dataset() and results

Main Function: analyze_dataset()

def analyze_dataset( logged_data_path: Optional[str] = None, # Path to JSONL (optional for Direct mode) fresh_draws_dir: Optional[str] = None, # Directory with fresh draws estimator: str = "auto", # Estimator or "auto" for mode selection judge_field: str = "judge_score", # Metadata field with judge scores oracle_field: str = "oracle_label", # Metadata field with oracle labels estimator_config: Optional[Dict] = None, # Estimator-specific config verbose: bool = False # Print detailed progress ) -> EstimationResult

Parameters

ParameterTypeDescription
logged_data_pathstr | NonePath to JSONL file with logged data (optional for Direct mode)
fresh_draws_dirstr | NoneDirectory containing fresh draw response files
estimatorstr"auto" (default) or manual: "direct", "calibrated-ips", "stacked-dr", etc.
judge_fieldstrField name for judge scores (default: "judge_score")
oracle_fieldstrField name for oracle labels (default: "oracle_label")
estimator_configDict | NoneOptional estimator-specific configuration
verboseboolPrint detailed progress (default: False)

Automatic Mode Selection

Use estimator="auto" (default) and CJE will:

  • • Detect the mode based on your data (Direct/IPS/DR)
  • • Select the best estimator for that mode
  • • Check logprob coverage (need ≥50% for IPS/DR)

Return Type: EstimationResult

@dataclass class EstimationResult: estimates: np.ndarray # Shape: [n_policies], values in [0,1] standard_errors: np.ndarray # Complete SEs (IF + MC + oracle) diagnostics: IPSDiagnostics # Health metrics metadata: Dict # Run metadata influence_functions: Dict # Per-sample contributions def ci(self, alpha=0.05) -> List[Tuple[float, float]]: """Get confidence intervals as (lower, upper) tuples."""

Result Fields

estimates

Policy value estimates as numpy array. One estimate per target policy, in [0, 1] range.

standard_errors

Complete standard errors including all uncertainty sources: influence function variance, Monte Carlo variance (for DR), and oracle uncertainty (when oracle coverage < 100%).

diagnostics

Health metrics including:

  • weight_ess - Effective sample size (0-1, higher is better)
  • ess_per_policy - ESS for each policy
  • overall_status - GOOD/WARNING/CRITICAL
  • calibration_rmse - Judge calibration quality

metadata

Run information including target policies, mode selected, estimator used, data sources, etc.

ci(alpha=0.05)

Convenience method returning 95% confidence intervals as list of (lower, upper) tuples.

Usage Examples

Basic Usage

from cje import analyze_dataset # Automatic mode selection result = analyze_dataset("data.jsonl") # Access results for i, policy in enumerate(result.metadata["target_policies"]): est = result.estimates[i] se = result.standard_errors[i] print(f"{policy}: {est:.3f} ± {1.96*se:.3f}") # Get confidence intervals cis = result.ci() # Returns [(lower, upper), ...] for i, (lower, upper) in enumerate(cis): policy = result.metadata["target_policies"][i] print(f"{policy}: [{lower:.3f}, {upper:.3f}]")

Check Diagnostics

# Check overall health if result.diagnostics.overall_status.value == "CRITICAL": print("⚠️ Critical issues detected") print(result.diagnostics.summary()) # Check ESS for each policy for policy, ess in result.diagnostics.ess_per_policy.items(): if ess < 0.30: print(f"⚠️ Low ESS for {policy}: {ess:.1%}")

Custom Configuration

# Use specific estimator with custom config result = analyze_dataset( "logs.jsonl", fresh_draws_dir="responses/", estimator="stacked-dr", estimator_config={ "n_folds": 10, # More folds for stability "use_calibrated_weights": True } )

CLI Commands

Basic Analysis

# Automatic mode selection python -m cje analyze data.jsonl # With fresh draws (for DR) python -m cje analyze logs.jsonl --fresh-draws-dir responses/ # Specify estimator python -m cje analyze logs.jsonl --estimator calibrated-ips # Save results to JSON python -m cje analyze data.jsonl -o results.json

Data Validation

# Check data format before running python -m cje validate data.jsonl --verbose

Common Patterns

Compare Multiple Policies

result = analyze_dataset("data.jsonl") policies = result.metadata["target_policies"] best_idx = result.estimates.argmax() print(f"Best policy: {policies[best_idx]}")

Export Results

import json result = analyze_dataset("data.jsonl") with open("results.json", "w") as f: json.dump({ "estimates": result.estimates.tolist(), "standard_errors": result.standard_errors.tolist(), "ess": result.diagnostics.weight_ess }, f)

Reliability Gating

result = analyze_dataset("data.jsonl") if result.diagnostics.weight_ess < 0.1: raise ValueError("Insufficient overlap for reliable estimation") return result.estimates[0]

Error Handling

ValueError: No data provided

At least one of logged_data_path or fresh_draws_dir must be provided.

ValueError: Estimator requires fresh draws

DR estimators like stacked-dr require fresh draws.

Solution: Provide fresh_draws_dir or use calibrated-ips

ValueError: Insufficient logprob coverage

Need ≥50% of samples with complete logprobs for IPS/DR modes.

Solution: Compute missing logprobs or use Direct mode with fresh draws

Developer Documentation

For module-level documentation, implementation details, and extending CJE, see the README files in each module directory on GitHub.