Mind the Gap: The PhysioNet/Computing in Cardiology Challenge 2010 1.0.0

File: <base>/sources/Luiz-Silva/RNNCinC2010.java (10,719 bytes)
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;
import java.util.Scanner;
import java.util.Vector;

/**
 * RNNCinC2010.java
 *
 * This class implements a solution for Physionet/CinC Challenge 2010.
 *
 * The solution uses a recurrent neural network (RNN) to predict the last 3750 samples
 * of the zero padded channel of a multichannel signal.
 *
 * @author Juliano Jinzenji Duque <julianojd@gmail.com>
 * @author Luiz Eduardo Virgilio da Silva <luizeduardovs@gmail.com>
 *
 * CSIM
 * Computing on Signals and Images on Medicine Group
 * University of Sao Paulo
 * Ribeirao Preto - SP - Brazil
 */
public class RNNCinC2010 {

    // Parameters
    private static int iterations;
    private static double learningRate;
    private static int numNeuronsHiddenLayer;
    private static String file;

    
    public static void main(String[] args) {
        readParameters();
        double[][] signal = readSignal(file);

        int nsig = signal.length;
        int sigSize = signal[0].length;
        int startGap = 71250;

        // Removing fully flat channels
        boolean[] discardedSignal = new boolean[nsig];
        for(int i=0; i<nsig; i++) {
                double[] sig1 = Arrays.copyOfRange(signal[i], 0, startGap);
                double max = max(sig1);
                double min = min(sig1);

                if(max-min == 0)
                    discardedSignal[i] = true;
        }

        // Counting discarded channels
        int nDiscarded = 0;
        for(int i=0; i<nsig; i++)
            if(discardedSignal[i])
                nDiscarded++;

        if(numNeuronsHiddenLayer == 0) // If default, set 2*inputs + 1
            numNeuronsHiddenLayer = 2*(nsig-nDiscarded)+1; 

        System.out.println("USING:");
        System.out.println("Learning rate = "+learningRate);
        System.out.println("Neurons in hidden layer = "+numNeuronsHiddenLayer);
        System.out.println("Iterations = "+iterations+"\n");

        // Detecting signal with GAP
        double minSM = 9999.0;
        int flatSig = 0;
        for(int i=0; i<signal.length; i++) {
            // Looking for flat stretch
            double[] sig1 = Arrays.copyOfRange(signal[i], startGap, sigSize);
            double stdev = stdev(sig1);
            double mean = mean(sig1);

            if(stdev+mean < minSM) {
                minSM = stdev+mean;
                flatSig = i;
            }
        }
        System.out.println("Signal with gap: "+flatSig);
        

        // Normalizing data for MLP training.
        // Desired output (flatSig) between 0.2 and 0.8.
        // Inputs (other channels) between 0.0 and 1.0.
        double[] normDesiredOutput = normalize(Arrays.copyOfRange(signal[flatSig],0,startGap),0.2,0.8);
        double[][] normInputs = new double[nsig][];
        for(int i=0; i<nsig; i++) {
            if(i!=flatSig && !discardedSignal[i])
                normInputs[i] = normalize(signal[i],0.0,1.0);
        }
        
        // Predicted series for channel with gap
        double[] predicted = new double[sigSize];
        
        
        // Training RNN
        MLP mlp = new MLP(nsig-nDiscarded,numNeuronsHiddenLayer,1);
        mlp.setLearningRate(learningRate);

        System.out.println("Training...");
        double lastOutput=0.0;
        for(int loops=0; loops<iterations; loops++) {

            lastOutput=0.0;
            for(int i=0; i<startGap; i++) {  // Training data range from 0 to startGap
                double[] input = new double[nsig];
                int cont=0;

                for(int j=0; j<nsig; j++) {
                    if(j!=flatSig && !discardedSignal[j]) {
                        input[cont] = normInputs[j][i];
                        cont++;
                    }
                }
                // There will be no prediction in this interval
                predicted[i] = signal[flatSig][i];

                input[cont] = lastOutput;

                double[] desiredOutput = new double[1];
                desiredOutput[0] = normDesiredOutput[i];
                lastOutput = mlp.train(input, desiredOutput)[1];
            }
        }

        // Now predicting GAP
        // lastOutput is already the desired value
        for(int i=startGap; i<sigSize; i++) {
            double[] input = new double[nsig];
            int cont=0;

            for(int j=0; j<nsig; j++) {
                if(j!=flatSig && !discardedSignal[j]) {
                    input[cont] = normInputs[j][i];
                    cont++;
                }
            }
            input[cont] = lastOutput;

            predicted[i] = mlp.passNet(input)[1];  // Network output starts from 1
            lastOutput = predicted[i];
        }


        // Denormalizing predicted values
        double[] cutFlatSignal = Arrays.copyOfRange(signal[flatSig],0,startGap-1);
        double max = max(cutFlatSignal);
        double min = min(cutFlatSignal);

        // If missing channel is fully flat, its missing gap is 
        // filled with its constant value
        if(discardedSignal[flatSig] == true)
            for(int i=startGap; i<sigSize; i++)
                predicted[i] = signal[flatSig][10];  // Any index can be used
        else
            for(int i=startGap; i<sigSize; i++)
                predicted[i] = (predicted[i]-0.2)*(max-min)/(0.8-0.2) + min;


        // Saving prediction
        try {
            BufferedWriter bw = new BufferedWriter(new FileWriter(file+"_prediction.txt"));
            for(int i=startGap; i<sigSize; i++) {
                bw.write(""+predicted[i]);
                bw.newLine();
            }
            bw.close();

            System.out.println("Reconstruction saved in file '"+file+"_prediction'");
        } catch(IOException e) {
            e.printStackTrace();
        }
    }


