shruthib sambt commited on
Commit
cd9d2b5
1 Parent(s): 301f526

Fix CXRBertOutput errors (#5)

Browse files

- Fix CXRBertOutput errors (a43c5aab30edfdddc6c7eac3db438823a8cc5c01)


Co-authored-by: Sam Bond-Taylor <[email protected]>

Files changed (1) hide show
  1. modeling_cxrbert.py +4 -2
modeling_cxrbert.py CHANGED
@@ -3,6 +3,7 @@
3
  # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
  # ------------------------------------------------------------------------------------------
5
 
 
6
  from typing import Any, Optional, Tuple, Union
7
 
8
  import torch
@@ -16,9 +17,10 @@ from .configuration_cxrbert import CXRBertConfig
16
 
17
  BERTTupleOutput = Tuple[T, T, T, T, T]
18
 
 
19
  class CXRBertOutput(ModelOutput):
20
- last_hidden_state: torch.FloatTensor
21
- logits: torch.FloatTensor
22
  cls_projected_embedding: Optional[torch.FloatTensor] = None
23
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
24
  attentions: Optional[Tuple[torch.FloatTensor]] = None
 
3
  # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
  # ------------------------------------------------------------------------------------------
5
 
6
+ from dataclasses import dataclass
7
  from typing import Any, Optional, Tuple, Union
8
 
9
  import torch
 
17
 
18
  BERTTupleOutput = Tuple[T, T, T, T, T]
19
 
20
+ @dataclass
21
  class CXRBertOutput(ModelOutput):
22
+ last_hidden_state: torch.FloatTensor = None
23
+ logits: torch.FloatTensor = None
24
  cls_projected_embedding: Optional[torch.FloatTensor] = None
25
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
26
  attentions: Optional[Tuple[torch.FloatTensor]] = None