Philipp Traber commited on
Commit
1038a32
1 Parent(s): a23704d

Updated README and added images

Browse files
README.md CHANGED
@@ -1,48 +1,76 @@
1
- ---
2
- license: apache-2.0
3
- datasets:
4
- - webis/tldr-17
5
- language:
6
- - en
7
- library_name: transformers
8
- pipeline_tag: text-classification
9
- inference: false
10
- ---
11
-
12
- ## Reddit post classification
13
-
14
- This model predicts the subreddit of a provided post
15
- The transformers library is required
16
  ```
17
- pip install 'transformers[torch]'
18
  ```
19
 
 
 
 
20
  ```py
21
  from transformers import pipeline
22
- pipe = pipeline('text-classification', model='traberph/RedBERT')
23
- pipe("Biden says US is at tipping point on gun control: We will ban assault weapons in this country")
24
  ```
25
 
26
- ## Class Labels
27
-
28
- To translate the labels back to subreddit names you need to download the `subreddits.json` file from this repo manually
 
 
 
 
 
29
 
 
 
 
 
 
 
30
  ```py
31
- import json
32
- s_count = 0
33
- s_data = []
34
- with open('subreddits.json', 'r') as file:
35
- s_data = json.load(file)
36
- s_count = len(s_data)
37
- labels = list(s_data.keys())
38
-
39
- def translate(d):
40
- d['label'] = s_data[ labels[ int( d['label'].split('_')[1]) ]]
41
- return d
 
 
 
 
 
 
42
  ```
43
 
44
- Now the class labels can be translated back to subreddits
 
45
 
46
- ```py
47
- list(map(translate, pipe("Biden says US is at tipping point on gun control: We will ban assault weapons in this country")))
48
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RedBERT - a Reddit post classifier
2
+
3
+ This model based on distilbert is finetuned to predict the subreddit of a Reddit post.
4
+
5
+
6
+ ## Usage
7
+ ### Preparations
8
+ The model uses the transformers library, so make sure to install it.
 
 
 
 
 
 
 
9
  ```
10
+ pip install transformers[torch]
11
  ```
12
 
13
+ After the installation, the model can be loaded from Hugging Face.
14
+ The model will be sored localy so if you run this lines multiple times the model will be loaded from cache.
15
+
16
  ```py
17
  from transformers import pipeline
18
+ pipe = pipeline("text-classification", model="traberph/RedBERT")
 
19
  ```
20
 
21
+ ### Basic
22
+ For a simple classification task just call the pipeline with the text of your choice
23
+ ```py
24
+ text = "I (33f) need to explain to my coworker (30m) I don't want his company on the commute back home"
25
+ pipe(text)
26
+ ```
27
+ output:
28
+ [{'label': 'relationships', 'score': 0.9622366428375244}]
29
 
30
+ ### Multiclass with visualization
31
+ Everyone likes visualizations! Therefore this is an example to output the 5 most probable labels and visualize the result.
32
+ Make sure that all requirements are satisfied.
33
+ ```
34
+ pip install pandas seaborn
35
+ ```
36
  ```py
37
+ import pandas as pd
38
+ import seaborn as sns
39
+
40
+ # if the model is already loaded this can be skipped
41
+ from transformers import pipeline
42
+ pipe = pipeline("text-classification", model="traberph/RedBERT")
43
+
44
+ text = "Today I spilled coffee over my pc. It started to smoke and the screen turned black. I guess I have a problem now."
45
+
46
+ # predict the 5 most probable labels
47
+ res = pipe(text, top_k=5)
48
+
49
+ # create a pandas dataframe from the result
50
+ df = pd.DataFrame(res)
51
+
52
+ # use seaborn to create a barplot
53
+ sns.barplot(df, x='score', y='label', color='steelblue')
54
  ```
55
 
56
+ output:
57
+ ![](./assets/classify01.png)
58
 
59
+
60
+ ## Training
61
+ The training of the final version of this model took `130h` on a single `Tesla P100 GPU`.
62
+ 90% of the [webis/tldr-17](https://huggingface.co/datasets/webis/tldr-17/) where used for this version.
63
+
64
+
65
+ ## Bias and Limitations
66
+ The webis/tldr-17 dataset used to train this model contains 3 848 330 posts from 29 651 subreddits.
67
+ Those posts however are not equally distributed over the subreddits. 589 947 posts belong to the subreddit `AskReddit`, which is `15%` of the whole dataset. Other subreddits are underrepresented.
68
+ | top subreddits | distribution |
69
+ | --- | --- |
70
+ | ![distribution](./assets/distribution01.png) | ![distribution](./assets/distribution02.png) |
71
+
72
+
73
+ This bias in the subreddit distribution is also represented in the model and can be observed during inference.
74
+ | class labels for `"Biden says US is at tipping point on gun control: We will ban assault weapons in this country"`, from r/politics |
75
+ | --- |
76
+ | ![classification](./assets/classify02.png) |
assets/classify01.png ADDED
assets/classify02.png ADDED
assets/distribution01.png ADDED
assets/distribution02.png ADDED