Skip to content

Temporal Model

neureptrace.temporal_model

build_state_trace(frame, *, stay_probability, class_names, prob_columns)

Decode posterior and Viterbi state traces for observed probability sequences.

Source code in src/neureptrace/temporal_model.py
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def build_state_trace(frame: pd.DataFrame, *, stay_probability: float, class_names: list[str], prob_columns: list[str]) -> pd.DataFrame:
    """Decode posterior and Viterbi state traces for observed probability sequences."""
    key_columns = _sequence_key_columns(frame)
    validate_unique_sequence_times(frame, key_columns)
    rows = []
    for key, sequence_frame in frame.sort_values([*key_columns, "time"]).groupby(key_columns, sort=True, dropna=False):
        key_values = key if isinstance(key, tuple) else (key,)
        metadata = dict(zip(key_columns, key_values, strict=True))
        probabilities = _normalize_probabilities(sequence_frame[prob_columns].to_numpy())
        posterior = _forward_backward(probabilities, stay_probability)
        viterbi = _viterbi_path(probabilities, stay_probability)
        for row_index, (_, observation) in enumerate(sequence_frame.iterrows()):
            state = int(viterbi[row_index])
            row = {
                **metadata,
                "decoder": str(observation["decoder"]),
                "emission_mode": str(observation["emission_mode"]) if "emission_mode" in observation else "calibrated",
                "time": float(observation["time"]),
                "viterbi_state": state,
                "viterbi_class": class_names[state],
                "viterbi_posterior": float(posterior[row_index, state]),
            }
            for optional_column in ("source_path", "source_file", "session", "run", "sample_index", "true_class", "predicted_class"):
                if optional_column in observation:
                    row[optional_column] = observation[optional_column]
            for state_index, class_name in enumerate(class_names):
                row[f"state_{state_index}"] = class_name
                row[f"posterior_state_{state_index}"] = float(posterior[row_index, state_index])
            rows.append(row)
    return pd.DataFrame(rows)

fit_sticky_switching_model(sequences, *, stay_grid_size=200)

Fit a sticky switching model by grid-searching the state persistence.

Source code in src/neureptrace/temporal_model.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def fit_sticky_switching_model(sequences: list[np.ndarray], *, stay_grid_size: int = 200) -> dict[str, float]:
    """Fit a sticky switching model by grid-searching the state persistence."""
    n_states = sequences[0].shape[1]
    if any(sequence.shape[1] != n_states for sequence in sequences):
        raise ValueError("All probability sequences must have the same number of states.")

    grid = _stay_grid(n_states, stay_grid_size)
    log_likelihoods = np.array([_total_log_likelihood(sequences, stay_probability) for stay_probability in grid])
    best_index = int(np.argmax(log_likelihoods))
    n_observations = int(sum(len(sequence) for sequence in sequences))
    log_likelihood = float(log_likelihoods[best_index])
    uniform_log_likelihood_per_observation = -float(np.log(n_states))
    log_likelihood_per_observation = log_likelihood / n_observations
    return {
        "n_sequences": float(len(sequences)),
        "n_observations": float(n_observations),
        "n_states": float(n_states),
        "best_stay_probability": float(grid[best_index]),
        "log_likelihood": log_likelihood,
        "log_likelihood_per_observation": log_likelihood_per_observation,
        "uniform_log_likelihood_per_observation": uniform_log_likelihood_per_observation,
        "persistence_gain_per_observation": log_likelihood_per_observation - uniform_log_likelihood_per_observation,
    }

fit_temporal_models(observation_csvs, *, effect_window=(0.1, 0.8), baseline_window=(-0.1, 0.0), n_permutations=100, random_seed=13, stay_grid_size=200, out_summary=None, out_states=None)

Fit sticky switching models to probability observation CSVs and controls.

Source code in src/neureptrace/temporal_model.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def fit_temporal_models(
    observation_csvs: list[Path],
    *,
    effect_window: tuple[float, float] = (0.1, 0.8),
    baseline_window: tuple[float, float] = (-0.1, 0.0),
    n_permutations: int = 100,
    random_seed: int = 13,
    stay_grid_size: int = 200,
    out_summary: Path | None = None,
    out_states: Path | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame | None]:
    """Fit sticky switching models to probability observation CSVs and controls."""
    observations = read_probability_observations(observation_csvs)
    prob_columns = probability_columns(observations)
    rows = []
    state_frames = []

    group_columns = _model_group_columns(observations)
    for keys, decoder_frame in observations.groupby(group_columns, sort=True):
        key_values = keys if isinstance(keys, tuple) else (keys,)
        group_values = dict(zip(group_columns, map(str, key_values), strict=True))
        class_names = _class_names(decoder_frame, prob_columns)
        effect_frame = _filter_time_window(decoder_frame, effect_window)
        effect_sequences = _sequences_from_frame(effect_frame, prob_columns)
        observed_fit = fit_sticky_switching_model(effect_sequences, stay_grid_size=stay_grid_size)
        rows.append(_model_row(group_values, "observed_effect", observed_fit))

        baseline_frame = _filter_time_window(decoder_frame, baseline_window)
        if not baseline_frame.empty:
            try:
                baseline_sequences = _sequences_from_frame(baseline_frame, prob_columns)
            except ValueError:
                baseline_sequences = []
            if baseline_sequences:
                baseline_fit = fit_sticky_switching_model(baseline_sequences, stay_grid_size=stay_grid_size)
                rows.append(_model_row(group_values, "baseline_window", baseline_fit))

        if n_permutations > 0:
            for offset, control in enumerate(("shuffled_time", "shuffled_label")):
                control_fits = _fit_control(
                    effect_sequences,
                    control=control,
                    n_permutations=n_permutations,
                    random_seed=random_seed + offset,
                    stay_grid_size=stay_grid_size,
                )
                rows.append(_control_row(group_values, control, control_fits, observed_gain=observed_fit["persistence_gain_per_observation"]))

        if out_states is not None:
            state_frames.append(
                build_state_trace(
                    effect_frame,
                    stay_probability=observed_fit["best_stay_probability"],
                    class_names=class_names,
                    prob_columns=prob_columns,
                )
            )

    summary = pd.DataFrame(rows)
    if out_summary is not None:
        out_summary.parent.mkdir(parents=True, exist_ok=True)
        summary.to_csv(out_summary, index=False)

    states = pd.concat(state_frames, ignore_index=True) if state_frames else None
    if out_states is not None and states is not None:
        out_states.parent.mkdir(parents=True, exist_ok=True)
        states.to_csv(out_states, index=False)
    return summary, states

