Transformer-DeID: Deidentification of free-text clinical notes with transformers 1.0.0

File: <base>/transformer_deid/train.py (5,496 bytes)
import torch
import random
import os
import numpy as np
import logging
import argparse
from pathlib import Path
from transformers import DistilBertForTokenClassification, BertForTokenClassification, RobertaForTokenClassification
from transformers import BertTokenizerFast, DistilBertTokenizerFast, RobertaTokenizerFast
from transformers import Trainer, TrainingArguments
from transformer_deid.load_data import create_deid_dataset, get_labels, load_data

logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
    level=logging.INFO)
logger = logging.getLogger(__name__)


def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def which_transformer_arch(baseArchitecture):
    """ Gets architecture-specific parameters for each supported base architecture. """
    if baseArchitecture == 'bert':
        load_model = BertForTokenClassification.from_pretrained
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
        baseArchitecture = 'bert-base-cased'

    elif baseArchitecture == 'roberta':
        load_model = RobertaForTokenClassification.from_pretrained
        tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
        baseArchitecture = 'roberta-base'

    elif baseArchitecture == 'distilbert':
        load_model = DistilBertForTokenClassification.from_pretrained
        tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
        baseArchitecture = 'distilbert-base-cased'
    else:
        raise NotImplementedError(f'{baseArchitecture} not a supported model.')

    return load_model, tokenizer, baseArchitecture


def train(train_data_dict, architecture, epochs, out_dir):
    """ Trains a transformer-based deidentification model over {epochs} using {architecture}. 

        Args:
            - train_data_dict: dict with 'txt', 'ann', and 'guid' keys with the training data
            - architecture: transformer model to use 
                e.g., bert, roberta, or distilbert
            - epochs: int, number of epochs to use
            - out_dir: directory to which to save the model; only passed because it's mandatory for TrainingArguments()
        
        Returns: trained HuggingFace Trainer object
    """
    seed_everything(42)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    logger.info(f'Running using {device}.')

    load_model, tokenizer, baseArchitecture = which_transformer_arch(
        architecture)

    unique_labels = get_labels(train_data_dict['ann'])
    label2id = {tag: id for id, tag in enumerate(unique_labels)}
    id2label = {id: tag for tag, id in label2id.items()}

    train_dataset = create_deid_dataset(train_data_dict, tokenizer, label2id)

    model = load_model(baseArchitecture,
                       num_labels=len(unique_labels),
                       label2id=label2id,
                       id2label=id2label).to(device)

    train_batch_size = 8
    training_args = TrainingArguments(
        output_dir=out_dir,  # this is a mandatory argument
        num_train_epochs=epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=8,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
        save_strategy='no',  # probably make none? or make a parameter?
    )

    trainer = Trainer(model=model,
                      args=training_args,
                      train_dataset=train_dataset)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", training_args.num_train_epochs)

    trainer.train()

    return trainer


def parse_args():
    parser = argparse.ArgumentParser(
        description='Train a transformer-based PHI deidentification model.')

    parser.add_argument('-i',
                        '--train_path',
                        type=str,
                        help=
                        'string to diretory containing txt and ann directories for the training set.')

    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        help='number of epochs to train over',
                        default=5)

    parser.add_argument('-m',
                        '--model_architecture',
                        type=str,
                        choices=['bert', 'distilbert', 'roberta'],
                        help='name of model architecture, either bert, roberta, or distilbert',
                        default='bert')

    parser.add_argument('-o',
                        '--output_path',
                        help='output path in which to save the model')

    args = parser.parse_args()

    return args


def main(args):
    # arguments
    train_path = args.train_path
    out_path = args.output_path
    model = args.model_architecture
    epochs = args.epochs

    # read in data from training directory
    data_dict = load_data(Path(train_path))

    trainer = train(data_dict, model, epochs, out_path)

    save_location = f'{out_path}/{model}_model_{epochs}'
    trainer.save_model(save_location)


if __name__ == '__main__':
    args = parse_args()
    main(args)