import os import numpy as np from wfdb.io.record import Record, rdrecord from wfdb.io.util import downround, upround from wfdb.io.annotation import Annotation def _expand_channels(signal): """ Convert application-specified signal data to a list. Parameters ---------- signal : 1d or 2d numpy array or list or None The signal or signals to be plotted. If signal is a one-dimensional array, it is assumed to represent a single channel. If it is a two-dimensional array, axes 0 and 1 must represent time and channel number respectively. Otherwise it must be a list of one-dimensional arrays (one for each channel.) Returns ------- signal : list A list of one-dimensional arrays (one for each channel.) """ if signal is None: return [] elif hasattr(signal, "ndim"): if signal.ndim == 1: return [signal] elif signal.ndim == 2: return list(signal.transpose()) else: raise ValueError( "invalid shape for signal array: {}".format(signal.shape) ) else: signal = list(signal) if any(s.ndim != 1 for s in signal): raise ValueError( "invalid shape for signal array(s): {}".format( [s.shape for s in signal] ) ) return signal def _get_sampling_freq(sampling_freq, n_sig, frame_freq): """ Convert application-specified sampling frequency to a list. Parameters ---------- sampling_freq : number or sequence or None The sampling frequency or frequencies of the signals. If this is a list, its length must equal `n_sig`. If unset, defaults to `frame_freq`. n_sig : int Number of channels. frame_freq : number or None Default sampling frequency (record frame frequency). Returns ------- sampling_freq : list The sampling frequency for each channel (a list of length `n_sig`.) """ if sampling_freq is None: return [frame_freq] * n_sig elif hasattr(sampling_freq, "__len__"): if len(sampling_freq) != n_sig: raise ValueError( "length mismatch: n_sig = {}, " "len(sampling_freq) = {}".format(n_sig, len(sampling_freq)) ) return list(sampling_freq) else: return [sampling_freq] * n_sig def _get_ann_freq(ann_freq, n_annot, frame_freq): """ Convert application-specified annotation frequency to a list. Parameters ---------- ann_freq : number or sequence or None The sampling frequency or frequencies of the annotations. If this is a list, its length must equal `n_annot`. If unset, defaults to `frame_freq`. n_annot : int Number of channels. frame_freq : number or None Default sampling frequency (record frame frequency). Returns ------- ann_freq : list The sampling frequency for each channel (a list of length `n_annot`). """ if ann_freq is None: return [frame_freq] * n_annot elif hasattr(ann_freq, "__len__"): if len(ann_freq) != n_annot: raise ValueError( "length mismatch: n_annot = {}, " "len(ann_freq) = {}".format(n_annot, len(ann_freq)) ) return list(ann_freq) else: return [ann_freq] * n_annot def plot_items( signal=None, ann_samp=None, ann_sym=None, fs=None, time_units="samples", sig_name=None, sig_units=None, xlabel=None, ylabel=None, title=None, sig_style=[""], ann_style=["r*"], ecg_grids=[], figsize=None, sharex=False, sharey=False, return_fig=False, return_fig_axes=False, sampling_freq=None, ann_freq=None, ): """ Subplot individual channels of signals and/or annotations. Parameters ---------- signal : 1d or 2d numpy array or list, optional The signal or signals to be plotted. If signal is a one-dimensional array, it is assumed to represent a single channel. If it is a two-dimensional array, axes 0 and 1 must represent time and channel number respectively. Otherwise it must be a list of one-dimensional arrays (one for each channel). ann_samp: list, optional A list of annotation locations to plot, with each list item corresponding to a different channel. List items may be: - 1d numpy array, with values representing sample indices. Empty arrays are skipped. - list, with values representing sample indices. Empty lists are skipped. - None. For channels in which nothing is to be plotted. If `signal` is defined, the annotation locations will be overlaid on the signals, with the list index corresponding to the signal channel. The length of `annotation` does not have to match the number of channels of `signal`. ann_sym: list, optional A list of annotation symbols to plot, with each list item corresponding to a different channel. List items should be lists of strings. The symbols are plotted over the corresponding `ann_samp` index locations. fs : int, float, optional The sampling frequency of the signals and/or annotations. Used to calculate time intervals if `time_units` is not 'samples'. Also required for plotting ECG grids. time_units : str, optional The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes', and 'hours'. sig_name : list, optional A list of strings specifying the signal names. Used with `sig_units` to form y labels, if `ylabel` is not set. sig_units : list, optional A list of strings specifying the units of each signal channel. Used with `sig_name` to form y labels, if `ylabel` is not set. This parameter is required for plotting ECG grids. xlabel : list, optional A list of strings specifying the final x labels to be used. If this option is present, no 'time/'`time_units` is used. ylabel : list, optional A list of strings specifying the final y labels. If this option is present, `sig_name` and `sig_units` will not be used for labels. title : str, optional The title of the graph. sig_style : list, optional A list of strings, specifying the style of the matplotlib plot for each signal channel. The list length should match the number of signal channels. If the list has a length of 1, the style will be used for all channels. ann_style : list, optional A list of strings, specifying the style of the matplotlib plot for each annotation channel. If the list has a length of 1, the style will be used for all channels. ecg_grids : list, optional A list of integers specifying channels in which to plot ECG grids. May also be set to 'all' for all channels. Major grids at 0.5mV, and minor grids at 0.125mV. All channels to be plotted with grids must have `sig_units` equal to 'uV', 'mV', or 'V'. sharex, sharey : bool, optional Controls sharing of properties among x (`sharex`) or y (`sharey`) axes. If True: x- or y-axis will be shared among all subplots. If False, each subplot x- or y-axis will be independent. figsize : tuple, optional Tuple pair specifying the width, and height of the figure. It is the 'figsize' argument passed into matplotlib.pyplot's `figure` function. return_fig : bool, optional Whether the figure is to be returned as an output argument. sampling_freq : number or sequence, optional The sampling frequency or frequencies of the signals. If this is a list, it must have the same length as the number of channels. If unspecified, defaults to `fs`. ann_freq : number or sequence, optional The sampling frequency or frequencies of the annotations. If this is a list, it must have the same length as `ann_samp`. If unspecified, defaults to `fs`. Returns ------- fig : matplotlib figure, optional The matplotlib figure generated. Only returned if the 'return_fig' or 'return_fig_axes' parameter is set to True. axes : matplotlib axes, optional The matplotlib axes generated. Only returned if the 'return_fig_axes' parameter is set to True. Examples -------- >>> record = wfdb.rdrecord('sample-data/100', sampto=3000) >>> ann = wfdb.rdann('sample-data/100', 'atr', sampto=3000) >>> wfdb.plot_items(signal=record.p_signal, ann_samp=[ann.sample, ann.sample], title='MIT-BIH Record 100', time_units='seconds', figsize=(10,4), ecg_grids='all') """ import matplotlib.pyplot as plt # Convert signal to a list if needed signal = _expand_channels(signal) # Figure out number of subplots required sig_len, n_sig, n_annot, n_subplots = _get_plot_dims(signal, ann_samp) # Convert sampling_freq and ann_freq to lists if needed sampling_freq = _get_sampling_freq(sampling_freq, n_sig, fs) ann_freq = _get_ann_freq(ann_freq, n_annot, fs) # Create figure fig, axes = _create_figure(n_subplots, sharex, sharey, figsize) try: if signal is not None: _plot_signal( signal, sig_len, n_sig, fs, time_units, sig_style, axes, sampling_freq=sampling_freq, ) if ann_samp is not None: _plot_annotation( ann_samp, n_annot, ann_sym, signal, n_sig, fs, time_units, ann_style, axes, sampling_freq=sampling_freq, ann_freq=ann_freq, ) if ecg_grids: _plot_ecg_grids( ecg_grids, fs, sig_units, time_units, axes, sampling_freq=sampling_freq, ) # Add title and axis labels. # First, make sure that xlabel and ylabel inputs are valid if xlabel: if len(xlabel) != signal.shape[1]: raise Exception( "The length of the xlabel must be the same as the " "signal: {} values".format(signal.shape[1]) ) if ylabel: if len(ylabel) != n_subplots: raise Exception( "The length of the ylabel must be the same as the " "signal: {} values".format(n_subplots) ) _label_figure( axes, n_subplots, time_units, sig_name, sig_units, xlabel, ylabel, title, ) except BaseException: plt.close(fig) raise if return_fig: return fig if return_fig_axes: return fig, axes plt.show() def _get_plot_dims(signal, ann_samp): """ Figure out the number of plot channels. Parameters ---------- signal : 1d or 2d numpy array or list, optional The signal or signals to be plotted. If signal is a one-dimensional array, it is assumed to represent a single channel. If it is a two-dimensional array, axes 0 and 1 must represent time and channel number respectively. Otherwise it must be a list of one-dimensional arrays (one for each channel). ann_samp: list, optional A list of annotation locations to plot, with each list item corresponding to a different channel. List items may be: - 1d numpy array, with values representing sample indices. Empty arrays are skipped. - list, with values representing sample indices. Empty lists are skipped. - None. For channels in which nothing is to be plotted. If `signal` is defined, the annotation locations will be overlaid on the signals, with the list index corresponding to the signal channel. The length of `annotation` does not have to match the number of channels of `signal`. Returns ------- sig_len : int The signal length (per channel) of the dat file. Deprecated. n_sig : int The number of signals contained in the dat file. n_annot : int The number of annotations contained in the dat file. int The max between number of signals and annotations. """ # Convert signal to a list if needed signal = _expand_channels(signal) if signal: n_sig = len(signal) sig_len = len(signal[0]) if any(len(s) != sig_len for s in signal): sig_len = None else: sig_len = 0 n_sig = 0 if ann_samp is not None: n_annot = len(ann_samp) else: n_annot = 0 return sig_len, n_sig, n_annot, max(n_sig, n_annot) def _create_figure(n_subplots, sharex, sharey, figsize): """ Create the plot figure and subplot axes. Parameters ---------- n_subplots : int The number of subplots to generate. figsize : tuple The figure's width, height in inches. sharex, sharey : bool, optional Controls sharing of properties among x (`sharex`) or y (`sharey`) axes. If True: x- or y-axis will be shared among all subplots. If False, each subplot x- or y-axis will be independent. Returns ------- fig : matplotlib plot object The entire figure that will hold each subplot. axes : list The information needed for each subplot. """ import matplotlib.pyplot as plt fig, axes = plt.subplots( nrows=n_subplots, ncols=1, sharex=sharex, sharey=sharey, figsize=figsize ) if n_subplots == 1: axes = [axes] return fig, axes def _plot_signal( signal, sig_len, n_sig, fs, time_units, sig_style, axes, sampling_freq=None ): """ Plot signal channels. Parameters ---------- signal : 1d or 2d numpy array or list The signal or signals to be plotted. If signal is a one-dimensional array, it is assumed to represent a single channel. If it is a two-dimensional array, axes 0 and 1 must represent time and channel number respectively. Otherwise it must be a list of one-dimensional arrays (one for each channel). sig_len : int The signal length (per channel) of the dat file. Deprecated. n_sig : int The number of signals contained in the dat file. fs : float The sampling frequency of the record. time_units : str The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes', and 'hours'. sig_style : list A list of strings, specifying the style of the matplotlib plot for each signal channel. The list length should match the number of signal channels. If the list has a length of 1, the style will be used for all channels. axes : list The information needed for each subplot. sampling_freq : number or sequence, optional The sampling frequency or frequencies of the signals. If this is a list, it must have the same length as the number of channels. If unspecified, defaults to `fs`. Returns ------- N/A """ # Convert signal to a list if needed signal = _expand_channels(signal) if n_sig == 0: return # Extend signal style if necessary if len(sig_style) == 1: sig_style = n_sig * sig_style # Convert sampling_freq to a list if needed sampling_freq = _get_sampling_freq(sampling_freq, n_sig, fs) tarrays = {} # Plot the signals for ch in range(n_sig): ch_len = len(signal[ch]) ch_freq = sampling_freq[ch] # Figure out time indices try: t = tarrays[ch_len, ch_freq] except KeyError: if time_units == "samples": t = np.linspace(0, ch_len - 1, ch_len) else: downsample_factor = { "seconds": ch_freq, "minutes": ch_freq * 60, "hours": ch_freq * 3600, } t = np.linspace(0, ch_len - 1, ch_len) t /= downsample_factor[time_units] tarrays[ch_len, ch_freq] = t axes[ch].plot(t, signal[ch], sig_style[ch], zorder=3) def _plot_annotation( ann_samp, n_annot, ann_sym, signal, n_sig, fs, time_units, ann_style, axes, sampling_freq=None, ann_freq=None, ): """ Plot annotations, possibly overlaid on signals. ann_samp, n_annot, ann_sym, signal, n_sig, fs, time_units, ann_style, axes Parameters ---------- ann_samp : list The values of the annotation locations. n_annot : int The number of annotations contained in the dat file. ann_sym : list The values of the annotation symbol locations. signal : 1d or 2d numpy array or list The signal or signals to be plotted. If signal is a one-dimensional array, it is assumed to represent a single channel. If it is a two-dimensional array, axes 0 and 1 must represent time and channel number respectively. Otherwise it must be a list of one-dimensional arrays (one for each channel). n_sig : int The number of signals contained in the dat file. fs : float The sampling frequency of the record. time_units : str The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes', and 'hours'. sig_style : list, optional A list of strings, specifying the style of the matplotlib plot for each signal channel. The list length should match the number of signal channels. If the list has a length of 1, the style will be used for all channels. axes : list The information needed for each subplot. sampling_freq : number or sequence, optional The sampling frequency or frequencies of the signals. If this is a list, it must have the same length as the number of channels. If unspecified, defaults to `fs`. ann_freq : number or sequence, optional The sampling frequency or frequencies of the annotations. If this is a list, it must have the same length as `ann_samp`. If unspecified, defaults to `fs`. Returns ------- N/A """ # Convert signal to a list if needed signal = _expand_channels(signal) # Extend annotation style if necessary if len(ann_style) == 1: ann_style = n_annot * ann_style # Convert sampling_freq and ann_freq to lists if needed sampling_freq = _get_sampling_freq(sampling_freq, n_sig, fs) ann_freq = _get_ann_freq(ann_freq, n_annot, fs) # Plot the annotations for ch in range(n_annot): afreq = ann_freq[ch] if ch < n_sig: sfreq = sampling_freq[ch] else: sfreq = afreq # Figure out downsample factor for time indices if time_units == "samples": if afreq is None and sfreq is None: downsample_factor = 1 else: downsample_factor = afreq / sfreq else: downsample_factor = { "seconds": float(afreq), "minutes": float(afreq) * 60, "hours": float(afreq) * 3600, }[time_units] if ann_samp[ch] is not None and len(ann_samp[ch]): # Figure out the y values to plot on a channel basis # 1 dimensional signals try: if n_sig > ch: if sfreq == afreq: index = ann_samp[ch] else: index = (sfreq / afreq * ann_samp[ch]).astype("int") y = signal[ch][index] else: y = np.zeros(len(ann_samp[ch])) except IndexError: raise Exception( "IndexError: try setting shift_samps=True in " 'the "rdann" function?' ) axes[ch].plot( ann_samp[ch] / downsample_factor, y, ann_style[ch], zorder=4 ) # Plot the annotation symbols if any if ann_sym is not None and ann_sym[ch] is not None: for i, s in enumerate(ann_sym[ch]): axes[ch].annotate( s, (ann_samp[ch][i] / downsample_factor, y[i]) ) def _plot_ecg_grids(ecg_grids, fs, units, time_units, axes, sampling_freq=None): """ Add ECG grids to the axes. Parameters ---------- ecg_grids : list, str Whether to add a grid for all the plots ('all') or not. fs : float The sampling frequency of the record. units : list The units used for plotting each signal. time_units : str The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes', and 'hours'. axes : list The information needed for each subplot. sampling_freq : number or sequence, optional The sampling frequency or frequencies of the signals. If this is a list, it must have the same length as the number of channels. If unspecified, defaults to `fs`. Returns ------- N/A """ if ecg_grids == "all": ecg_grids = range(0, len(axes)) # Convert sampling_freq to a list if needed sampling_freq = _get_sampling_freq(sampling_freq, len(axes), fs) for ch in ecg_grids: # Get the initial plot limits auto_xlims = axes[ch].get_xlim() auto_ylims = axes[ch].get_ylim() ( major_ticks_x, minor_ticks_x, major_ticks_y, minor_ticks_y, ) = _calc_ecg_grids( auto_ylims[0], auto_ylims[1], units[ch], sampling_freq[ch], auto_xlims[1], time_units, ) min_x, max_x = np.min(minor_ticks_x), np.max(minor_ticks_x) min_y, max_y = np.min(minor_ticks_y), np.max(minor_ticks_y) for tick in minor_ticks_x: axes[ch].plot( [tick, tick], [min_y, max_y], c="#ededed", marker="|", zorder=1 ) for tick in major_ticks_x: axes[ch].plot( [tick, tick], [min_y, max_y], c="#bababa", marker="|", zorder=2 ) for tick in minor_ticks_y: axes[ch].plot( [min_x, max_x], [tick, tick], c="#ededed", marker="_", zorder=1 ) for tick in major_ticks_y: axes[ch].plot( [min_x, max_x], [tick, tick], c="#bababa", marker="_", zorder=2 ) # Plotting the lines changes the graph. Set the limits back axes[ch].set_xlim(auto_xlims) axes[ch].set_ylim(auto_ylims) def _calc_ecg_grids(minsig, maxsig, sig_units, fs, maxt, time_units): """ Calculate tick intervals for ECG grids. - 5mm 0.2s major grids, 0.04s minor grids. - 0.5mV major grids, 0.125 minor grids. 10 mm is equal to 1mV in voltage. Parameters ---------- minsig : float The min value of the signal. maxsig : float The max value of the signal. sig_units : list The units used for plotting each signal. fs : float The sampling frequency of the signal. maxt : float The max time of the signal. time_units : str The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes', and 'hours'. Returns ------- major_ticks_x : ndarray The locations of the major ticks on the x-axis. minor_ticks_x : ndarray The locations of the minor ticks on the x-axis. major_ticks_y : ndarray The locations of the major ticks on the y-axis. minor_ticks_y : ndarray The locations of the minor ticks on the y-axis. """ # Get the grid interval of the x axis if time_units == "samples": majorx = 0.2 * fs minorx = 0.04 * fs elif time_units == "seconds": majorx = 0.2 minorx = 0.04 elif time_units == "minutes": majorx = 0.2 / 60 minorx = 0.04 / 60 elif time_units == "hours": majorx = 0.2 / 3600 minorx = 0.04 / 3600 # Get the grid interval of the y axis if sig_units.lower() == "uv": majory = 500 minory = 125 elif sig_units.lower() == "mv": majory = 0.5 minory = 0.125 elif sig_units.lower() == "v": majory = 0.0005 minory = 0.000125 else: raise ValueError("Signal units must be uV, mV, or V to plot ECG grids.") major_ticks_x = np.arange(0, upround(maxt, majorx) + 0.0001, majorx) minor_ticks_x = np.arange(0, upround(maxt, majorx) + 0.0001, minorx) major_ticks_y = np.arange( downround(minsig, majory), upround(maxsig, majory) + 0.0001, majory ) minor_ticks_y = np.arange( downround(minsig, majory), upround(maxsig, majory) + 0.0001, minory ) return (major_ticks_x, minor_ticks_x, major_ticks_y, minor_ticks_y) def _label_figure( axes, n_subplots, time_units, sig_name, sig_units, xlabel, ylabel, title ): """ Add title, and axes labels. Parameters ---------- axes : list The information needed for each subplot. n_subplots : int The number of subplots to generate. time_units : str, optional The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes', and 'hours'. sig_name : list, optional A list of strings specifying the signal names. Used with `sig_units` to form y labels, if `ylabel` is not set. sig_units : list, optional A list of strings specifying the units of each signal channel. Used with `sig_name` to form y labels, if `ylabel` is not set. This parameter is required for plotting ECG grids. xlabel : list, optional A list of strings specifying the final x labels to be used. If this option is present, no 'time/'`time_units` is used. ylabel : list, optional A list of strings specifying the final y labels. If this option is present, `sig_name` and `sig_units` will not be used for labels. title : str, optional The title of the graph. Returns ------- N/A """ if title: axes[0].set_title(title) # Determine x label # Explicit labels take precedence if present. Otherwise, construct labels # using signal time units if not xlabel: axes[-1].set_xlabel("/".join(["time", time_units[:-1]])) else: for ch in range(n_subplots): axes[ch].set_xlabel(xlabel[ch]) # Determine y label # Explicit labels take precedence if present. Otherwise, construct labels # using signal names and units if not ylabel: ylabel = [] # Set default channel and signal names if needed if not sig_name: sig_name = ["ch_" + str(i) for i in range(n_subplots)] if not sig_units: sig_units = n_subplots * ["NU"] ylabel = ["/".join(pair) for pair in zip(sig_name, sig_units)] # If there are annotations with channels outside of signal range # put placeholders n_missing_labels = n_subplots - len(ylabel) if n_missing_labels: ylabel = ylabel + [ "ch_%d/NU" % i for i in range(len(ylabel), n_subplots) ] for ch in range(n_subplots): axes[ch].set_ylabel(ylabel[ch]) def plot_wfdb( record=None, annotation=None, plot_sym=False, time_units="seconds", title=None, sig_style=[""], ann_style=["r*"], ecg_grids=[], figsize=None, return_fig=False, sharex="auto", ): """ Subplot individual channels of a WFDB record and/or annotation. This function implements the base functionality of the `plot_items` function, while allowing direct input of WFDB objects. If the record object is input, the function will extract from it: - signal values, from the `e_p_signal`, `e_d_signal`, `p_signal`, or `d_signal` attribute (in that order of priority.) - frame frequency, from the `fs` attribute - signal names, from the `sig_name` attribute - signal units, from the `units` attribute If the annotation object is input, the function will extract from it: - sample locations, from the `sample` attribute - symbols, from the `symbol` attribute - the annotation channels, from the `chan` attribute - the sampling frequency, from the `fs` attribute if present, and if fs was not already extracted from the `record` argument. Parameters ---------- record : WFDB Record, optional The Record object to be plotted. annotation : WFDB Annotation, optional The Annotation object to be plotted. plot_sym : bool, optional Whether to plot the annotation symbols on the graph. time_units : str, optional The x axis unit. Allowed options are: 'samples', 'seconds', 'minutes', and 'hours'. title : str, optional The title of the graph. sig_style : list, optional A list of strings, specifying the style of the matplotlib plot for each signal channel. The list length should match the number of signal channels. If the list has a length of 1, the style will be used for all channels. ann_style : list, optional A list of strings, specifying the style of the matplotlib plot for each annotation channel. The list length should match the number of annotation channels. If the list has a length of 1, the style will be used for all channels. ecg_grids : list, optional A list of integers specifying channels in which to plot ECG grids. May also be set to 'all' for all channels. Major grids at 0.5mV, and minor grids at 0.125mV. All channels to be plotted with grids must have `sig_units` equal to 'uV', 'mV', or 'V'. figsize : tuple, optional Tuple pair specifying the width, and height of the figure. It is the 'figsize' argument passed into matplotlib.pyplot's `figure` function. return_fig : bool, optional Whether the figure is to be returned as an output argument. sharex : bool or 'auto', optional Whether the X axis should be shared between all subplots. If set to True, then all signals will be aligned with each other. If set to False, then each subplot can be panned/zoomed independently. If set to 'auto' (default), then the X axis will be shared unless record is multi-frequency and the time units are set to 'samples'. Returns ------- figure : matplotlib figure, optional The matplotlib figure generated. Only returned if the 'return_fig' option is set to True. Examples -------- >>> record = wfdb.rdrecord('sample-data/100', sampto=3000) >>> annotation = wfdb.rdann('sample-data/100', 'atr', sampto=3000) >>> wfdb.plot_wfdb(record=record, annotation=annotation, plot_sym=True time_units='seconds', title='MIT-BIH Record 100', figsize=(10,4), ecg_grids='all') """ ( signal, ann_samp, ann_sym, fs, ylabel, record_name, sig_units, ) = _get_wfdb_plot_items( record=record, annotation=annotation, plot_sym=plot_sym ) if record: if record.e_p_signal is not None or record.e_d_signal is not None: sampling_freq = [spf * record.fs for spf in record.samps_per_frame] else: sampling_freq = record.fs else: sampling_freq = None if sharex == "auto": # If the sampling frequencies are equal, or if we are using # hours/minutes/seconds as the time unit, then share the X axes so # that the channels are synchronized. If time units are 'samples' # and sampling frequencies are not uniform, then sharing X axes # doesn't work and may even be misleading. if ( time_units == "samples" and isinstance(sampling_freq, list) and any(f != sampling_freq[0] for f in sampling_freq) ): sharex = False else: sharex = True if annotation and annotation.fs is not None: ann_freq = annotation.fs elif record: ann_freq = record.fs else: ann_freq = None return plot_items( signal=signal, ann_samp=ann_samp, ann_sym=ann_sym, fs=fs, time_units=time_units, ylabel=ylabel, title=(title or record_name), sig_style=sig_style, sig_units=sig_units, ann_style=ann_style, ecg_grids=ecg_grids, figsize=figsize, return_fig=return_fig, sampling_freq=sampling_freq, ann_freq=ann_freq, sharex=sharex, ) def _get_wfdb_plot_items(record, annotation, plot_sym): """ Get items to plot from WFDB objects. Parameters ---------- record : WFDB Record The Record object to be plotted annotation : WFDB Annotation The Annotation object to be plotted plot_sym : bool Whether to plot the annotation symbols on the graph. Returns ------- signal : 1d or 2d numpy array The uniformly sampled signal to be plotted. If signal.ndim is 1, it is assumed to be a one channel signal. If it is 2, axes 0 and 1, must represent time and channel number respectively. ann_samp: list A list of annotation locations to plot, with each list item corresponding to a different channel. List items may be: - 1d numpy array, with values representing sample indices. Empty arrays are skipped. - list, with values representing sample indices. Empty lists are skipped. - None. For channels in which nothing is to be plotted. If `signal` is defined, the annotation locations will be overlaid on the signals, with the list index corresponding to the signal channel. The length of `annotation` does not have to match the number of channels of `signal`. ann_sym: list A list of annotation symbols to plot, with each list item corresponding to a different channel. List items should be lists of strings. The symbols are plotted over the corresponding `ann_samp` index locations. fs : int, float The sampling frequency of the signals and/or annotations. Used to calculate time intervals if `time_units` is not 'samples'. Also required for plotting ECG grids. ylabel : list A list of strings specifying the final y labels. If this option is present, `sig_name` and `sig_units` will not be used for labels. record_name : str The string name of the WFDB record to be written (without any file extensions). Must not contain any "." since this would indicate an EDF file which is not compatible at this point. sig_units : list A list of strings specifying the units of each signal channel. Used with `sig_name` to form y labels, if `ylabel` is not set. This parameter is required for plotting ECG grids. """ # Get record attributes if record: if record.e_p_signal is not None: signal = record.e_p_signal n_sig = len(signal) physical = True elif record.e_d_signal is not None: signal = record.e_d_signal n_sig = len(signal) physical = False elif record.p_signal is not None: signal = record.p_signal n_sig = signal.shape[1] physical = True elif record.d_signal is not None: signal = record.d_signal n_sig = signal.shape[1] physical = False else: raise ValueError("The record has no signal to plot") fs = record.fs sig_name = [str(s) for s in record.sig_name] if physical: sig_units = [str(s) for s in record.units] else: sig_units = ["adu"] * n_sig record_name = "Record: %s" % record.record_name ylabel = ["/".join(pair) for pair in zip(sig_name, sig_units)] else: signal = fs = ylabel = record_name = sig_units = None # Get annotation attributes if annotation: # Get channels ann_chans = set(annotation.chan) n_ann_chans = max(ann_chans) + 1 # Indices for each channel chan_inds = n_ann_chans * [np.empty(0, dtype="int")] for chan in ann_chans: chan_inds[chan] = np.where(annotation.chan == chan)[0] ann_samp = [annotation.sample[ci] for ci in chan_inds] if plot_sym: ann_sym = n_ann_chans * [None] for ch in ann_chans: ann_sym[ch] = [annotation.symbol[ci] for ci in chan_inds[ch]] else: ann_sym = None # Try to get fs from annotation if not already in record if fs is None: fs = annotation.fs record_name = record_name or annotation.record_name else: ann_samp = None ann_sym = None # Cleaning: remove empty channels and set labels and styles. # Wrangle together the signal and annotation channels if necessary if record and annotation: # There may be instances in which the annotation `chan` # attribute has non-overlapping channels with the signal. # In this case, omit empty middle channels. This function should # already process labels and arrangements before passing into # `plot_items` sig_chans = set(range(n_sig)) all_chans = sorted(sig_chans.union(ann_chans)) # Need to update ylabels and annotation values if sig_chans != all_chans: compact_ann_samp = [] if plot_sym: compact_ann_sym = [] else: compact_ann_sym = None ylabel = [] for ch in all_chans: # ie. 0, 1, 9 if ch in ann_chans: compact_ann_samp.append(ann_samp[ch]) if plot_sym: compact_ann_sym.append(ann_sym[ch]) if ch in sig_chans: ylabel.append("".join([sig_name[ch], sig_units[ch]])) else: ylabel.append("ch_%d/NU" % ch) ann_samp = compact_ann_samp ann_sym = compact_ann_sym # Signals encompass annotations else: ylabel = ["/".join(pair) for pair in zip(sig_name, sig_units)] # Remove any empty middle channels from annotations elif annotation: ann_samp = [a for a in ann_samp if a.size] if ann_sym is not None: ann_sym = [a for a in ann_sym if a] ylabel = ["ch_%d/NU" % ch for ch in ann_chans] return signal, ann_samp, ann_sym, fs, ylabel, record_name, sig_units def plot_all_records(directory=""): """ Plot all WFDB records in a directory (by finding header files), one at a time, until the 'enter' key is pressed. Parameters ---------- directory : str, optional The directory in which to search for WFDB records. Defaults to current working directory. Returns ------- N/A """ directory = directory or os.getcwd() headers = [ f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) ] headers = [f for f in headers if f.endswith(".hea")] records = [h.split(".hea")[0] for h in headers] records.sort() for record_name in records: record = rdrecord(os.path.join(directory, record_name)) plot_wfdb(record, title="Record - %s" % record.record_name) input("Press enter to continue...")