probability_columns(frame)

Return probability-vector columns in class-index order.

Source code in src/neureptrace/temporal_model.py
34
35
36
37
38
39
40
41
42
43
44
def probability_columns(frame: pd.DataFrame) -> list[str]:
    """Return probability-vector columns in class-index order."""
    columns = [column for column in frame.columns if column.startswith("prob_class_")]
    if not columns:
        raise ValueError("Observation CSVs must contain probability columns named 'prob_class_*'.")

    def sort_key(column: str) -> tuple[int, str]:
        suffix = column.removeprefix("prob_class_")
        return (int(suffix), suffix) if suffix.isdigit() else (10_000, suffix)

    return sorted(columns, key=sort_key)

read_probability_observations(csv_paths)

Read held-out probability observation CSVs emitted by NeuRepTrace.

Source code in src/neureptrace/temporal_model.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def read_probability_observations(csv_paths: list[Path]) -> pd.DataFrame:
    """Read held-out probability observation CSVs emitted by NeuRepTrace."""
    if not csv_paths:
        raise ValueError("At least one observation CSV path is required.")

    frames = []
    for csv_path in csv_paths:
        frame = pd.read_csv(csv_path)
        missing = [column for column in ("time",) if column not in frame.columns]
        if missing:
            raise ValueError(f"{csv_path} is missing required columns: {missing}")
        prob_columns = probability_columns(frame)
        _validate_probability_matrix(frame[prob_columns].to_numpy())
        if "sequence_id" not in frame.columns:
            if "sample_index" not in frame.columns:
                raise ValueError(f"{csv_path} is missing 'sequence_id' or 'sample_index'.")
            frame["sequence_id"] = frame["sample_index"]
        if "subject" not in frame.columns:
            frame["subject"] = csv_path.stem
        if "decoder" not in frame.columns:
            frame["decoder"] = "decoder"
        if "emission_mode" not in frame.columns:
            frame["emission_mode"] = "calibrated"
        frame["subject"] = frame["subject"].astype(str)
        frame["decoder"] = frame["decoder"].astype(str)
        frame["emission_mode"] = frame["emission_mode"].astype(str)
        if "source_file" not in frame.columns:
            frame["source_file"] = csv_path.name
        else:
            frame["source_file"] = frame["source_file"].fillna(csv_path.name)
        if "source_path" not in frame.columns:
            frame["source_path"] = str(csv_path)
        else:
            frame["source_path"] = frame["source_path"].fillna(str(csv_path))
        frame["source_file"] = frame["source_file"].astype(str)
        frame["source_path"] = frame["source_path"].astype(str)
        frames.append(frame)
    return pd.concat(frames, ignore_index=True)

sequence_key_columns(frame, *, require_sequence_id=True)

Return columns that identify one probability/state sequence.

source_path/source_file and session-level columns are part of the sequence identity so reused sequence_id values from different input files, sessions, or runs are not silently concatenated.

Source code in src/neureptrace/temporal_model.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def sequence_key_columns(frame: pd.DataFrame, *, require_sequence_id: bool = True) -> list[str]:
    """Return columns that identify one probability/state sequence.

    ``source_path``/``source_file`` and session-level columns are part of the
    sequence identity so reused ``sequence_id`` values from different input
    files, sessions, or runs are not silently concatenated.
    """

    key_columns = [column for column in SEQUENCE_KEY_COLUMN_CANDIDATES if column in frame.columns]
    if require_sequence_id and "sequence_id" not in frame.columns:
        raise ValueError("Observation rows must contain sequence_id or sample_index.")
    if not key_columns:
        raise ValueError("Observation rows must contain at least one sequence key column.")
    return key_columns

validate_unique_sequence_times(frame, key_columns)

Fail fast when a sequence identity contains duplicate time bins.

Source code in src/neureptrace/temporal_model.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def validate_unique_sequence_times(frame: pd.DataFrame, key_columns: list[str]) -> None:
    """Fail fast when a sequence identity contains duplicate time bins."""

    if frame.empty or "time" not in frame.columns:
        return
    identity_columns = [*key_columns, "time"]
    duplicate_mask = frame.duplicated(identity_columns, keep=False)
    if not duplicate_mask.any():
        return
    examples = frame.loc[duplicate_mask, identity_columns].drop_duplicates().head(5).to_dict("records")
    raise ValueError(
        "Duplicate time rows found within a sequence identity. "
        "Include source_path/source_file/session/run in the sequence key or deduplicate rows before analysis. "
        f"Examples: {examples}"
    )