Waveform Database Software Package (WFDB) for Python 4.1.0

File: <base>/wfdb/processing/evaluate.py (18,430 bytes)
from multiprocessing import cpu_count, Pool

import numpy as np

from wfdb.io.annotation import rdann
from wfdb.io.download import get_record_list
from wfdb.io.record import rdsamp


class Comparitor(object):
    """
    The class to implement and hold comparisons between two sets of
    annotations. See methods `compare`, `print_summary` and `plot`.

    Attributes
    ----------
    ref_sample : ndarray
        An array of the reference sample locations.
    test_sample : ndarray
        An array of the comparison sample locations.
    window_width : int
        The width of the window.
    signal : 1d numpy array, optional
        The signal array the annotation samples are labelling. Only used
        for plotting.

    Examples
    --------
    >>> import wfdb
    >>> from wfdb import processing

    >>> sig, fields = wfdb.rdsamp('sample-data/100', channels=[0])
    >>> ann_ref = wfdb.rdann('sample-data/100','atr')
    >>> xqrs = processing.XQRS(sig=sig[:,0], fs=fields['fs'])
    >>> xqrs.detect()

    >>> comparitor = processing.Comparitor(ann_ref.sample[1:],
                                           xqrs.qrs_inds,
                                           int(0.1 * fields['fs']),
                                           sig[:,0])
    >>> comparitor.compare()
    >>> comparitor.print_summary()
    >>> comparitor.plot()

    """

    def __init__(self, ref_sample, test_sample, window_width, signal=None):
        if len(ref_sample) > 1 and len(test_sample) > 1:
            if min(np.diff(ref_sample)) < 0 or min(np.diff(test_sample)) < 0:
                raise ValueError(
                    (
                        "The sample locations must be monotonically"
                        + " increasing"
                    )
                )

        self.ref_sample = ref_sample
        self.test_sample = test_sample
        self.n_ref = len(ref_sample)
        self.n_test = len(test_sample)
        self.window_width = window_width

        # The matching test sample number for each reference annotation.
        # -1 for indices with no match
        self.matching_sample_nums = np.full(self.n_ref, -1, dtype="int")

        self.signal = signal
        # TODO: rdann return annotations.where

    def _calc_stats(self):
        """
        Calculate performance statistics after the two sets of annotations
        are compared.

        Parameters
        ----------
        N/A

        Returns
        -------
        N/A

        Example:
        -------------------
         ref=500  test=480
        {  30 { 470 } 10  }
        -------------------

        tp = 470
        fp = 10
        fn = 30

        sensitivity = 470 / 500
        positive_predictivity = 470 / 480

        """
        # Reference annotation indices that were detected
        self.matched_ref_inds = np.where(self.matching_sample_nums != -1)[0]
        # Reference annotation indices that were missed
        self.unmatched_ref_inds = np.where(self.matching_sample_nums == -1)[0]
        # Test annotation indices that were matched to a reference annotation
        self.matched_test_inds = self.matching_sample_nums[
            self.matching_sample_nums != -1
        ]
        # Test annotation indices that were unmatched to a reference annotation
        self.unmatched_test_inds = np.setdiff1d(
            np.array(range(self.n_test)),
            self.matched_test_inds,
            assume_unique=True,
        )

        # Sample numbers that were matched and unmatched
        self.matched_ref_sample = self.ref_sample[self.matched_ref_inds]
        self.unmatched_ref_sample = self.ref_sample[self.unmatched_ref_inds]
        self.matched_test_sample = self.test_sample[self.matched_test_inds]
        self.unmatched_test_sample = self.test_sample[self.unmatched_test_inds]

        # True positives = matched reference samples
        self.tp = len(self.matched_ref_inds)
        # False positives = extra test samples not matched
        self.fp = self.n_test - self.tp
        # False negatives = undetected reference samples
        self.fn = self.n_ref - self.tp
        # No tn attribute

        self.sensitivity = float(self.tp) / float(self.tp + self.fn)
        self.positive_predictivity = float(self.tp) / self.n_test

    def compare(self):
        """
        Main comparison function. Note: Make sure to be able to handle
        these ref/test scenarios:

        Parameters
        -------
        N/A

        Returns
        -------
        N/A

        Example
        -------
        A:
        o----o---o---o
        x-------x----x

        B:
        o----o-----o---o
        x--------x--x--x

        C:
        o------o-----o---o
        x-x--------x--x--x

        D:
        o------o-----o---o
        x-x--------x-----x

        """
        test_samp_num = 0
        ref_samp_num = 0

        # Iterate through the reference sample numbers
        while ref_samp_num < self.n_ref and test_samp_num < self.n_test:
            # Get the closest testing sample number for this reference sample
            closest_samp_num, smallest_samp_diff = self._get_closest_samp_num(
                ref_samp_num, test_samp_num
            )
            # Get the closest testing sample number for the next reference
            # sample. This doesn't need to be called for the last index.
            if ref_samp_num < self.n_ref - 1:
                (
                    closest_samp_num_next,
                    smallest_samp_diff_next,
                ) = self._get_closest_samp_num(ref_samp_num + 1, test_samp_num)
            else:
                # Set non-matching value if there is no next reference sample
                # to compete for the test sample
                closest_samp_num_next = -1

            # Found a contested test sample number. Decide which
            # reference sample it belongs to. If the sample is closer to
            # the next reference sample, leave it to the next reference
            # sample and label this reference sample as unmatched.
            if (
                closest_samp_num == closest_samp_num_next
                and smallest_samp_diff_next < smallest_samp_diff
            ):
                # Get the next closest sample for this reference sample,
                # if not already assigned to a previous sample.
                # It will be the previous testing sample number in any
                # possible case (scenario D below), or nothing.
                if closest_samp_num and (
                    not ref_samp_num
                    or closest_samp_num - 1
                    != self.matching_sample_nums[ref_samp_num - 1]
                ):
                    # The previous test annotation is inspected
                    closest_samp_num = closest_samp_num - 1
                    smallest_samp_diff = abs(
                        self.ref_sample[ref_samp_num]
                        - self.test_sample[closest_samp_num]
                    )
                    # Assign the reference-test pair if close enough
                    if smallest_samp_diff < self.window_width:
                        self.matching_sample_nums[
                            ref_samp_num
                        ] = closest_samp_num
                    # Set the starting test sample number to inspect
                    # for the next reference sample.
                    test_samp_num = closest_samp_num + 1

                # Otherwise there is no matching test annotation

            # If there is no clash, or the contested test sample is
            # closer to the current reference, keep the test sample
            # for this reference sample.
            else:
                # Assign the reference-test pair if close enough
                if smallest_samp_diff < self.window_width:
                    self.matching_sample_nums[ref_samp_num] = closest_samp_num
                # Increment the starting test sample number to inspect
                # for the next reference sample.
                test_samp_num = closest_samp_num + 1

            ref_samp_num += 1

        self._calc_stats()

    def _get_closest_samp_num(self, ref_samp_num, start_test_samp_num):
        """
        Return the closest testing sample number for the given reference
        sample number. Limit the search from start_test_samp_num.

        Parameters
        ----------
        ref_samp_num : int
            The desired reference sample number to get the closest result.
        start_test_samp_num
            The desired testing reference sample number to get the
            closest result.

        Returns
        -------
        closest_samp_num : int
            The closest sample number to the reference sample number.
        smallest_samp_diff : int
            The smallest difference between the reference sample and
            the testing sample.

        """
        if start_test_samp_num >= self.n_test:
            raise ValueError("Invalid starting test sample number.")

        ref_samp = self.ref_sample[ref_samp_num]
        test_samp = self.test_sample[start_test_samp_num]
        samp_diff = ref_samp - test_samp

        # Initialize running parameters
        closest_samp_num = start_test_samp_num
        smallest_samp_diff = abs(samp_diff)

        # Iterate through the testing samples
        for test_samp_num in range(start_test_samp_num, self.n_test):
            test_samp = self.test_sample[test_samp_num]
            samp_diff = ref_samp - test_samp
            abs_samp_diff = abs(samp_diff)

            # Found a better match
            if abs_samp_diff < smallest_samp_diff:
                closest_samp_num = test_samp_num
                smallest_samp_diff = abs_samp_diff

            # Stop iterating when the ref sample is first passed or reached
            if samp_diff <= 0:
                break

        return closest_samp_num, smallest_samp_diff

    def print_summary(self):
        """
        Print summary metrics of the annotation comparisons.

        Parameters
        ----------
        N/A

        Returns
        -------
        N/A

        """
        if not hasattr(self, "sensitivity"):
            self._calc_stats()

        print(
            "%d reference annotations, %d test annotations\n"
            % (self.n_ref, self.n_test)
        )
        print("True Positives (matched samples): %d" % self.tp)
        print("False Positives (unmatched test samples): %d" % self.fp)
        print("False Negatives (unmatched reference samples): %d\n" % self.fn)

        print(
            "Sensitivity: %.4f (%d/%d)"
            % (self.sensitivity, self.tp, self.n_ref)
        )
        print(
            "Positive Predictivity: %.4f (%d/%d)"
            % (self.positive_predictivity, self.tp, self.n_test)
        )

    def plot(self, sig_style="", title=None, figsize=None, return_fig=False):
        """
        Plot the comparison of two sets of annotations, possibly
        overlaid on their original signal.

        Parameters
        ----------
        sig_style : str, optional
            The matplotlib style of the signal
        title : str, optional
            The title of the plot
        figsize: tuple, optional
            Tuple pair specifying the width, and height of the figure.
            It is the'figsize' argument passed into matplotlib.pyplot's
            `figure` function.
        return_fig : bool, optional
            Whether the figure is to be returned as an output argument.

        Returns
        -------
        fig : matplotlib figure object
            The figure information for the plot.
        ax : matplotlib axes object
            The axes information for the plot.

        """
        import matplotlib.pyplot as plt

        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(1, 1, 1)

        legend = [
            "Signal",
            "Matched Reference Annotations (%d/%d)" % (self.tp, self.n_ref),
            "Unmatched Reference Annotations (%d/%d)" % (self.fn, self.n_ref),
            "Matched Test Annotations (%d/%d)" % (self.tp, self.n_test),
            "Unmatched Test Annotations (%d/%d)" % (self.fp, self.n_test),
        ]

        # Plot the signal if any
        if self.signal is not None:
            ax.plot(self.signal, sig_style)

            # Plot reference annotations
            ax.plot(
                self.matched_ref_sample,
                self.signal[self.matched_ref_sample],
                "ko",
            )
            ax.plot(
                self.unmatched_ref_sample,
                self.signal[self.unmatched_ref_sample],
                "ko",
                fillstyle="none",
            )
            # Plot test annotations
            ax.plot(
                self.matched_test_sample,
                self.signal[self.matched_test_sample],
                "g+",
            )
            ax.plot(
                self.unmatched_test_sample,
                self.signal[self.unmatched_test_sample],
                "rx",
            )

            ax.legend(legend)

        # Just plot annotations
        else:
            # Plot reference annotations
            ax.plot(self.matched_ref_sample, np.ones(self.tp), "ko")
            ax.plot(
                self.unmatched_ref_sample,
                np.ones(self.fn),
                "ko",
                fillstyle="none",
            )
            # Plot test annotations
            ax.plot(self.matched_test_sample, 0.5 * np.ones(self.tp), "g+")
            ax.plot(self.unmatched_test_sample, 0.5 * np.ones(self.fp), "rx")
            ax.legend(legend[1:])

        if title:
            ax.set_title(title)

        ax.set_xlabel("time/sample")

        fig.show()

        if return_fig:
            return fig, ax


