RedBERT / README.md
Philipp Traber
Merge branch 'main' of hf.co:traberph/RedBERT
0025d17
---
license: apache-2.0
datasets:
- webis/tldr-17
language:
- en
library_name: transformers
pipeline_tag: text-classification
widget:
- text: "Biden says US is at tipping point on gun control: We will ban assault weapons in this country"
example_title: "classification"
---
# RedBERT - a Reddit post classifier
This model based on distilbert is finetuned to predict the subreddit of a Reddit post.
## Usage
### Preparations
The model uses the transformers library, so make sure to install it.
```
pip install transformers[torch]
```
After the installation, the model can be loaded from Hugging Face.
The model will be sored localy so if you run this lines multiple times the model will be loaded from cache.
```py
from transformers import pipeline
pipe = pipeline("text-classification", model="traberph/RedBERT")
```
### Basic
For a simple classification task just call the pipeline with the text of your choice
```py
text = "I (33f) need to explain to my coworker (30m) I don't want his company on the commute back home"
pipe(text)
```
output:
[{'label': 'relationships', 'score': 0.9622366428375244}]
### Multiclass with visualization
Everyone likes visualizations! Therefore this is an example to output the 5 most probable labels and visualize the result.
Make sure that all requirements are satisfied.
```
pip install pandas seaborn
```
```py
import pandas as pd
import seaborn as sns
# if the model is already loaded this can be skipped
from transformers import pipeline
pipe = pipeline("text-classification", model="traberph/RedBERT")
text = "Today I spilled coffee over my pc. It started to smoke and the screen turned black. I guess I have a problem now."
# predict the 5 most probable labels
res = pipe(text, top_k=5)
# create a pandas dataframe from the result
df = pd.DataFrame(res)
# use seaborn to create a barplot
sns.barplot(df, x='score', y='label', color='steelblue')
```
output:
![](./assets/classify01.png)
## Training
The training of the final version of this model took `130h` on a single `Tesla P100 GPU`.
90% of the [webis/tldr-17](https://huggingface.co/datasets/webis/tldr-17/) where used for this version.
## Bias and Limitations
The webis/tldr-17 dataset used to train this model contains 3 848 330 posts from 29 651 subreddits.
Those posts however are not equally distributed over the subreddits. 589 947 posts belong to the subreddit `AskReddit`, which is `15%` of the whole dataset. Other subreddits are underrepresented.
| top subreddits | distribution |
| --- | --- |
| ![distribution](./assets/distribution01.png) | ![distribution](./assets/distribution02.png) |
This bias in the subreddit distribution is also represented in the model and can be observed during inference.
| class labels for `"Biden says US is at tipping point on gun control: We will ban assault weapons in this country"`, from r/politics |
| --- |
| ![classification](./assets/classify02.png) |