import sys
import os
import wfdb
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import argparse
import random
# =============================================================================
# README: WFDB Visualization Script
# =============================================================================
# This script reads and visualizes a single WFDB file from the dataset.
#
# FEATURES:
# - Visualizes EMG signals as matplotlib plots
# - Requires participant ID (extension only: 44-63)
# - Optionally specify session, gesture, trial (randomly chosen if not specified)
# - Displays plot with all channels
# =============================================================================
# =============================================================================
# CONSTANTS FOR DEFAULT VALUES
# =============================================================================
NSESSION = 3
NGEESTURE = 16
NTRIALS = 7
DPI = 100
NUMPY_EXPORT_DIR = "numpy_export"
# =============================================================================
# STEP 1: COMMAND-LINE ARGUMENT PARSING
# =============================================================================
parser = argparse.ArgumentParser(
description="Visualize a single WFDB EMG recording",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
USAGE EXAMPLES:
# Visualize participant 44, session 1, gesture 1, trial 1:
python wfdb_visualize_export.py --ext_path "/path/to/static_extension_WFDB" --participant 44 --session 1 --gesture 1 --trial 1
# Visualize + export to numpy:
python wfdb_visualize_export.py --ext_path "/path/to/static_extension_WFDB" --participant 44 --session 2 --export_numpy
# Visualize participant 44, random session/gesture/trial:
python wfdb_visualize_export.py --ext_path "/path/to/static_extension_WFDB" --participant 44
# Visualize participant 50, session 2, random gesture/trial:
python wfdb_visualize_export.py --ext_path "/path/to/static_extension_WFDB" --participant 50 --session 2
"""
)
parser.add_argument("--ext_path", required=True, help="Path to extension data folder (static_extension_WFDB)")
parser.add_argument("--participant", required=True, type=int, help="Participant ID (extension only: 44-63, REQUIRED)")
parser.add_argument("--session", type=int, help="Session number (1-3, optional, random if not specified)")
parser.add_argument("--gesture", type=int, help="Gesture number (1-16, optional, random if not specified)")
parser.add_argument("--trial", type=int, help="Trial number (1-7, optional, random if not specified)")
parser.add_argument("--export_numpy", action="store_true", help="Export signal to numpy (.npz) format")
args = parser.parse_args()
# Validate paths
if not os.path.exists(args.ext_path):
print(f"ERROR: Extension path does not exist: {args.ext_path}")
sys.exit(1)
# Validate participant (extension only)
if args.participant < 44 or args.participant > 63:
print(f"ERROR: Participant must be in extension range (44-63), got {args.participant}")
sys.exit(1)
# Set or randomize session, gesture, trial
session = args.session if args.session else random.randint(1, NSESSION)
gesture = args.gesture if args.gesture else random.randint(1, NGEESTURE)
trial = args.trial if args.trial else random.randint(1, NTRIALS)
# Validate ranges
if not (1 <= session <= NSESSION):
print(f"ERROR: Session must be 1-{NSESSION}, got {session}")
sys.exit(1)
if not (1 <= gesture <= NGEESTURE):
print(f"ERROR: Gesture must be 1-{NGEESTURE}, got {gesture}")
sys.exit(1)
if not (1 <= trial <= NTRIALS):
print(f"ERROR: Trial must be 1-{NTRIALS}, got {trial}")
sys.exit(1)
print(f"Parameters:")
print(f" Participant: {args.participant}")
print(f" Session: {session}")
print(f" Gesture: {gesture}")
print(f" Trial: {trial}\n")
# =============================================================================
# STEP 2: BUILD FILE PATH AND READ WFDB
# =============================================================================
foldername = f"session{session}_participant{args.participant}"
session_dir = f"session{session}"
participant_dir = os.path.join(args.ext_path, session_dir, foldername)
if not os.path.exists(participant_dir):
print(f"ERROR: Participant folder not found: {participant_dir}")
sys.exit(1)
filename = f"session{session}_participant{args.participant}_gesture{gesture}_trial{trial}"
wfdb_record_path = os.path.join(participant_dir, filename)
# Check if files exist
if not (os.path.exists(wfdb_record_path + ".dat") or os.path.exists(wfdb_record_path + ".hea")):
print(f"ERROR: WFDB files not found:")
print(f" {wfdb_record_path}.dat")
print(f" {wfdb_record_path}.hea")
sys.exit(1)
# Read WFDB record
try:
print(f"Reading: {filename}")
record = wfdb.rdrecord(wfdb_record_path)
data_emg = record.p_signal
print(f" Shape: {data_emg.shape} (samples × channels)")
print(f" Sampling rate: {record.fs} Hz")
print(f" Duration: {data_emg.shape[0] / record.fs:.2f} seconds")
print(f" Channel names: {record.sig_name}")
print()
except Exception as e:
print(f"ERROR: Could not read WFDB file: {e}")
sys.exit(1)
# Extract metadata
fs = record.fs
n_channels = data_emg.shape[1]
n_samples = data_emg.shape[0]
channel_names = record.sig_name if hasattr(record, 'sig_name') else [f"Ch {i+1}" for i in range(n_channels)]
# =============================================================================
# STEP 3: OPTIONAL NUMPY EXPORT
# =============================================================================
if args.export_numpy:
os.makedirs(NUMPY_EXPORT_DIR, exist_ok=True)
# Prepare metadata
save_dict = {
"signal": data_emg,
"fs": record.fs,
"n_samples": data_emg.shape[0],
"n_channels": data_emg.shape[1],
"participant": args.participant,
"session": session,
"gesture": gesture,
"trial": trial,
}
# Add channel names if available
if hasattr(record, 'sig_name'):
save_dict["channel_names"] = np.array(record.sig_name)
npz_path = os.path.join(NUMPY_EXPORT_DIR, f"{filename}.npz")
try:
np.savez(npz_path, **save_dict)
print(f"Exported to numpy: {npz_path}")
except Exception as e:
print(f"ERROR: Could not export to numpy: {e}")
# =============================================================================
# STEP 4: VISUALIZATION
# =============================================================================
print("Displaying plot... (close window to exit)\n")
# Create figure with proper spacing
fig, axes = plt.subplots(n_channels, 1, figsize=(8, 1.0*n_channels))
# Handle single channel
if n_channels == 1:
axes = [axes]
# Time axis
time_axis = np.arange(n_samples) / fs
# Fixed y-axis limits for all channels
y_lim = [-0.15, 0.15]
# Plot each channel
for ch in range(n_channels):
axes[ch].plot(time_axis, data_emg[:, ch], linewidth=0.7, color='black')
axes[ch].set_ylabel(channel_names[ch], fontsize=10, fontweight='bold')
axes[ch].set_ylim(y_lim) # Fixed limits: -0.1 to 0.1
axes[ch].grid(True, alpha=0.2, linestyle='-', linewidth=0.5)
axes[ch].margins(x=0)
# Format y-axis with consistent units
axes[ch].yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, p: f'{x:.2f}'))
# X-axis formatting
axes[-1].set_xlabel("Time (seconds)", fontsize=11)
time_ticks = np.linspace(0, time_axis[-1], 6)
for ax in axes:
ax.set_xticks(time_ticks)
ax.set_xticklabels([f"{t:.2f}" for t in time_ticks])
# Title and layout
fig.suptitle(filename, fontsize=12, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()
print("Done!")