Skip to content

Onset Detection

neureptrace.onset_detection detects the first threshold-crossing time in held-out probability-observation traces.

The basic detector estimates a score threshold from a baseline window, then reports the first time each trial/sequence crosses that threshold. The module also supports sustained-onset criteria and a sequence-level max-run threshold to reduce false detections from scanning many time bins.

Sustained-onset controls

Use the following options to require a more persistent representation onset:

python -m neureptrace.onset_detection \
  results/nod_sub-01_animate_observations.csv \
  --threshold-window -0.10 0.00 \
  --threshold-quantile 0.95 \
  --threshold-method max_run \
  --detection-start 0.00 \
  --min-consecutive 3 \
  --require-stable-prediction \
  --out-events results/nod_sub-01_animate_onset_events.csv \
  --out-summary results/nod_sub-01_animate_onset_summary.csv
  • --min-consecutive requires at least this many adjacent above-threshold windows.
  • --min-duration requires the above-threshold run to last at least the given duration in seconds.
  • --require-stable-prediction breaks an onset run when the predicted class changes across adjacent above-threshold bins.
  • --threshold-method max_run estimates the threshold from sequence-level baseline maxima under the same run criteria, rather than from pointwise baseline scores.

The event CSV includes the run length, run duration, run stop time, and peak score within the detection run. The summary CSV reports detection rates, false-alarm rates, post-zero detection rates, correct-at-detection rates, and median post-zero detection latencies.

neureptrace.onset_detection

annotate_threshold_crossings(observations, *, threshold_window=DEFAULT_THRESHOLD_WINDOW, threshold_quantile=DEFAULT_THRESHOLD_QUANTILE, score_column='confidence', threshold_method='point', min_consecutive=1, min_duration=None, require_stable_prediction=False)

Annotate observation rows with baseline-derived threshold crossings.

Source code in src/neureptrace/onset_detection.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def annotate_threshold_crossings(
    observations: pd.DataFrame,
    *,
    threshold_window: tuple[float, float] = DEFAULT_THRESHOLD_WINDOW,
    threshold_quantile: float = DEFAULT_THRESHOLD_QUANTILE,
    score_column: str = "confidence",
    threshold_method: str = "point",
    min_consecutive: int = 1,
    min_duration: float | None = None,
    require_stable_prediction: bool = False,
) -> pd.DataFrame:
    """Annotate observation rows with baseline-derived threshold crossings."""

    if not 0.0 <= threshold_quantile <= 1.0:
        raise ValueError("threshold_quantile must be between 0 and 1.")
    if threshold_method not in THRESHOLD_METHODS:
        raise ValueError(f"threshold_method must be one of {THRESHOLD_METHODS}.")
    if "time" not in observations.columns:
        raise ValueError("Observation rows must contain a time column.")

    observations = _ensure_prediction_columns(observations)
    group_columns = _group_columns(observations)
    frames = []
    grouped = observations.groupby(group_columns, sort=True) if group_columns else [((), observations)]
    for _, group_frame in grouped:
        frames.append(
            _annotate_group_threshold(
                group_frame,
                threshold_window=threshold_window,
                threshold_quantile=threshold_quantile,
                score_column=score_column,
                threshold_method=threshold_method,
                min_consecutive=min_consecutive,
                min_duration=min_duration,
                require_stable_prediction=require_stable_prediction,
            )
        )
    return pd.concat(frames, ignore_index=True) if frames else observations.copy()

detect_onsets(observations, *, threshold_window=DEFAULT_THRESHOLD_WINDOW, threshold_quantile=DEFAULT_THRESHOLD_QUANTILE, score_column='confidence', threshold_method='point', detection_start=None, detection_window=None, min_consecutive=1, min_duration=None, require_stable_prediction=False)

Find the first threshold-crossing time for each probability-observation sequence.

min_consecutive and min_duration can be used to suppress single-bin spikes by requiring the threshold crossing to be sustained. With require_stable_prediction=True, an onset run is also broken when the predicted class changes across adjacent above-threshold bins.

