Waveform Database Software Package (WFDB) for Python 3.3.0
(17,805 bytes)
from multiprocessing import cpu_count, Pool
import matplotlib.pyplot as plt
import numpy as np
import requests
from wfdb.io.annotation import rdann
from wfdb.io.download import get_record_list
from wfdb.io.record import rdsamp
class Comparitor(object):
"""
The class to implement and hold comparisons between two sets of
annotations. See methods `compare`, `print_summary` and `plot`.
Attributes
----------
ref_sample : ndarray
An array of the reference sample locations.
test_sample : ndarray
An array of the comparison sample locations.
window_width : int
The width of the window.
signal : 1d numpy array, optional
The signal array the annotation samples are labelling. Only used
for plotting.
Examples
--------
>>> import wfdb
>>> from wfdb import processing
>>> sig, fields = wfdb.rdsamp('sample-data/100', channels=[0])
>>> ann_ref = wfdb.rdann('sample-data/100','atr')
>>> xqrs = processing.XQRS(sig=sig[:,0], fs=fields['fs'])
>>> xqrs.detect()
>>> comparitor = processing.Comparitor(ann_ref.sample[1:],
xqrs.qrs_inds,
int(0.1 * fields['fs']),
sig[:,0])
>>> comparitor.compare()
>>> comparitor.print_summary()
>>> comparitor.plot()
"""
def __init__(self, ref_sample, test_sample, window_width, signal=None):
if len(ref_sample) > 1 and len(test_sample) > 1:
if min(np.diff(ref_sample)) < 0 or min(np.diff(test_sample)) < 0:
raise ValueError(('The sample locations must be monotonically'
+ ' increasing'))
self.ref_sample = ref_sample
self.test_sample = test_sample
self.n_ref = len(ref_sample)
self.n_test = len(test_sample)
self.window_width = window_width
# The matching test sample number for each reference annotation.
# -1 for indices with no match
self.matching_sample_nums = np.full(self.n_ref, -1, dtype='int')
self.signal = signal
# TODO: rdann return annotations.where
def _calc_stats(self):
"""
Calculate performance statistics after the two sets of annotations
are compared.
Parameters
----------
N/A
Returns
-------
N/A
Example:
-------------------
ref=500 test=480
{ 30 { 470 } 10 }
-------------------
tp = 470
fp = 10
fn = 30
sensitivity = 470 / 500
positive_predictivity = 470 / 480
"""
# Reference annotation indices that were detected
self.matched_ref_inds = np.where(self.matching_sample_nums != -1)[0]
# Reference annotation indices that were missed
self.unmatched_ref_inds = np.where(self.matching_sample_nums == -1)[0]
# Test annotation indices that were matched to a reference annotation
self.matched_test_inds = self.matching_sample_nums[
self.matching_sample_nums != -1]
# Test annotation indices that were unmatched to a reference annotation
self.unmatched_test_inds = np.setdiff1d(np.array(range(self.n_test)),
self.matched_test_inds, assume_unique=True)
# Sample numbers that were matched and unmatched
self.matched_ref_sample = self.ref_sample[self.matched_ref_inds]
self.unmatched_ref_sample = self.ref_sample[self.unmatched_ref_inds]
self.matched_test_sample = self.test_sample[self.matched_test_inds]
self.unmatched_test_sample = self.test_sample[self.unmatched_test_inds]
# True positives = matched reference samples
self.tp = len(self.matched_ref_inds)
# False positives = extra test samples not matched
self.fp = self.n_test - self.tp
# False negatives = undetected reference samples
self.fn = self.n_ref - self.tp
# No tn attribute
self.sensitivity = float(self.tp) / float(self.tp + self.fn)
self.positive_predictivity = float(self.tp) / self.n_test
def compare(self):
"""
Main comparison function. Note: Make sure to be able to handle
these ref/test scenarios:
Parameters
-------
N/A
Returns
-------
N/A
Example
-------
A:
o----o---o---o
x-------x----x
B:
o----o-----o---o
x--------x--x--x
C:
o------o-----o---o
x-x--------x--x--x
D:
o------o-----o---o
x-x--------x-----x
"""
test_samp_num = 0
ref_samp_num = 0
# Iterate through the reference sample numbers
while ref_samp_num < self.n_ref and test_samp_num < self.n_test:
# Get the closest testing sample number for this reference sample
closest_samp_num, smallest_samp_diff = (
self._get_closest_samp_num(ref_samp_num, test_samp_num))
# Get the closest testing sample number for the next reference
# sample. This doesn't need to be called for the last index.
if ref_samp_num < self.n_ref - 1:
closest_samp_num_next, smallest_samp_diff_next = (
self._get_closest_samp_num(ref_samp_num + 1, test_samp_num))
else:
# Set non-matching value if there is no next reference sample
# to compete for the test sample
closest_samp_num_next = -1
# Found a contested test sample number. Decide which
# reference sample it belongs to. If the sample is closer to
# the next reference sample, leave it to the next reference
# sample and label this reference sample as unmatched.
if (closest_samp_num == closest_samp_num_next
and smallest_samp_diff_next < smallest_samp_diff):
# Get the next closest sample for this reference sample,
# if not already assigned to a previous sample.
# It will be the previous testing sample number in any
# possible case (scenario D below), or nothing.
if closest_samp_num and (not ref_samp_num or closest_samp_num - 1 != self.matching_sample_nums[ref_samp_num - 1]):
# The previous test annotation is inspected
closest_samp_num = closest_samp_num - 1
smallest_samp_diff = abs(self.ref_sample[ref_samp_num]
- self.test_sample[closest_samp_num])
# Assign the reference-test pair if close enough
if smallest_samp_diff < self.window_width:
self.matching_sample_nums[ref_samp_num] = closest_samp_num
# Set the starting test sample number to inspect
# for the next reference sample.
test_samp_num = closest_samp_num + 1
# Otherwise there is no matching test annotation
# If there is no clash, or the contested test sample is
# closer to the current reference, keep the test sample
# for this reference sample.
else:
# Assign the reference-test pair if close enough
if smallest_samp_diff < self.window_width:
self.matching_sample_nums[ref_samp_num] = closest_samp_num
# Increment the starting test sample number to inspect
# for the next reference sample.
test_samp_num = closest_samp_num + 1
ref_samp_num += 1
self._calc_stats()
def _get_closest_samp_num(self, ref_samp_num, start_test_samp_num):
"""
Return the closest testing sample number for the given reference
sample number. Limit the search from start_test_samp_num.
Parameters
----------
ref_samp_num : int
The desired reference sample number to get the closest result.
start_test_samp_num
The desired testing reference sample number to get the
closest result.
Returns
-------
closest_samp_num : int
The closest sample number to the reference sample number.
smallest_samp_diff : int
The smallest difference between the reference sample and
the testing sample.
"""
if start_test_samp_num >= self.n_test:
raise ValueError('Invalid starting test sample number.')
ref_samp = self.ref_sample[ref_samp_num]
test_samp = self.test_sample[start_test_samp_num]
samp_diff = ref_samp - test_samp
# Initialize running parameters
closest_samp_num = start_test_samp_num
smallest_samp_diff = abs(samp_diff)
# Iterate through the testing samples
for test_samp_num in range(start_test_samp_num, self.n_test):
test_samp = self.test_sample[test_samp_num]
samp_diff = ref_samp - test_samp
abs_samp_diff = abs(samp_diff)
# Found a better match
if abs_samp_diff < smallest_samp_diff:
closest_samp_num = test_samp_num
smallest_samp_diff = abs_samp_diff
# Stop iterating when the ref sample is first passed or reached
if samp_diff <= 0:
break
return closest_samp_num, smallest_samp_diff
def print_summary(self):
"""
Print summary metrics of the annotation comparisons.
Parameters
----------
N/A
Returns
-------
N/A
"""
if not hasattr(self, 'sensitivity'):
self._calc_stats()
print('%d reference annotations, %d test annotations\n'
% (self.n_ref, self.n_test))
print('True Positives (matched samples): %d' % self.tp)
print('False Positives (unmatched test samples): %d' % self.fp)
print('False Negatives (unmatched reference samples): %d\n' % self.fn)
print('Sensitivity: %.4f (%d/%d)'
% (self.sensitivity, self.tp, self.n_ref))
print('Positive Predictivity: %.4f (%d/%d)'
% (self.positive_predictivity, self.tp, self.n_test))
def plot(self, sig_style='', title=None, figsize=None,
return_fig=False):
"""
Plot the comparison of two sets of annotations, possibly
overlaid on their original signal.
Parameters
----------
sig_style : str, optional
The matplotlib style of the signal
title : str, optional
The title of the plot
figsize: tuple, optional
Tuple pair specifying the width, and height of the figure.
It is the'figsize' argument passed into matplotlib.pyplot's
`figure` function.
return_fig : bool, optional
Whether the figure is to be returned as an output argument.
Returns
-------
fig : matplotlib figure object
The figure information for the plot.
ax : matplotlib axes object
The axes information for the plot.
"""
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(1, 1, 1)
legend = ['Signal',
'Matched Reference Annotations (%d/%d)' % (self.tp, self.n_ref),
'Unmatched Reference Annotations (%d/%d)' % (self.fn, self.n_ref),
'Matched Test Annotations (%d/%d)' % (self.tp, self.n_test),
'Unmatched Test Annotations (%d/%d)' % (self.fp, self.n_test)
]
# Plot the signal if any
if self.signal is not None:
ax.plot(self.signal, sig_style)
# Plot reference annotations
ax.plot(self.matched_ref_sample,
self.signal[self.matched_ref_sample], 'ko')
ax.plot(self.unmatched_ref_sample,
self.signal[self.unmatched_ref_sample], 'ko',
fillstyle='none')
# Plot test annotations
ax.plot(self.matched_test_sample,
self.signal[self.matched_test_sample], 'g+')
ax.plot(self.unmatched_test_sample,
self.signal[self.unmatched_test_sample], 'rx')
ax.legend(legend)
# Just plot annotations
else:
# Plot reference annotations
ax.plot(self.matched_ref_sample, np.ones(self.tp), 'ko')
ax.plot(self.unmatched_ref_sample, np.ones(self.fn), 'ko',
fillstyle='none')
# Plot test annotations
ax.plot(self.matched_test_sample, 0.5 * np.ones(self.tp), 'g+')
ax.plot(self.unmatched_test_sample, 0.5 * np.ones(self.fp), 'rx')
ax.legend(legend[1:])
if title:
ax.set_title(title)
ax.set_xlabel('time/sample')
fig.show()
if return_fig:
return fig, ax
def compare_annotations(ref_sample, test_sample, window_width, signal=None):
"""
Compare a set of reference annotation locations against a set of
test annotation locations. See the Comparitor class docstring
for more information.
Parameters
----------
ref_sample : 1d numpy array
Array of reference sample locations.
test_sample : 1d numpy array
Array of test sample locations to compare.
window_width : int
The maximum absolute difference in sample numbers that is
permitted for matching annotations.
signal : 1d numpy array, optional
The original signal of the two annotations. Only used for
plotting.
Returns
-------
comparitor : Comparitor object
Object containing parameters about the two sets of annotations.
Examples
--------
>>> import wfdb
>>> from wfdb import processing
>>> sig, fields = wfdb.rdsamp('sample-data/100', channels=[0])
>>> ann_ref = wfdb.rdann('sample-data/100','atr')
>>> xqrs = processing.XQRS(sig=sig[:,0], fs=fields['fs'])
>>> xqrs.detect()
>>> comparitor = processing.compare_annotations(ann_ref.sample[1:],
xqrs.qrs_inds,
int(0.1 * fields['fs']),
sig[:,0])
>>> comparitor.print_summary()
>>> comparitor.plot()
"""
comparitor = Comparitor(ref_sample=ref_sample, test_sample=test_sample,
window_width=window_width, signal=signal)
comparitor.compare()
return comparitor
def benchmark_mitdb(detector, verbose=False, print_results=False):
"""
Benchmark a QRS detector against mitdb's records.
Parameters
----------
detector : function
The detector function.
verbose : bool, optional
The verbose option of the detector function.
print_results : bool, optional
Whether to print the overall performance, and the results for
each record.
Returns
-------
comparitors : dictionary
Dictionary of Comparitor objects run on the records, keyed on
the record names.
sensitivity : float
Aggregate sensitivity.
positive_predictivity : float
Aggregate positive_predictivity.
Notes
-----
TODO:
- remove non-qrs detections from reference annotations
- allow kwargs
Examples
--------
>>> import wfdb
>> from wfdb.processing import benchmark_mitdb, xqrs_detect
>>> comparitors, spec, pp = benchmark_mitdb(xqrs_detect)
"""
record_list = get_record_list('mitdb')
n_records = len(record_list)
# Function arguments for starmap
args = zip(record_list, n_records * [detector], n_records * [verbose])
# Run detector and compare against reference annotations for all
# records
with Pool(cpu_count() - 1) as p:
comparitors = p.starmap(benchmark_mitdb_record, args)
# Calculate aggregate stats
sensitivity = np.mean([c.sensitivity for c in comparitors])
positive_predictivity = np.mean(
[c.positive_predictivity for c in comparitors])
comparitors = dict(zip(record_list, comparitors))
print('Benchmark complete')
if print_results:
print('\nOverall MITDB Performance - Sensitivity: %.4f, Positive Predictivity: %.4f\n'
% (sensitivity, positive_predictivity))
for record_name in record_list:
print('Record %s:' % record_name)
comparitors[record_name].print_summary()
print('\n\n')
return comparitors, sensitivity, positive_predictivity
def benchmark_mitdb_record(rec, detector, verbose):
"""
Benchmark a single mitdb record.
Parameters
----------
rec : str
The mitdb record to be benchmarked.
detector : function
The detector function.
verbose : bool
Whether to print the record names (True) or not (False).
Returns
-------
comparitor : Comparitor object
Object containing parameters about the two sets of annotations.
"""
sig, fields = rdsamp(rec, pn_dir='mitdb', channels=[0])
ann_ref = rdann(rec, pn_dir='mitdb', extension='atr')
qrs_inds = detector(sig=sig[:,0], fs=fields['fs'], verbose=verbose)
comparitor = compare_annotations(ref_sample=ann_ref.sample[1:],
test_sample=qrs_inds,
window_width=int(0.1 * fields['fs']))
if verbose:
print('Finished record %s' % rec)
return comparitor