GRABMyoFlow - Dataset extension 1.0.0

File: <base>/grabmyoflow_wfdb_visualize_export.py (7,961 bytes)
import sys
import os
import wfdb
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import argparse
import random

# =============================================================================
# README: WFDB Visualization Script
# =============================================================================
# This script reads and visualizes a single WFDB file from the dataset.
#
# FEATURES:
# - Visualizes EMG signals as matplotlib plots
# - Requires participant ID (extension only: 44-63)
# - Optionally specify session, gesture, trial (randomly chosen if not specified)
# - Displays plot with all channels
# =============================================================================

# =============================================================================
# CONSTANTS FOR DEFAULT VALUES
# =============================================================================

NSESSION = 3
NGEESTURE = 16
NTRIALS = 7
DPI = 100
NUMPY_EXPORT_DIR = "numpy_export"

# =============================================================================
# STEP 1: COMMAND-LINE ARGUMENT PARSING
# =============================================================================

parser = argparse.ArgumentParser(
    description="Visualize a single WFDB EMG recording",
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog="""
USAGE EXAMPLES:
  # Visualize participant 44, session 1, gesture 1, trial 1:
  python wfdb_visualize_export.py --ext_path "/path/to/static_extension_WFDB" --participant 44 --session 1 --gesture 1 --trial 1

  # Visualize + export to numpy:
  python wfdb_visualize_export.py --ext_path "/path/to/static_extension_WFDB" --participant 44 --session 2 --export_numpy

  # Visualize participant 44, random session/gesture/trial:
  python wfdb_visualize_export.py --ext_path "/path/to/static_extension_WFDB" --participant 44

  # Visualize participant 50, session 2, random gesture/trial:
  python wfdb_visualize_export.py --ext_path "/path/to/static_extension_WFDB" --participant 50 --session 2
    """
)

parser.add_argument("--ext_path", required=True, help="Path to extension data folder (static_extension_WFDB)")
parser.add_argument("--participant", required=True, type=int, help="Participant ID (extension only: 44-63, REQUIRED)")
parser.add_argument("--session", type=int, help="Session number (1-3, optional, random if not specified)")
parser.add_argument("--gesture", type=int, help="Gesture number (1-16, optional, random if not specified)")
parser.add_argument("--trial", type=int, help="Trial number (1-7, optional, random if not specified)")
parser.add_argument("--export_numpy", action="store_true", help="Export signal to numpy (.npz) format")

args = parser.parse_args()

# Validate paths
if not os.path.exists(args.ext_path):
    print(f"ERROR: Extension path does not exist: {args.ext_path}")
    sys.exit(1)

# Validate participant (extension only)
if args.participant < 44 or args.participant > 63:
    print(f"ERROR: Participant must be in extension range (44-63), got {args.participant}")
    sys.exit(1)

# Set or randomize session, gesture, trial
session = args.session if args.session else random.randint(1, NSESSION)
gesture = args.gesture if args.gesture else random.randint(1, NGEESTURE)
trial = args.trial if args.trial else random.randint(1, NTRIALS)

# Validate ranges
if not (1 <= session <= NSESSION):
    print(f"ERROR: Session must be 1-{NSESSION}, got {session}")
    sys.exit(1)
if not (1 <= gesture <= NGEESTURE):
    print(f"ERROR: Gesture must be 1-{NGEESTURE}, got {gesture}")
    sys.exit(1)
if not (1 <= trial <= NTRIALS):
    print(f"ERROR: Trial must be 1-{NTRIALS}, got {trial}")
    sys.exit(1)

print(f"Parameters:")
print(f"  Participant: {args.participant}")
print(f"  Session: {session}")
print(f"  Gesture: {gesture}")
print(f"  Trial: {trial}\n")

# =============================================================================
# STEP 2: BUILD FILE PATH AND READ WFDB
# =============================================================================