Source code in src/neureptrace/onset_detection.py
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
def detect_onsets(
    observations: pd.DataFrame,
    *,
    threshold_window: tuple[float, float] = DEFAULT_THRESHOLD_WINDOW,
    threshold_quantile: float = DEFAULT_THRESHOLD_QUANTILE,
    score_column: str = "confidence",
    threshold_method: str = "point",
    detection_start: float | None = None,
    detection_window: tuple[float, float] | None = None,
    min_consecutive: int = 1,
    min_duration: float | None = None,
    require_stable_prediction: bool = False,
) -> pd.DataFrame:
    """Find the first threshold-crossing time for each probability-observation sequence.

    ``min_consecutive`` and ``min_duration`` can be used to suppress single-bin
    spikes by requiring the threshold crossing to be sustained. With
    ``require_stable_prediction=True``, an onset run is also broken when the
    predicted class changes across adjacent above-threshold bins.
    """

    if not 0.0 <= threshold_quantile <= 1.0:
        raise ValueError("threshold_quantile must be between 0 and 1.")
    if threshold_method not in THRESHOLD_METHODS:
        raise ValueError(f"threshold_method must be one of {THRESHOLD_METHODS}.")
    if min_consecutive < 1:
        raise ValueError("min_consecutive must be at least 1.")
    if min_duration is not None and min_duration < 0:
        raise ValueError("min_duration must be non-negative when provided.")
    if "time" not in observations.columns:
        raise ValueError("Observation rows must contain a time column.")

    observations = _prepare_thresholded_observations(
        observations,
        threshold_window=threshold_window,
        threshold_quantile=threshold_quantile,
        score_column=score_column,
        threshold_method=threshold_method,
        min_consecutive=min_consecutive,
        min_duration=min_duration,
        require_stable_prediction=require_stable_prediction,
    )
    group_columns = _group_columns(observations)
    sequence_columns = _sequence_columns(observations)
    event_rows = []

    grouped = observations.groupby(group_columns, sort=True) if group_columns else [((), observations)]
    for keys, group_frame in grouped:
        key_values = keys if isinstance(keys, tuple) else (keys,)
        group_values = dict(zip(group_columns, key_values, strict=True))
        threshold = (
            group_frame["score_threshold"].iloc[0]
            if "score_threshold" in group_frame
            else _threshold_for_group(
                group_frame,
                threshold_window=threshold_window,
                threshold_quantile=threshold_quantile,
                score_column=score_column,
                threshold_method=threshold_method,
                min_consecutive=min_consecutive,
                min_duration=min_duration,
                require_stable_prediction=require_stable_prediction,
            )
        )
        validate_unique_sequence_times(group_frame, sequence_columns)
        sorted_group = group_frame.sort_values([*sequence_columns, "time"])
        for _, sequence_frame in sorted_group.groupby(sequence_columns, sort=True, dropna=False):
            candidates = sequence_frame
            if detection_start is not None:
                candidates = candidates.loc[candidates["time"] >= detection_start]
            if detection_window is not None:
                start, stop = detection_window
                candidates = candidates.loc[(candidates["time"] >= start) & (candidates["time"] <= stop)]
            detection_run = _first_detection_run(
                candidates,
                threshold=threshold,
                min_consecutive=min_consecutive,
                min_duration=min_duration,
                require_stable_prediction=require_stable_prediction,
            )
            event_rows.append(
                _event_row(
                    group_values,
                    sequence_frame,
                    detection_run,
                    threshold=threshold,
                    threshold_method=threshold_method,
                    threshold_window=threshold_window,
                    threshold_quantile=threshold_quantile,
                    score_column=score_column,
                    detection_start=detection_start,
                    detection_window=detection_window,
                    min_consecutive=min_consecutive,
                    min_duration=min_duration,
                    require_stable_prediction=require_stable_prediction,
                )
            )
    return pd.DataFrame(event_rows)

detect_onsets_from_csvs(observation_csvs, *, threshold_window=DEFAULT_THRESHOLD_WINDOW, threshold_quantile=DEFAULT_THRESHOLD_QUANTILE, score_column='confidence', threshold_method='point', detection_start=None, event_window=None, min_consecutive=1, min_duration=None, require_stable_prediction=False, out_events=None, out_summary=None, out_thresholded_observations=None, out_threshold_summary=None, detection_window=DEFAULT_DETECTION_WINDOW)

Read probability observations, detect onsets, and optionally write CSV outputs.