def compare_annotations(ref_sample, test_sample, window_width, signal=None):
    """
    Compare a set of reference annotation locations against a set of
    test annotation locations. See the Comparitor class  docstring
    for more information.

    Parameters
    ----------
    ref_sample : 1d numpy array
        Array of reference sample locations.
    test_sample : 1d numpy array
        Array of test sample locations to compare.
    window_width : int
        The maximum absolute difference in sample numbers that is
        permitted for matching annotations.
    signal : 1d numpy array, optional
        The original signal of the two annotations. Only used for
        plotting.

    Returns
    -------
    comparitor : Comparitor object
        Object containing parameters about the two sets of annotations.

    Examples
    --------
    >>> import wfdb
    >>> from wfdb import processing

    >>> sig, fields = wfdb.rdsamp('sample-data/100', channels=[0])
    >>> ann_ref = wfdb.rdann('sample-data/100','atr')
    >>> xqrs = processing.XQRS(sig=sig[:,0], fs=fields['fs'])
    >>> xqrs.detect()

    >>> comparitor = processing.compare_annotations(ann_ref.sample[1:],
                                                    xqrs.qrs_inds,
                                                    int(0.1 * fields['fs']),
                                                    sig[:,0])
    >>> comparitor.print_summary()
    >>> comparitor.plot()

    """
    comparitor = Comparitor(
        ref_sample=ref_sample,
        test_sample=test_sample,
        window_width=window_width,
        signal=signal,
    )
    comparitor.compare()

    return comparitor