foldername = f"session{session}_participant{args.participant}"
session_dir = f"session{session}"
participant_dir = os.path.join(args.ext_path, session_dir, foldername)

if not os.path.exists(participant_dir):
    print(f"ERROR: Participant folder not found: {participant_dir}")
    sys.exit(1)

filename = f"session{session}_participant{args.participant}_gesture{gesture}_trial{trial}"
wfdb_record_path = os.path.join(participant_dir, filename)

# Check if files exist
if not (os.path.exists(wfdb_record_path + ".dat") or os.path.exists(wfdb_record_path + ".hea")):
    print(f"ERROR: WFDB files not found:")
    print(f"  {wfdb_record_path}.dat")
    print(f"  {wfdb_record_path}.hea")
    sys.exit(1)

# Read WFDB record
try:
    print(f"Reading: {filename}")
    record = wfdb.rdrecord(wfdb_record_path)
    data_emg = record.p_signal
    print(f"  Shape: {data_emg.shape} (samples × channels)")
    print(f"  Sampling rate: {record.fs} Hz")
    print(f"  Duration: {data_emg.shape[0] / record.fs:.2f} seconds")
    print(f"  Channel names: {record.sig_name}")
    print()
except Exception as e:
    print(f"ERROR: Could not read WFDB file: {e}")
    sys.exit(1)

# Extract metadata
fs = record.fs
n_channels = data_emg.shape[1]
n_samples = data_emg.shape[0]
channel_names = record.sig_name if hasattr(record, 'sig_name') else [f"Ch {i+1}" for i in range(n_channels)]

# =============================================================================
# STEP 3: OPTIONAL NUMPY EXPORT
# =============================================================================

if args.export_numpy:
    os.makedirs(NUMPY_EXPORT_DIR, exist_ok=True)
    
    # Prepare metadata
    save_dict = {
        "signal": data_emg,
        "fs": record.fs,
        "n_samples": data_emg.shape[0],
        "n_channels": data_emg.shape[1],
        "participant": args.participant,
        "session": session,
        "gesture": gesture,
        "trial": trial,
    }
    
    # Add channel names if available
    if hasattr(record, 'sig_name'):
        save_dict["channel_names"] = np.array(record.sig_name)
    
    npz_path = os.path.join(NUMPY_EXPORT_DIR, f"{filename}.npz")
    try:
        np.savez(npz_path, **save_dict)
        print(f"Exported to numpy: {npz_path}")
    except Exception as e:
        print(f"ERROR: Could not export to numpy: {e}")

# =============================================================================
# STEP 4: VISUALIZATION
# =============================================================================

print("Displaying plot... (close window to exit)\n")

# Create figure with proper spacing
fig, axes = plt.subplots(n_channels, 1, figsize=(8, 1.0*n_channels))

# Handle single channel
if n_channels == 1:
    axes = [axes]

# Time axis
time_axis = np.arange(n_samples) / fs

# Fixed y-axis limits for all channels
y_lim = [-0.15, 0.15]

# Plot each channel
for ch in range(n_channels):
    axes[ch].plot(time_axis, data_emg[:, ch], linewidth=0.7, color='black')
    axes[ch].set_ylabel(channel_names[ch], fontsize=10, fontweight='bold')
    axes[ch].set_ylim(y_lim)  # Fixed limits: -0.1 to 0.1
    axes[ch].grid(True, alpha=0.2, linestyle='-', linewidth=0.5)
    axes[ch].margins(x=0)
    # Format y-axis with consistent units
    axes[ch].yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, p: f'{x:.2f}'))

# X-axis formatting
axes[-1].set_xlabel("Time (seconds)", fontsize=11)
time_ticks = np.linspace(0, time_axis[-1], 6)
for ax in axes:
    ax.set_xticks(time_ticks)
    ax.set_xticklabels([f"{t:.2f}" for t in time_ticks])

# Title and layout
fig.suptitle(filename, fontsize=12, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

print("Done!")