Source code in src/neureptrace/onset_detection.py
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
def detect_onsets_from_csvs(
    observation_csvs: list[Path],
    *,
    threshold_window: tuple[float, float] = DEFAULT_THRESHOLD_WINDOW,
    threshold_quantile: float = DEFAULT_THRESHOLD_QUANTILE,
    score_column: str = "confidence",
    threshold_method: str = "point",
    detection_start: float | None = None,
    event_window: tuple[float, float] | None = None,
    min_consecutive: int = 1,
    min_duration: float | None = None,
    require_stable_prediction: bool = False,
    out_events: Path | None = None,
    out_summary: Path | None = None,
    out_thresholded_observations: Path | None = None,
    out_threshold_summary: Path | None = None,
    detection_window: tuple[float, float] = DEFAULT_DETECTION_WINDOW,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Read probability observations, detect onsets, and optionally write CSV outputs."""

    observations = read_probability_observations(observation_csvs)
    event_detection_window = event_window if event_window is not None else detection_window
    thresholded_observations = annotate_threshold_crossings(
        observations,
        threshold_window=threshold_window,
        threshold_quantile=threshold_quantile,
        score_column=score_column,
        threshold_method=threshold_method,
        min_consecutive=min_consecutive,
        min_duration=min_duration,
        require_stable_prediction=require_stable_prediction,
    )
    events = detect_onsets(
        thresholded_observations,
        threshold_window=threshold_window,
        threshold_quantile=threshold_quantile,
        score_column=score_column,
        threshold_method=threshold_method,
        detection_start=detection_start,
        detection_window=event_detection_window,
        min_consecutive=min_consecutive,
        min_duration=min_duration,
        require_stable_prediction=require_stable_prediction,
    )
    summary = summarize_onset_events(events)
    threshold_summary = summarize_threshold_crossings(
        thresholded_observations,
        baseline_window=threshold_window,
        detection_window=event_detection_window,
    )
    if out_events is not None:
        out_events.parent.mkdir(parents=True, exist_ok=True)
        events.to_csv(out_events, index=False)
    if out_summary is not None:
        out_summary.parent.mkdir(parents=True, exist_ok=True)
        summary.to_csv(out_summary, index=False)
    if out_thresholded_observations is not None:
        out_thresholded_observations.parent.mkdir(parents=True, exist_ok=True)
        thresholded_observations.drop(columns=["_onset_score"], errors="ignore").to_csv(out_thresholded_observations, index=False)
    if out_threshold_summary is not None:
        out_threshold_summary.parent.mkdir(parents=True, exist_ok=True)
        threshold_summary.to_csv(out_threshold_summary, index=False)
    return events, summary

summarize_onset_events(events)

Summarize onset-detection events by subject/decoder/emission group.

Source code in src/neureptrace/onset_detection.py
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
def summarize_onset_events(events: pd.DataFrame) -> pd.DataFrame:
    """Summarize onset-detection events by subject/decoder/emission group."""

    group_columns = _group_columns(events)
    rows = []
    grouped = events.groupby(group_columns, sort=True) if group_columns else [((), events)]
    for keys, group_frame in grouped:
        key_values = keys if isinstance(keys, tuple) else (keys,)
        group_values = dict(zip(group_columns, key_values, strict=True))
        detected = group_frame["detected"].astype(bool)
        false_alarm = group_frame["detected_before_zero"].astype(bool)
        correct = group_frame["is_correct_at_detection"].astype(bool)
        post_detected = detected & ~false_alarm
        latencies = pd.to_numeric(group_frame.loc[post_detected, "detection_latency"], errors="coerce").dropna()
        run_durations = pd.to_numeric(
            group_frame.loc[post_detected, "detection_run_duration"],
            errors="coerce",
        ).dropna()
        run_lengths = pd.to_numeric(
            group_frame.loc[post_detected, "detection_run_length"],
            errors="coerce",
        ).dropna()
        rows.append(
            {
                **group_values,
                "n_sequences": len(group_frame),
                "detected_count": int(detected.sum()),
                "detected_rate": float(detected.mean()) if len(detected) else np.nan,
                "false_alarm_count": int(false_alarm.sum()),
                "false_alarm_rate": float(false_alarm.mean()) if len(false_alarm) else np.nan,
                "post_zero_detected_count": int(post_detected.sum()),
                "post_zero_detected_rate": float(post_detected.mean()) if len(post_detected) else np.nan,
                "correct_detection_count": int((detected & correct).sum()),
                "correct_detection_rate": float((detected & correct).mean()) if len(correct) else np.nan,
                "post_detection_latency_mean": float(latencies.mean()) if not latencies.empty else np.nan,
                "post_detection_latency_median": float(latencies.median()) if not latencies.empty else np.nan,
                "post_detection_run_duration_mean": float(run_durations.mean()) if not run_durations.empty else np.nan,
                "post_detection_run_duration_median": (
                    float(run_durations.median()) if not run_durations.empty else np.nan
                ),
                "post_detection_run_length_median": (
                    float(run_lengths.median()) if not run_lengths.empty else np.nan
                ),
                "score_threshold": (
                    group_frame["score_threshold"].iloc[0]
                    if "score_threshold" in group_frame
                    else np.nan
                ),
                "threshold_method": (
                    group_frame["threshold_method"].iloc[0]
                    if "threshold_method" in group_frame
                    else ""
                ),
                "threshold_quantile": (
                    group_frame["threshold_quantile"].iloc[0]
                    if "threshold_quantile" in group_frame
                    else np.nan
                ),
                "threshold_window_start": (
                    group_frame["threshold_window_start"].iloc[0]
                    if "threshold_window_start" in group_frame
                    else np.nan
                ),
                "threshold_window_stop": (
                    group_frame["threshold_window_stop"].iloc[0]
                    if "threshold_window_stop" in group_frame
                    else np.nan
                ),
                "min_consecutive": group_frame["min_consecutive"].iloc[0] if "min_consecutive" in group_frame else 1,
                "min_duration": group_frame["min_duration"].iloc[0] if "min_duration" in group_frame else np.nan,
                "require_stable_prediction": group_frame["require_stable_prediction"].iloc[0]
                if "require_stable_prediction" in group_frame
                else False,
            }
        )
    return pd.DataFrame(rows)

summarize_threshold_crossings(thresholded_observations, *, baseline_window=DEFAULT_THRESHOLD_WINDOW, detection_window=DEFAULT_DETECTION_WINDOW)

Summarize baseline false positives separately from post-event detections.

Source code in src/neureptrace/onset_detection.py
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
def summarize_threshold_crossings(
    thresholded_observations: pd.DataFrame,
    *,
    baseline_window: tuple[float, float] = DEFAULT_THRESHOLD_WINDOW,
    detection_window: tuple[float, float] = DEFAULT_DETECTION_WINDOW,
) -> pd.DataFrame:
    """Summarize baseline false positives separately from post-event detections."""

    group_columns = _group_columns(thresholded_observations)
    sequence_columns = _sequence_columns(thresholded_observations)
    rows = []
    grouped = thresholded_observations.groupby(group_columns, sort=True) if group_columns else [((), thresholded_observations)]
    for keys, group_frame in grouped:
        key_values = keys if isinstance(keys, tuple) else (keys,)
        group_values = dict(zip(group_columns, key_values, strict=True))
        baseline_stats = _window_threshold_stats(group_frame, baseline_window, sequence_columns)
        detection_stats = _window_threshold_stats(group_frame, detection_window, sequence_columns)
        rows.append(
            {
                **group_values,
                "score_threshold": group_frame["score_threshold"].iloc[0] if "score_threshold" in group_frame else np.nan,
                "score_column": group_frame["score_column"].iloc[0] if "score_column" in group_frame else "",
                "threshold_method": group_frame["threshold_method"].iloc[0] if "threshold_method" in group_frame else "",
                "threshold_quantile": group_frame["threshold_quantile"].iloc[0] if "threshold_quantile" in group_frame else np.nan,
                "baseline_window_start": baseline_window[0],
                "baseline_window_stop": baseline_window[1],
                "detection_window_start": detection_window[0],
                "detection_window_stop": detection_window[1],
                "baseline_n_observations": baseline_stats["n_observations"],
                "baseline_false_positive_count": baseline_stats["threshold_crossing_count"],
                "baseline_false_positive_rate": baseline_stats["threshold_crossing_rate"],
                "baseline_false_positive_sequence_count": baseline_stats["sequence_crossing_count"],
                "baseline_false_positive_sequence_rate": baseline_stats["sequence_crossing_rate"],
                "post_stimulus_n_observations": detection_stats["n_observations"],
                "post_stimulus_detection_count": detection_stats["threshold_crossing_count"],
                "post_stimulus_detection_rate": detection_stats["threshold_crossing_rate"],
                "post_stimulus_detection_sequence_count": detection_stats["sequence_crossing_count"],
                "post_stimulus_detection_sequence_rate": detection_stats["sequence_crossing_rate"],
                "post_stimulus_correct_detection_count": detection_stats.get("correct_crossing_count", np.nan),
                "post_stimulus_correct_detection_rate": detection_stats.get("correct_crossing_rate", np.nan),
            }
        )
    return pd.DataFrame(rows)