Waveform Database Software Package (WFDB) for Python 4.1.0
(40,768 bytes)
import os
import numpy as np
from wfdb.io.record import Record, rdrecord
from wfdb.io.util import downround, upround
from wfdb.io.annotation import Annotation
def _expand_channels(signal):
"""
Convert application-specified signal data to a list.
Parameters
----------
signal : 1d or 2d numpy array or list or None
The signal or signals to be plotted. If signal is a
one-dimensional array, it is assumed to represent a single channel.
If it is a two-dimensional array, axes 0 and 1 must represent time
and channel number respectively. Otherwise it must be a list of
one-dimensional arrays (one for each channel.)
Returns
-------
signal : list
A list of one-dimensional arrays (one for each channel.)
"""
if signal is None:
return []
elif hasattr(signal, "ndim"):
if signal.ndim == 1:
return [signal]
elif signal.ndim == 2:
return list(signal.transpose())
else:
raise ValueError(
"invalid shape for signal array: {}".format(signal.shape)
)
else:
signal = list(signal)
if any(s.ndim != 1 for s in signal):
raise ValueError(
"invalid shape for signal array(s): {}".format(
[s.shape for s in signal]
)
)
return signal
def _get_sampling_freq(sampling_freq, n_sig, frame_freq):
"""
Convert application-specified sampling frequency to a list.
Parameters
----------
sampling_freq : number or sequence or None
The sampling frequency or frequencies of the signals. If this is a
list, its length must equal `n_sig`. If unset, defaults to
`frame_freq`.
n_sig : int
Number of channels.
frame_freq : number or None
Default sampling frequency (record frame frequency).
Returns
-------
sampling_freq : list
The sampling frequency for each channel (a list of length `n_sig`.)
"""
if sampling_freq is None:
return [frame_freq] * n_sig
elif hasattr(sampling_freq, "__len__"):
if len(sampling_freq) != n_sig:
raise ValueError(
"length mismatch: n_sig = {}, "
"len(sampling_freq) = {}".format(n_sig, len(sampling_freq))
)
return list(sampling_freq)
else:
return [sampling_freq] * n_sig
def _get_ann_freq(ann_freq, n_annot, frame_freq):
"""
Convert application-specified annotation frequency to a list.
Parameters
----------
ann_freq : number or sequence or None
The sampling frequency or frequencies of the annotations. If this
is a list, its length must equal `n_annot`. If unset, defaults to
`frame_freq`.
n_annot : int
Number of channels.
frame_freq : number or None
Default sampling frequency (record frame frequency).
Returns
-------
ann_freq : list
The sampling frequency for each channel (a list of length `n_annot`).
"""
if ann_freq is None:
return [frame_freq] * n_annot
elif hasattr(ann_freq, "__len__"):
if len(ann_freq) != n_annot:
raise ValueError(
"length mismatch: n_annot = {}, "
"len(ann_freq) = {}".format(n_annot, len(ann_freq))
)
return list(ann_freq)
else:
return [ann_freq] * n_annot
def plot_items(
signal=None,
ann_samp=None,
ann_sym=None,
fs=None,
time_units="samples",
sig_name=None,
sig_units=None,
xlabel=None,
ylabel=None,
title=None,
sig_style=[""],
ann_style=["r*"],
ecg_grids=[],
figsize=None,
sharex=False,
sharey=False,
return_fig=False,
return_fig_axes=False,
sampling_freq=None,
ann_freq=None,
):
"""
Subplot individual channels of signals and/or annotations.
Parameters
----------
signal : 1d or 2d numpy array or list, optional
The signal or signals to be plotted. If signal is a
one-dimensional array, it is assumed to represent a single channel.
If it is a two-dimensional array, axes 0 and 1 must represent time
and channel number respectively. Otherwise it must be a list of
one-dimensional arrays (one for each channel).
ann_samp: list, optional
A list of annotation locations to plot, with each list item
corresponding to a different channel. List items may be:
- 1d numpy array, with values representing sample indices. Empty
arrays are skipped.
- list, with values representing sample indices. Empty lists
are skipped.
- None. For channels in which nothing is to be plotted.
If `signal` is defined, the annotation locations will be overlaid on
the signals, with the list index corresponding to the signal channel.
The length of `annotation` does not have to match the number of
channels of `signal`.
ann_sym: list, optional
A list of annotation symbols to plot, with each list item
corresponding to a different channel. List items should be lists of
strings. The symbols are plotted over the corresponding `ann_samp`
index locations.
fs : int, float, optional
The sampling frequency of the signals and/or annotations. Used to
calculate time intervals if `time_units` is not 'samples'. Also
required for plotting ECG grids.
time_units : str, optional
The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes',
and 'hours'.
sig_name : list, optional
A list of strings specifying the signal names. Used with `sig_units`
to form y labels, if `ylabel` is not set.
sig_units : list, optional
A list of strings specifying the units of each signal channel. Used
with `sig_name` to form y labels, if `ylabel` is not set. This
parameter is required for plotting ECG grids.
xlabel : list, optional
A list of strings specifying the final x labels to be used. If this
option is present, no 'time/'`time_units` is used.
ylabel : list, optional
A list of strings specifying the final y labels. If this option is
present, `sig_name` and `sig_units` will not be used for labels.
title : str, optional
The title of the graph.
sig_style : list, optional
A list of strings, specifying the style of the matplotlib plot
for each signal channel. The list length should match the number
of signal channels. If the list has a length of 1, the style
will be used for all channels.
ann_style : list, optional
A list of strings, specifying the style of the matplotlib plot for each
annotation channel. If the list has a length of 1, the style will be
used for all channels.
ecg_grids : list, optional
A list of integers specifying channels in which to plot ECG grids. May
also be set to 'all' for all channels. Major grids at 0.5mV, and minor
grids at 0.125mV. All channels to be plotted with grids must have
`sig_units` equal to 'uV', 'mV', or 'V'.
sharex, sharey : bool, optional
Controls sharing of properties among x (`sharex`) or y (`sharey`) axes.
If True: x- or y-axis will be shared among all subplots.
If False, each subplot x- or y-axis will be independent.
figsize : tuple, optional
Tuple pair specifying the width, and height of the figure. It is the
'figsize' argument passed into matplotlib.pyplot's `figure` function.
return_fig : bool, optional
Whether the figure is to be returned as an output argument.
sampling_freq : number or sequence, optional
The sampling frequency or frequencies of the signals. If this is a
list, it must have the same length as the number of channels. If
unspecified, defaults to `fs`.
ann_freq : number or sequence, optional
The sampling frequency or frequencies of the annotations. If this
is a list, it must have the same length as `ann_samp`. If
unspecified, defaults to `fs`.
Returns
-------
fig : matplotlib figure, optional
The matplotlib figure generated. Only returned if the 'return_fig'
or 'return_fig_axes' parameter is set to True.
axes : matplotlib axes, optional
The matplotlib axes generated. Only returned if the 'return_fig_axes'
parameter is set to True.
Examples
--------
>>> record = wfdb.rdrecord('sample-data/100', sampto=3000)
>>> ann = wfdb.rdann('sample-data/100', 'atr', sampto=3000)
>>> wfdb.plot_items(signal=record.p_signal,
ann_samp=[ann.sample, ann.sample],
title='MIT-BIH Record 100', time_units='seconds',
figsize=(10,4), ecg_grids='all')
"""
import matplotlib.pyplot as plt
# Convert signal to a list if needed
signal = _expand_channels(signal)
# Figure out number of subplots required
sig_len, n_sig, n_annot, n_subplots = _get_plot_dims(signal, ann_samp)
# Convert sampling_freq and ann_freq to lists if needed
sampling_freq = _get_sampling_freq(sampling_freq, n_sig, fs)
ann_freq = _get_ann_freq(ann_freq, n_annot, fs)
# Create figure
fig, axes = _create_figure(n_subplots, sharex, sharey, figsize)
try:
if signal is not None:
_plot_signal(
signal,
sig_len,
n_sig,
fs,
time_units,
sig_style,
axes,
sampling_freq=sampling_freq,
)
if ann_samp is not None:
_plot_annotation(
ann_samp,
n_annot,
ann_sym,
signal,
n_sig,
fs,
time_units,
ann_style,
axes,
sampling_freq=sampling_freq,
ann_freq=ann_freq,
)
if ecg_grids:
_plot_ecg_grids(
ecg_grids,
fs,
sig_units,
time_units,
axes,
sampling_freq=sampling_freq,
)
# Add title and axis labels.
# First, make sure that xlabel and ylabel inputs are valid
if xlabel:
if len(xlabel) != signal.shape[1]:
raise Exception(
"The length of the xlabel must be the same as the "
"signal: {} values".format(signal.shape[1])
)
if ylabel:
if len(ylabel) != n_subplots:
raise Exception(
"The length of the ylabel must be the same as the "
"signal: {} values".format(n_subplots)
)
_label_figure(
axes,
n_subplots,
time_units,
sig_name,
sig_units,
xlabel,
ylabel,
title,
)
except BaseException:
plt.close(fig)
raise
if return_fig:
return fig
if return_fig_axes:
return fig, axes
plt.show()
def _get_plot_dims(signal, ann_samp):
"""
Figure out the number of plot channels.
Parameters
----------
signal : 1d or 2d numpy array or list, optional
The signal or signals to be plotted. If signal is a
one-dimensional array, it is assumed to represent a single channel.
If it is a two-dimensional array, axes 0 and 1 must represent time
and channel number respectively. Otherwise it must be a list of
one-dimensional arrays (one for each channel).
ann_samp: list, optional
A list of annotation locations to plot, with each list item
corresponding to a different channel. List items may be:
- 1d numpy array, with values representing sample indices. Empty
arrays are skipped.
- list, with values representing sample indices. Empty lists
are skipped.
- None. For channels in which nothing is to be plotted.
If `signal` is defined, the annotation locations will be overlaid on
the signals, with the list index corresponding to the signal channel.
The length of `annotation` does not have to match the number of
channels of `signal`.
Returns
-------
sig_len : int
The signal length (per channel) of the dat file. Deprecated.
n_sig : int
The number of signals contained in the dat file.
n_annot : int
The number of annotations contained in the dat file.
int
The max between number of signals and annotations.
"""
# Convert signal to a list if needed
signal = _expand_channels(signal)
if signal:
n_sig = len(signal)
sig_len = len(signal[0])
if any(len(s) != sig_len for s in signal):
sig_len = None
else:
sig_len = 0
n_sig = 0
if ann_samp is not None:
n_annot = len(ann_samp)
else:
n_annot = 0
return sig_len, n_sig, n_annot, max(n_sig, n_annot)
def _create_figure(n_subplots, sharex, sharey, figsize):
"""
Create the plot figure and subplot axes.
Parameters
----------
n_subplots : int
The number of subplots to generate.
figsize : tuple
The figure's width, height in inches.
sharex, sharey : bool, optional
Controls sharing of properties among x (`sharex`) or y (`sharey`) axes.
If True: x- or y-axis will be shared among all subplots.
If False, each subplot x- or y-axis will be independent.
Returns
-------
fig : matplotlib plot object
The entire figure that will hold each subplot.
axes : list
The information needed for each subplot.
"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(
nrows=n_subplots, ncols=1, sharex=sharex, sharey=sharey, figsize=figsize
)
if n_subplots == 1:
axes = [axes]
return fig, axes
def _plot_signal(
signal, sig_len, n_sig, fs, time_units, sig_style, axes, sampling_freq=None
):
"""
Plot signal channels.
Parameters
----------
signal : 1d or 2d numpy array or list
The signal or signals to be plotted. If signal is a
one-dimensional array, it is assumed to represent a single channel.
If it is a two-dimensional array, axes 0 and 1 must represent time
and channel number respectively. Otherwise it must be a list of
one-dimensional arrays (one for each channel).
sig_len : int
The signal length (per channel) of the dat file. Deprecated.
n_sig : int
The number of signals contained in the dat file.
fs : float
The sampling frequency of the record.
time_units : str
The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes',
and 'hours'.
sig_style : list
A list of strings, specifying the style of the matplotlib plot
for each signal channel. The list length should match the number
of signal channels. If the list has a length of 1, the style
will be used for all channels.
axes : list
The information needed for each subplot.
sampling_freq : number or sequence, optional
The sampling frequency or frequencies of the signals. If this is a
list, it must have the same length as the number of channels. If
unspecified, defaults to `fs`.
Returns
-------
N/A
"""
# Convert signal to a list if needed
signal = _expand_channels(signal)
if n_sig == 0:
return
# Extend signal style if necessary
if len(sig_style) == 1:
sig_style = n_sig * sig_style
# Convert sampling_freq to a list if needed
sampling_freq = _get_sampling_freq(sampling_freq, n_sig, fs)
tarrays = {}
# Plot the signals
for ch in range(n_sig):
ch_len = len(signal[ch])
ch_freq = sampling_freq[ch]
# Figure out time indices
try:
t = tarrays[ch_len, ch_freq]
except KeyError:
if time_units == "samples":
t = np.linspace(0, ch_len - 1, ch_len)
else:
downsample_factor = {
"seconds": ch_freq,
"minutes": ch_freq * 60,
"hours": ch_freq * 3600,
}
t = np.linspace(0, ch_len - 1, ch_len)
t /= downsample_factor[time_units]
tarrays[ch_len, ch_freq] = t
axes[ch].plot(t, signal[ch], sig_style[ch], zorder=3)
def _plot_annotation(
ann_samp,
n_annot,
ann_sym,
signal,
n_sig,
fs,
time_units,
ann_style,
axes,
sampling_freq=None,
ann_freq=None,
):
"""
Plot annotations, possibly overlaid on signals.
ann_samp, n_annot, ann_sym, signal, n_sig, fs, time_units, ann_style, axes
Parameters
----------
ann_samp : list
The values of the annotation locations.
n_annot : int
The number of annotations contained in the dat file.
ann_sym : list
The values of the annotation symbol locations.
signal : 1d or 2d numpy array or list
The signal or signals to be plotted. If signal is a
one-dimensional array, it is assumed to represent a single channel.
If it is a two-dimensional array, axes 0 and 1 must represent time
and channel number respectively. Otherwise it must be a list of
one-dimensional arrays (one for each channel).
n_sig : int
The number of signals contained in the dat file.
fs : float
The sampling frequency of the record.
time_units : str
The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes',
and 'hours'.
sig_style : list, optional
A list of strings, specifying the style of the matplotlib plot
for each signal channel. The list length should match the number
of signal channels. If the list has a length of 1, the style
will be used for all channels.
axes : list
The information needed for each subplot.
sampling_freq : number or sequence, optional
The sampling frequency or frequencies of the signals. If this is a
list, it must have the same length as the number of channels. If
unspecified, defaults to `fs`.
ann_freq : number or sequence, optional
The sampling frequency or frequencies of the annotations. If this
is a list, it must have the same length as `ann_samp`. If
unspecified, defaults to `fs`.
Returns
-------
N/A
"""
# Convert signal to a list if needed
signal = _expand_channels(signal)
# Extend annotation style if necessary
if len(ann_style) == 1:
ann_style = n_annot * ann_style
# Convert sampling_freq and ann_freq to lists if needed
sampling_freq = _get_sampling_freq(sampling_freq, n_sig, fs)
ann_freq = _get_ann_freq(ann_freq, n_annot, fs)
# Plot the annotations
for ch in range(n_annot):
afreq = ann_freq[ch]
if ch < n_sig:
sfreq = sampling_freq[ch]
else:
sfreq = afreq
# Figure out downsample factor for time indices
if time_units == "samples":
if afreq is None and sfreq is None:
downsample_factor = 1
else:
downsample_factor = afreq / sfreq
else:
downsample_factor = {
"seconds": float(afreq),
"minutes": float(afreq) * 60,
"hours": float(afreq) * 3600,
}[time_units]
if ann_samp[ch] is not None and len(ann_samp[ch]):
# Figure out the y values to plot on a channel basis
# 1 dimensional signals
try:
if n_sig > ch:
if sfreq == afreq:
index = ann_samp[ch]
else:
index = (sfreq / afreq * ann_samp[ch]).astype("int")
y = signal[ch][index]
else:
y = np.zeros(len(ann_samp[ch]))
except IndexError:
raise Exception(
"IndexError: try setting shift_samps=True in "
'the "rdann" function?'
)
axes[ch].plot(
ann_samp[ch] / downsample_factor, y, ann_style[ch], zorder=4
)
# Plot the annotation symbols if any
if ann_sym is not None and ann_sym[ch] is not None:
for i, s in enumerate(ann_sym[ch]):
axes[ch].annotate(
s, (ann_samp[ch][i] / downsample_factor, y[i])
)
def _plot_ecg_grids(ecg_grids, fs, units, time_units, axes, sampling_freq=None):
"""
Add ECG grids to the axes.
Parameters
----------
ecg_grids : list, str
Whether to add a grid for all the plots ('all') or not.
fs : float
The sampling frequency of the record.
units : list
The units used for plotting each signal.
time_units : str
The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes',
and 'hours'.
axes : list
The information needed for each subplot.
sampling_freq : number or sequence, optional
The sampling frequency or frequencies of the signals. If this is a
list, it must have the same length as the number of channels. If
unspecified, defaults to `fs`.
Returns
-------
N/A
"""
if ecg_grids == "all":
ecg_grids = range(0, len(axes))
# Convert sampling_freq to a list if needed
sampling_freq = _get_sampling_freq(sampling_freq, len(axes), fs)
for ch in ecg_grids:
# Get the initial plot limits
auto_xlims = axes[ch].get_xlim()
auto_ylims = axes[ch].get_ylim()
(
major_ticks_x,
minor_ticks_x,
major_ticks_y,
minor_ticks_y,
) = _calc_ecg_grids(
auto_ylims[0],
auto_ylims[1],
units[ch],
sampling_freq[ch],
auto_xlims[1],
time_units,
)
min_x, max_x = np.min(minor_ticks_x), np.max(minor_ticks_x)
min_y, max_y = np.min(minor_ticks_y), np.max(minor_ticks_y)
for tick in minor_ticks_x:
axes[ch].plot(
[tick, tick], [min_y, max_y], c="#ededed", marker="|", zorder=1
)
for tick in major_ticks_x:
axes[ch].plot(
[tick, tick], [min_y, max_y], c="#bababa", marker="|", zorder=2
)
for tick in minor_ticks_y:
axes[ch].plot(
[min_x, max_x], [tick, tick], c="#ededed", marker="_", zorder=1
)
for tick in major_ticks_y:
axes[ch].plot(
[min_x, max_x], [tick, tick], c="#bababa", marker="_", zorder=2
)
# Plotting the lines changes the graph. Set the limits back
axes[ch].set_xlim(auto_xlims)
axes[ch].set_ylim(auto_ylims)
def _calc_ecg_grids(minsig, maxsig, sig_units, fs, maxt, time_units):
"""
Calculate tick intervals for ECG grids.
- 5mm 0.2s major grids, 0.04s minor grids.
- 0.5mV major grids, 0.125 minor grids.
10 mm is equal to 1mV in voltage.
Parameters
----------
minsig : float
The min value of the signal.
maxsig : float
The max value of the signal.
sig_units : list
The units used for plotting each signal.
fs : float
The sampling frequency of the signal.
maxt : float
The max time of the signal.
time_units : str
The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes',
and 'hours'.
Returns
-------
major_ticks_x : ndarray
The locations of the major ticks on the x-axis.
minor_ticks_x : ndarray
The locations of the minor ticks on the x-axis.
major_ticks_y : ndarray
The locations of the major ticks on the y-axis.
minor_ticks_y : ndarray
The locations of the minor ticks on the y-axis.
"""
# Get the grid interval of the x axis
if time_units == "samples":
majorx = 0.2 * fs
minorx = 0.04 * fs
elif time_units == "seconds":
majorx = 0.2
minorx = 0.04
elif time_units == "minutes":
majorx = 0.2 / 60
minorx = 0.04 / 60
elif time_units == "hours":
majorx = 0.2 / 3600
minorx = 0.04 / 3600
# Get the grid interval of the y axis
if sig_units.lower() == "uv":
majory = 500
minory = 125
elif sig_units.lower() == "mv":
majory = 0.5
minory = 0.125
elif sig_units.lower() == "v":
majory = 0.0005
minory = 0.000125
else:
raise ValueError("Signal units must be uV, mV, or V to plot ECG grids.")
major_ticks_x = np.arange(0, upround(maxt, majorx) + 0.0001, majorx)
minor_ticks_x = np.arange(0, upround(maxt, majorx) + 0.0001, minorx)
major_ticks_y = np.arange(
downround(minsig, majory), upround(maxsig, majory) + 0.0001, majory
)
minor_ticks_y = np.arange(
downround(minsig, majory), upround(maxsig, majory) + 0.0001, minory
)
return (major_ticks_x, minor_ticks_x, major_ticks_y, minor_ticks_y)
def _label_figure(
axes, n_subplots, time_units, sig_name, sig_units, xlabel, ylabel, title
):
"""
Add title, and axes labels.
Parameters
----------
axes : list
The information needed for each subplot.
n_subplots : int
The number of subplots to generate.
time_units : str, optional
The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes',
and 'hours'.
sig_name : list, optional
A list of strings specifying the signal names. Used with `sig_units`
to form y labels, if `ylabel` is not set.
sig_units : list, optional
A list of strings specifying the units of each signal channel. Used
with `sig_name` to form y labels, if `ylabel` is not set. This
parameter is required for plotting ECG grids.
xlabel : list, optional
A list of strings specifying the final x labels to be used. If this
option is present, no 'time/'`time_units` is used.
ylabel : list, optional
A list of strings specifying the final y labels. If this option is
present, `sig_name` and `sig_units` will not be used for labels.
title : str, optional
The title of the graph.
Returns
-------
N/A
"""
if title:
axes[0].set_title(title)
# Determine x label
# Explicit labels take precedence if present. Otherwise, construct labels
# using signal time units
if not xlabel:
axes[-1].set_xlabel("/".join(["time", time_units[:-1]]))
else:
for ch in range(n_subplots):
axes[ch].set_xlabel(xlabel[ch])
# Determine y label
# Explicit labels take precedence if present. Otherwise, construct labels
# using signal names and units
if not ylabel:
ylabel = []
# Set default channel and signal names if needed
if not sig_name:
sig_name = ["ch_" + str(i) for i in range(n_subplots)]
if not sig_units:
sig_units = n_subplots * ["NU"]
ylabel = ["/".join(pair) for pair in zip(sig_name, sig_units)]
# If there are annotations with channels outside of signal range
# put placeholders
n_missing_labels = n_subplots - len(ylabel)
if n_missing_labels:
ylabel = ylabel + [
"ch_%d/NU" % i for i in range(len(ylabel), n_subplots)
]
for ch in range(n_subplots):
axes[ch].set_ylabel(ylabel[ch])
def plot_wfdb(
record=None,
annotation=None,
plot_sym=False,
time_units="seconds",
title=None,
sig_style=[""],
ann_style=["r*"],
ecg_grids=[],
figsize=None,
return_fig=False,
sharex="auto",
):
"""
Subplot individual channels of a WFDB record and/or annotation.
This function implements the base functionality of the `plot_items`
function, while allowing direct input of WFDB objects.
If the record object is input, the function will extract from it:
- signal values, from the `e_p_signal`, `e_d_signal`, `p_signal`, or
`d_signal` attribute (in that order of priority.)
- frame frequency, from the `fs` attribute
- signal names, from the `sig_name` attribute
- signal units, from the `units` attribute
If the annotation object is input, the function will extract from it:
- sample locations, from the `sample` attribute
- symbols, from the `symbol` attribute
- the annotation channels, from the `chan` attribute
- the sampling frequency, from the `fs` attribute if present, and if fs
was not already extracted from the `record` argument.
Parameters
----------
record : WFDB Record, optional
The Record object to be plotted.
annotation : WFDB Annotation, optional
The Annotation object to be plotted.
plot_sym : bool, optional
Whether to plot the annotation symbols on the graph.
time_units : str, optional
The x axis unit. Allowed options are: 'samples', 'seconds',
'minutes', and 'hours'.
title : str, optional
The title of the graph.
sig_style : list, optional
A list of strings, specifying the style of the matplotlib plot
for each signal channel. The list length should match the number
of signal channels. If the list has a length of 1, the style
will be used for all channels.
ann_style : list, optional
A list of strings, specifying the style of the matplotlib plot
for each annotation channel. The list length should match the
number of annotation channels. If the list has a length of 1,
the style will be used for all channels.
ecg_grids : list, optional
A list of integers specifying channels in which to plot ECG grids. May
also be set to 'all' for all channels. Major grids at 0.5mV, and minor
grids at 0.125mV. All channels to be plotted with grids must have
`sig_units` equal to 'uV', 'mV', or 'V'.
figsize : tuple, optional
Tuple pair specifying the width, and height of the figure. It is the
'figsize' argument passed into matplotlib.pyplot's `figure` function.
return_fig : bool, optional
Whether the figure is to be returned as an output argument.
sharex : bool or 'auto', optional
Whether the X axis should be shared between all subplots. If set
to True, then all signals will be aligned with each other. If set
to False, then each subplot can be panned/zoomed independently. If
set to 'auto' (default), then the X axis will be shared unless
record is multi-frequency and the time units are set to 'samples'.
Returns
-------
figure : matplotlib figure, optional
The matplotlib figure generated. Only returned if the 'return_fig'
option is set to True.
Examples
--------
>>> record = wfdb.rdrecord('sample-data/100', sampto=3000)
>>> annotation = wfdb.rdann('sample-data/100', 'atr', sampto=3000)
>>> wfdb.plot_wfdb(record=record, annotation=annotation, plot_sym=True
time_units='seconds', title='MIT-BIH Record 100',
figsize=(10,4), ecg_grids='all')
"""
(
signal,
ann_samp,
ann_sym,
fs,
ylabel,
record_name,
sig_units,
) = _get_wfdb_plot_items(
record=record, annotation=annotation, plot_sym=plot_sym
)
if record:
if record.e_p_signal is not None or record.e_d_signal is not None:
sampling_freq = [spf * record.fs for spf in record.samps_per_frame]
else:
sampling_freq = record.fs
else:
sampling_freq = None
if sharex == "auto":
# If the sampling frequencies are equal, or if we are using
# hours/minutes/seconds as the time unit, then share the X axes so
# that the channels are synchronized. If time units are 'samples'
# and sampling frequencies are not uniform, then sharing X axes
# doesn't work and may even be misleading.
if (
time_units == "samples"
and isinstance(sampling_freq, list)
and any(f != sampling_freq[0] for f in sampling_freq)
):
sharex = False
else:
sharex = True
if annotation and annotation.fs is not None:
ann_freq = annotation.fs
elif record:
ann_freq = record.fs
else:
ann_freq = None
return plot_items(
signal=signal,
ann_samp=ann_samp,
ann_sym=ann_sym,
fs=fs,
time_units=time_units,
ylabel=ylabel,
title=(title or record_name),
sig_style=sig_style,
sig_units=sig_units,
ann_style=ann_style,
ecg_grids=ecg_grids,
figsize=figsize,
return_fig=return_fig,
sampling_freq=sampling_freq,
ann_freq=ann_freq,
sharex=sharex,
)
def _get_wfdb_plot_items(record, annotation, plot_sym):
"""
Get items to plot from WFDB objects.
Parameters
----------
record : WFDB Record
The Record object to be plotted
annotation : WFDB Annotation
The Annotation object to be plotted
plot_sym : bool
Whether to plot the annotation symbols on the graph.
Returns
-------
signal : 1d or 2d numpy array
The uniformly sampled signal to be plotted. If signal.ndim is 1, it is
assumed to be a one channel signal. If it is 2, axes 0 and 1, must
represent time and channel number respectively.
ann_samp: list
A list of annotation locations to plot, with each list item
corresponding to a different channel. List items may be:
- 1d numpy array, with values representing sample indices. Empty
arrays are skipped.
- list, with values representing sample indices. Empty lists
are skipped.
- None. For channels in which nothing is to be plotted.
If `signal` is defined, the annotation locations will be overlaid on
the signals, with the list index corresponding to the signal channel.
The length of `annotation` does not have to match the number of
channels of `signal`.
ann_sym: list
A list of annotation symbols to plot, with each list item
corresponding to a different channel. List items should be lists of
strings. The symbols are plotted over the corresponding `ann_samp`
index locations.
fs : int, float
The sampling frequency of the signals and/or annotations. Used to
calculate time intervals if `time_units` is not 'samples'. Also
required for plotting ECG grids.
ylabel : list
A list of strings specifying the final y labels. If this option is
present, `sig_name` and `sig_units` will not be used for labels.
record_name : str
The string name of the WFDB record to be written (without any file
extensions). Must not contain any "." since this would indicate an
EDF file which is not compatible at this point.
sig_units : list
A list of strings specifying the units of each signal channel. Used
with `sig_name` to form y labels, if `ylabel` is not set. This
parameter is required for plotting ECG grids.
"""
# Get record attributes
if record:
if record.e_p_signal is not None:
signal = record.e_p_signal
n_sig = len(signal)
physical = True
elif record.e_d_signal is not None:
signal = record.e_d_signal
n_sig = len(signal)
physical = False
elif record.p_signal is not None:
signal = record.p_signal
n_sig = signal.shape[1]
physical = True
elif record.d_signal is not None:
signal = record.d_signal
n_sig = signal.shape[1]
physical = False
else:
raise ValueError("The record has no signal to plot")
fs = record.fs
sig_name = [str(s) for s in record.sig_name]
if physical:
sig_units = [str(s) for s in record.units]
else:
sig_units = ["adu"] * n_sig
record_name = "Record: %s" % record.record_name
ylabel = ["/".join(pair) for pair in zip(sig_name, sig_units)]
else:
signal = fs = ylabel = record_name = sig_units = None
# Get annotation attributes
if annotation:
# Get channels
ann_chans = set(annotation.chan)
n_ann_chans = max(ann_chans) + 1
# Indices for each channel
chan_inds = n_ann_chans * [np.empty(0, dtype="int")]
for chan in ann_chans:
chan_inds[chan] = np.where(annotation.chan == chan)[0]
ann_samp = [annotation.sample[ci] for ci in chan_inds]
if plot_sym:
ann_sym = n_ann_chans * [None]
for ch in ann_chans:
ann_sym[ch] = [annotation.symbol[ci] for ci in chan_inds[ch]]
else:
ann_sym = None
# Try to get fs from annotation if not already in record
if fs is None:
fs = annotation.fs
record_name = record_name or annotation.record_name
else:
ann_samp = None
ann_sym = None
# Cleaning: remove empty channels and set labels and styles.
# Wrangle together the signal and annotation channels if necessary
if record and annotation:
# There may be instances in which the annotation `chan`
# attribute has non-overlapping channels with the signal.
# In this case, omit empty middle channels. This function should
# already process labels and arrangements before passing into
# `plot_items`
sig_chans = set(range(n_sig))
all_chans = sorted(sig_chans.union(ann_chans))
# Need to update ylabels and annotation values
if sig_chans != all_chans:
compact_ann_samp = []
if plot_sym:
compact_ann_sym = []
else:
compact_ann_sym = None
ylabel = []
for ch in all_chans: # ie. 0, 1, 9
if ch in ann_chans:
compact_ann_samp.append(ann_samp[ch])
if plot_sym:
compact_ann_sym.append(ann_sym[ch])
if ch in sig_chans:
ylabel.append("".join([sig_name[ch], sig_units[ch]]))
else:
ylabel.append("ch_%d/NU" % ch)
ann_samp = compact_ann_samp
ann_sym = compact_ann_sym
# Signals encompass annotations
else:
ylabel = ["/".join(pair) for pair in zip(sig_name, sig_units)]
# Remove any empty middle channels from annotations
elif annotation:
ann_samp = [a for a in ann_samp if a.size]
if ann_sym is not None:
ann_sym = [a for a in ann_sym if a]
ylabel = ["ch_%d/NU" % ch for ch in ann_chans]
return signal, ann_samp, ann_sym, fs, ylabel, record_name, sig_units
def plot_all_records(directory=""):
"""
Plot all WFDB records in a directory (by finding header files), one at
a time, until the 'enter' key is pressed.
Parameters
----------
directory : str, optional
The directory in which to search for WFDB records. Defaults to
current working directory.
Returns
-------
N/A
"""
directory = directory or os.getcwd()
headers = [
f
for f in os.listdir(directory)
if os.path.isfile(os.path.join(directory, f))
]
headers = [f for f in headers if f.endswith(".hea")]
records = [h.split(".hea")[0] for h in headers]
records.sort()
for record_name in records:
record = rdrecord(os.path.join(directory, record_name))
plot_wfdb(record, title="Record - %s" % record.record_name)
input("Press enter to continue...")