File size: 3,758 Bytes
3b7b011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from tensorboard.backend.event_processing import event_accumulator

import os
from shutil import copy2
from re import search as RSearch
import pandas as pd
from ast import literal_eval as LEval

weights_dir = 'logs/weights/'

def find_biggest_tensorboard(tensordir):
    try:
        files = [f for f in os.listdir(tensordir) if f.endswith('.0')]
        if not files:
            print("No files with the '.0' extension found!")
            return

        max_size = 0
        biggest_file = ""

        for file in files:
            file_path = os.path.join(tensordir, file)
            if os.path.isfile(file_path):
                file_size = os.path.getsize(file_path)
                if file_size > max_size:
                    max_size = file_size
                    biggest_file = file

        return biggest_file

    except FileNotFoundError:
        print("Couldn't find your model!")
        return

def main(model_name, save_freq, lastmdls):
    global lowestval_weight_dir, scl

    tensordir = os.path.join('logs', model_name)
    lowestval_weight_dir = os.path.join(tensordir, "lowestvals")
    
    latest_file = find_biggest_tensorboard(tensordir)
    
    if latest_file is None:
        print("Couldn't find a valid tensorboard file!")
        return
    
    tfile = os.path.join(tensordir, latest_file)
    
    ea = event_accumulator.EventAccumulator(tfile,
        size_guidance={
        event_accumulator.COMPRESSED_HISTOGRAMS: 500,
        event_accumulator.IMAGES: 4,
        event_accumulator.AUDIO: 4,
        event_accumulator.SCALARS: 0,
        event_accumulator.HISTOGRAMS: 1,
    })

    ea.Reload()
    ea.Tags()

    scl = ea.Scalars('loss/g/total')

    listwstep = {}
    
    for val in scl:
        if (val.step // save_freq) * save_freq in [val.step for val in scl]:
            listwstep[float(val.value)] = (val.step // save_freq) * save_freq

    lowest_vals = sorted(listwstep.keys())[:lastmdls]

    sorted_dict = {value: step for value, step in listwstep.items() if value in lowest_vals}
    
    return sorted_dict

def selectweights(model_name, file_dict, weights_dir, lowestval_weight_dir):
    os.makedirs(lowestval_weight_dir, exist_ok=True)
    logdir = []
    files = []
    lbldict = {
        'Values': {},
        'Names': {}
    }
    weights_dir_path = os.path.join(weights_dir, "")
    low_val_path = os.path.join(os.getcwd(), os.path.join(lowestval_weight_dir, ""))
    
    try:
        file_dict = LEval(file_dict)
    except Exception as e: 
        print(f"Error! {e}")
        return f"Couldn't load tensorboard file! {e}"
    
    weights = [f for f in os.scandir(weights_dir)]
    for key, value in file_dict.items():
        pattern = fr"^{model_name}_.*_s{value}\.pth$"
        matching_weights = [f.name for f in weights if f.is_file() and RSearch(pattern, f.name)]
        for weight in matching_weights:
            source_path = weights_dir_path + weight
            destination_path = os.path.join(lowestval_weight_dir, weight)
            
            copy2(source_path, destination_path)

            logdir.append(f"File = {weight} Value: {key}, Step: {value}")

            lbldict['Names'][weight] = weight
            lbldict['Values'][weight] = key

            files.append(low_val_path + weight)

            print(f"File = {weight} Value: {key}, Step: {value}")

            yield ('\n'.join(logdir), files, pd.DataFrame(lbldict))
            

    return ''.join(logdir), files, pd.DataFrame(lbldict)
    

if __name__ == "__main__":
    model = str(input("Enter the name of the model: "))
    sav_freq = int(input("Enter save frequency of the model: "))
    ds = main(model, sav_freq)
    
    if ds: selectweights(model, ds, weights_dir, lowestval_weight_dir)