medcat.components.addons.relation_extraction.pad_seq

Classes

Pad_Sequence

Module Contents

class medcat.components.addons.relation_extraction.pad_seq.Pad_Sequence(seq_pad_value, label_pad_value=-1)
Parameters:
  • seq_pad_value (int)

  • label_pad_value (int)

__init__(seq_pad_value, label_pad_value=-1)

Used in rel_cat.py in RelCAT to create DataLoaders for train/test datasets. collate_fn for dataloader to collate sequences of different input_ids, ent1/ent2, and label lengths into a fixed length batch. This is applied per batch and not on the whole DataLoader data, padded x sequence, y sequence, x lengths and y lengths of batch.

Parameters:
  • seq_pad_value (int) – pad value for input_ids.

  • label_pad_value (int) – pad value for labels. Defaults to -1.

seq_pad_value: int
label_pad_value: int = -1
__call__(batch)

Pads a batch of input_ids.

Parameters:

batch (list[torch.Tensor]) – gets the batch of Tensors from RelData.dataset (check __getitem__() method for data returned) and pads the token sequence + labels as needed See https://pytorch.org/docs/stable/_modules/torch/nn/utils/rnn.html#pad_sequence for extra info.

Returns:

tuple[Tensor, Tensor, Tensor, LongTensor, LongTensor] – padded data padded input ids, ent1&ent2 start token pos, padded labels, padded input_id_lengths, padded labels length

Return type:

tuple[torch.Tensor, list, torch.Tensor, torch.LongTensor, torch.LongTensor]