TomokiFujihara's picture
Upload model
fbf5826 verified
raw
history blame contribute delete
No virus
2.02 kB
from transformers import PreTrainedModel
from configuration import *
import torch.nn as nn
import torch
from transformers import AutoModel
class OffensivenessEstimationModel(PreTrainedModel):
config_class = OffensivenessEstimationConfig
def __init__(self, config):
super().__init__(config)
self.text_encoder = PretrainedLanguageModel(config)
self.decoder = nn.Sequential(
nn.Dropout(p=config.dropout_rate),
nn.Linear(1024, config.output_class_num)
)
def forward(self, ids, mask):
h = self.text_encoder(ids, mask)
output = self.decoder(h)
return output
class PretrainedLanguageModel(PreTrainedModel):
config_class = OffensivenessEstimationConfig
def __init__(self, config):
super().__init__(config)
self.language_model = AutoModel.from_pretrained(config.language_model)
self.reinit_n_layers = config.reinit_n_layers
if self.reinit_n_layers > 0:
self._do_reinit()
def _do_reinit(self):
# Re-init last n layers.
for layer in self.language_model.encoder.layer[-1*self.reinit_n_layers:]:
for module in layer.modules():
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.language_model.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.language_model.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, ids, mask):
output = self.language_model(ids, attention_mask=mask)
return output[0][:,0,:]