Transformer-DeID: Deidentification of free-text clinical notes with transformers 1.0.0
(5,813 bytes)
import argparse
import math
from datetime import datetime
import logging
from pathlib import Path
import os
from tqdm import tqdm
import numpy as np
from transformers import AutoModelForTokenClassification
from transformers import Trainer, TrainingArguments
from datasets import load_metric
# local packages
from transformer_deid.evaluation import compute_metrics
from transformer_deid.train import which_transformer_arch
from transformer_deid.model_evaluation_functions import load_data
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
logger = logging.getLogger(__name__)
multi_class_fields = [
'AGEprecision', 'AGErecall', 'AGEf1', 'AGEnumber', 'CONTACTprecision',
'CONTACTrecall', 'CONTACTf1', 'CONTACTnumber', 'DATEprecision',
'DATErecall', 'DATEf1', 'DATEnumber', 'IDprecision', 'IDrecall', 'IDf1',
'IDnumber', 'LOCATIONprecision', 'LOCATIONrecall', 'LOCATIONf1',
'LOCATIONnumber', 'NAMEprecision', 'NAMErecall', 'NAMEf1', 'NAMEnumber',
'PROFESSIONprecision', 'PROFESSIONrecall', 'PROFESSIONf1',
'PROFESSIONnumber', 'overall_precision', 'overall_recall', 'overall_f1',
'overall_accuracy'
]
binary_fields = [
'PHIprecision', 'PHIrecall', 'PHIf1', 'PHInumber', 'overall_precision',
'overall_recall', 'overall_f1', 'overall_accuracy'
]
def flatten_dict(d):
"""
Return flattened version of the evaluation result dict
"""
out = {}
for key in d:
if type(d[key]) is dict:
child = flatten_dict(d[key])
for child_key in child:
val = child[child_key]
if isinstance(val, np.int64):
val = int(val)
out[key + child_key] = val
else:
out[key] = d[key]
return out
def add_row(
path, epochs, results_multiclass, results_binary, multi_class_fields,
binary_fields, test_loss
):
"""
Add row to worksheet
fields: [epochs] + multi_class_fields + binary_fields
"""
root = Path(path).parent
row = [epochs] + [
flatten_dict(results_multiclass).get(field)
for field in multi_class_fields
] + [flatten_dict(results_binary).get(field)
for field in binary_fields] + [test_loss]
text_metrics = ','.join(map(str, row)) + '\n'
with open(str(root) + '/training_eval.csv', 'at') as f:
f.write(text_metrics)
# worksheet.append_row(row, table_range='A1')
def eval_checkpoints(
path, deid_task, train_dataset, val_dataset, test_dataset, training_args
):
step = int(path.split('-')[-1])
steps_per_epoch = math.ceil(
len(train_dataset) / training_args.per_device_train_batch_size
)
epoch = step / steps_per_epoch
model = AutoModelForTokenClassification.from_pretrained(
path, num_labels=len(deid_task.labels)
)
model.eval()
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset
)
predictions, labels, metrics = trainer.predict(test_dataset)
predicted_label = np.argmax(predictions, axis=2)
metric_dir = "transformer_deid/token_evaluation.py"
metric = load_metric(metric_dir)
results_multiclass = compute_metrics(
predicted_label, labels, deid_task.labels, metric=metric
)
results_binary = compute_metrics(
predicted_label,
labels,
deid_task.labels,
metric=metric,
binary_evaluation=True
)
add_row(
path, epoch, results_multiclass, results_binary, multi_class_fields,
binary_fields, metrics['test_loss']
)
def parse_args():
parser = argparse.ArgumentParser(
description='Evaluate transformer-based model at each checkpoint.'
)
parser.add_argument(
'-n',
'--task_name',
type=str,
help=
'name of folder containing train and test data; defaults to i2b2_2014',
default='i2b2_2014'
)
parser.add_argument(
'-m',
'--model',
type=str,
help='folder containing checkpoint files',
default='bert'
)
args = parser.parse_args()
return args
def main():
args = parse_args()
root = f'{args.model}'
arch = args.model.split('results')[0].lower()
epochs = int(args.model.split('results')[1])
task_name = args.task_name
_, tokenizer, _ = which_transformer_arch(arch)
dataDir = f'{task_name}'
testDir = f'{task_name}/test'
deid_task, train_dataset, val_dataset, test_dataset = load_data(
task_name, dataDir, testDir, tokenizer
)
train_batch_size = 8
training_args = TrainingArguments(
output_dir=root,
num_train_epochs=epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
save_strategy='steps',
eval_steps=1155
)
if not os.path.exists(str(root) + '/training_eval.csv'):
with open(str(root) + '/training_eval.csv', 'wt') as f:
header = 'epoch,' + ','.join(
map(str, multi_class_fields + binary_fields + ['test_loss'])
) + '\n'
f.write(header)
checkpoints = [
item for item in os.listdir(root)
if 'checkpoint' in item and os.path.isdir(os.path.join(root, item))
]
for item in tqdm(sorted(checkpoints, key=lambda x: int(x.split('-')[1]))):
path = os.path.join(root, item)
eval_checkpoints(
path, deid_task, train_dataset, val_dataset, test_dataset,
training_args
)
if __name__ == '__main__':
main()