def benchmark_mitdb(detector, verbose=False, print_results=False):
    """
    Benchmark a QRS detector against mitdb's records.

    Parameters
    ----------
    detector : function
        The detector function.
    verbose : bool, optional
        The verbose option of the detector function.
    print_results : bool, optional
        Whether to print the overall performance, and the results for
        each record.

    Returns
    -------
    comparitors : dictionary
        Dictionary of Comparitor objects run on the records, keyed on
        the record names.
    sensitivity : float
        Aggregate sensitivity.
    positive_predictivity : float
        Aggregate positive_predictivity.

    Notes
    -----
    TODO:
    - remove non-qrs detections from reference annotations
    - allow kwargs

    Examples
    --------
    >>> import wfdb
    >> from wfdb.processing import benchmark_mitdb, xqrs_detect

    >>> comparitors, spec, pp = benchmark_mitdb(xqrs_detect)

    """
    record_list = get_record_list("mitdb")
    n_records = len(record_list)

    # Function arguments for starmap
    args = zip(record_list, n_records * [detector], n_records * [verbose])

    # Run detector and compare against reference annotations for all
    # records
    with Pool(cpu_count() - 1) as p:
        comparitors = p.starmap(benchmark_mitdb_record, args)

    # Calculate aggregate stats
    sensitivity = np.mean([c.sensitivity for c in comparitors])
    positive_predictivity = np.mean(
        [c.positive_predictivity for c in comparitors]
    )

    comparitors = dict(zip(record_list, comparitors))

    print("Benchmark complete")

    if print_results:
        print(
            "\nOverall MITDB Performance - Sensitivity: %.4f, Positive Predictivity: %.4f\n"
            % (sensitivity, positive_predictivity)
        )
        for record_name in record_list:
            print("Record %s:" % record_name)
            comparitors[record_name].print_summary()
            print("\n\n")

    return comparitors, sensitivity, positive_predictivity


def benchmark_mitdb_record(rec, detector, verbose):
    """
    Benchmark a single mitdb record.

    Parameters
    ----------
    rec : str
        The mitdb record to be benchmarked.
    detector : function
        The detector function.
    verbose : bool
        Whether to print the record names (True) or not (False).

    Returns
    -------
    comparitor : Comparitor object
        Object containing parameters about the two sets of annotations.

    """
    sig, fields = rdsamp(rec, pn_dir="mitdb", channels=[0])
    ann_ref = rdann(rec, pn_dir="mitdb", extension="atr")

    qrs_inds = detector(sig=sig[:, 0], fs=fields["fs"], verbose=verbose)

    comparitor = compare_annotations(
        ref_sample=ann_ref.sample[1:],
        test_sample=qrs_inds,
        window_width=int(0.1 * fields["fs"]),
    )
    if verbose:
        print("Finished record %s" % rec)
    return comparitor