    /**
     * Read the parameters from standard input.
     */
    private static void readParameters() {
        Scanner scanner = new Scanner(System.in);
        
        System.out.print("\nEnter the learnig rate or press ENTER for default (0.1): ");
        try {
            learningRate = Double.parseDouble(scanner.nextLine());
        } catch(Exception e) {
            learningRate = 0.1;
        }

        System.out.print("Enter the hidden layer neurons number or press ENTER for default (2*inputs + 1): ");
        try {
            numNeuronsHiddenLayer = Integer.parseInt(scanner.nextLine());
        } catch(Exception e) {
            numNeuronsHiddenLayer = 0;  // Flag to set 2*input + 1 after read signal
        }
        
        System.out.print("Enter the number of training iterations or press ENTER for default (500): ");
        try {
            iterations = Integer.parseInt(scanner.nextLine());
        } catch(Exception e) {
            iterations = 500;
        }
        
        System.out.print("Enter the multi-channel signal file path: ");
        file = scanner.nextLine();
    }

    /**
     * Reads a multichannel signal with channels over collumns, separeted from
     * each other by a space character (" ").
     *
     * @param path String representing the file path
     * @return 2D array with channels of signal
     */
    private static double[][] readSignal(String path) {
        double[][] signal = null;

        try {
            BufferedReader br = new BufferedReader(new FileReader(path));
            Vector<String> lines = new Vector<String>();

            // Reading all lines from file
            String lineAux = br.readLine();
            while(lineAux != null) {
                lines.add(lineAux);
                lineAux = br.readLine();
            }

            // Getting channels
            int numChannels = lines.elementAt(0).split(" ").length;  // Samples at first line
            signal = new double[numChannels][lines.size()];
            for(int i=0; i<lines.size(); i++) {
                String[] splittedLine = lines.elementAt(i).split(" ");
                for(int j=0; j<splittedLine.length; j++)
                    signal[j][i] = Double.parseDouble(splittedLine[j]);
            }
        } catch(FileNotFoundException e) {
            System.out.println("Error: File not found.");
            System.exit(1);
        } catch(IOException e) {
            e.printStackTrace();
            System.exit(1);
        }

        return signal;
    }

    /**
     * Normalize an array
     *
     * @param vec the array to be normalized
     * @param lower the new lower bound of <code>vec</code>
     * @param upper the new upper bound of <code>vec</code>
     * @return a new array with normalized values.
     */
    public static double[] normalize(double[] vec, double lower, double upper) {
        double[] normalized = new double[vec.length];

        double max = max(vec);
        double min = min(vec);
        for(int i=0; i<normalized.length; i++) {
            normalized[i] = (vec[i] - min)*(upper - lower)/(max - min) + lower;
        }

        return normalized;
    }

    /**
     * Calculates the mean value of <code>array</code>
     *
     * @param array the array of values
     * @return the mean value of <code>array</code>
     */
    public static double mean(double[] array) {
        double sum = 0.0;

        for(int i=0; i<array.length; i++)
            sum += array[i];

        return (sum/array.length);
    }

    /**
     * Calculates the standar deviation of values in <code>array</code>
     *
     * @param array the array of values
     * @return the standard deviation
     */
    public static double stdev(double[] serie) {
        double sd = 0.0;
        double mean = mean(serie);

        for(int i=0; i<serie.length; i++)
            sd += (serie[i]-mean) * (serie[i]-mean);

        return Math.sqrt(sd/serie.length);
    }

    /**
     * Calculates the minimum value of <code>array</code>
     *
     * @param array the array of values
     * @return the minimum value of <code>array</code>
     */
    public static double min(double[] signal) {
        double min = Double.MAX_VALUE;

        for(int i=0; i<signal.length; i++)
            if(signal[i] < min)
                min = signal[i];

        return min;
    }

    /**
     * Calculates the maximum value of <code>array</code>
     *
     * @param array the array of values
     * @return the maximum value of <code>array</code>
     */
    public static double max(double[] signal) {
        double max = -Double.MAX_VALUE;

        for(int i=0; i<signal.length; i++)
            if(signal[i] > max)
                max = signal[i];

        return max;
    }

}