medcat.components.addons.relation_extraction.pad_seq
Classes
Module Contents
- class medcat.components.addons.relation_extraction.pad_seq.Pad_Sequence(seq_pad_value, label_pad_value=-1)
- Parameters:
seq_pad_value (int)
label_pad_value (int)
- __init__(seq_pad_value, label_pad_value=-1)
Used in rel_cat.py in RelCAT to create DataLoaders for train/test datasets. collate_fn for dataloader to collate sequences of different input_ids, ent1/ent2, and label lengths into a fixed length batch. This is applied per batch and not on the whole DataLoader data, padded x sequence, y sequence, x lengths and y lengths of batch.
- Parameters:
seq_pad_value (int) – pad value for input_ids.
label_pad_value (int) – pad value for labels. Defaults to -1.
- seq_pad_value: int
- label_pad_value: int = -1
- __call__(batch)
Pads a batch of input_ids.
- Parameters:
batch (list[torch.Tensor]) – gets the batch of Tensors from RelData.dataset (check __getitem__() method for data returned) and pads the token sequence + labels as needed See https://pytorch.org/docs/stable/_modules/torch/nn/utils/rnn.html#pad_sequence for extra info.
- Returns:
tuple[Tensor, Tensor, Tensor, LongTensor, LongTensor] – padded data padded input ids, ent1&ent2 start token pos, padded labels, padded input_id_lengths, padded labels length
- Return type:
tuple[torch.Tensor, list, torch.Tensor, torch.LongTensor, torch.LongTensor]