anycores commited on
Commit
ff49b11
1 Parent(s): 9a1749c

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +0 -3
  2. convert.py +23 -0
  3. main.cpp +79 -0
  4. xg_runtime_api.h +138 -0
README.md CHANGED
@@ -1,3 +0,0 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
convert.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from onnx import numpy_helper
2
+ import numpy as np
3
+ import onnx
4
+ import ffmpeg
5
+ import argparse
6
+
7
+ # Parameter settings
8
+ parser = argparse.ArgumentParser(description='Whisper format converter')
9
+ parser.add_argument('--ipath', metavar='S', help='path to the input file')
10
+ parser.add_argument('--opath', metavar='S', help='path to the output file (.pb extension)')
11
+ args = parser.parse_args()
12
+
13
+ if __name__ == '__main__':
14
+
15
+ out, _ = (
16
+ ffmpeg.input(args.ipath, threads=0)
17
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
18
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
19
+ )
20
+ audio = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
21
+
22
+ onnx_tp = numpy_helper.from_array(audio, 'raw_audio')
23
+ onnx.save_tensor(onnx_tp, args.opath)
main.cpp ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream>
2
+ #include "xg_runtime_api.h"
3
+
4
+
5
+ void test_whisper(const std::string& weight_path, const std::string& input_path);
6
+
7
+ int main(int argc, char** argv) {
8
+
9
+ if (argc == 3)
10
+ {
11
+ std::string weight_path = argv[1];
12
+ std::string input_path = argv[2];
13
+ test_whisper(weight_path, input_path);
14
+ }
15
+
16
+ return 0;
17
+ }
18
+
19
+ void test_whisper(const std::string& weight_path, const std::string& input_path)
20
+ {
21
+ XgModelInfo minfo = {};
22
+ xg_get_model_info(&minfo);
23
+ std::cout << minfo.model_name << " " << minfo.model_version << std::endl;
24
+
25
+ std::cout << "initing graph" << std::endl;
26
+ XgGraph* graph = nullptr;
27
+ if (xg_init_graph(weight_path, XGWeightSource::XG_ONNX, &graph) != XGResult::XG_SUCCESS)
28
+ {
29
+ std::cout << "Graph init error" << std::endl;
30
+ return;
31
+ }
32
+ else
33
+ {
34
+ std::cout << "Graph init: successful" << std::endl;
35
+ }
36
+
37
+ XgData* input_data = nullptr;
38
+ if (xg_allocate_input_compatible_data(0, &input_data) != XGResult::XG_SUCCESS)
39
+ {
40
+ std::cout << "Input allocation error" << std::endl;
41
+ return;
42
+ }
43
+ else
44
+ {
45
+ std::cout << "Input allocation: successful" << std::endl;
46
+ }
47
+
48
+ // load the data into XgData
49
+ reinterpret_cast<std::string*>(input_data->raw_data)[0] = input_path;
50
+
51
+ if (xg_set_input_data(graph, 0, input_data) != XGResult::XG_SUCCESS)
52
+ {
53
+ std::cout << "Input data set error" << std::endl;
54
+ return;
55
+ }
56
+ else
57
+ {
58
+ std::cout << "Input data set: successful" << std::endl;
59
+ }
60
+
61
+ // execute the graph
62
+ xg_execute_graph(graph);
63
+
64
+ // write output
65
+ XgData* output_data = nullptr;
66
+ if (xg_get_output_data(graph, 0, &output_data) != XGResult::XG_SUCCESS)
67
+ {
68
+ std::cout << "Getting output error" << std::endl;
69
+ return;
70
+ }
71
+ else
72
+ {
73
+ std::cout << "Getting output: successful" << std::endl;
74
+ }
75
+
76
+ // print output
77
+ std::string* o1 = reinterpret_cast<std::string*>(output_data->raw_data);
78
+ std::cout << o1[0] << std::endl;
79
+ }
xg_runtime_api.h ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef __XG_RUNTIME_API__
2
+ #define __XG_RUNTIME_API__
3
+
4
+ #include <vector>
5
+ #include <string>
6
+
7
+ #if _WIN32
8
+ #define XG_API extern "C" __declspec(dllexport)
9
+ #elif __unix__ || __linux__
10
+ #define XG_API extern "C"
11
+ #endif
12
+
13
+ // XG type definitions
14
+
15
+ enum class XGResult
16
+ {
17
+ XG_SUCCESS,
18
+ XG_INPUT_SIZE_MISSMATCH,
19
+ XG_INPUT_TYPE_MISSMATCH,
20
+ XG_WRONG_INPUT_INDEX,
21
+ XG_WRONG_OUTPUT_INDEX,
22
+ // device related
23
+ XG_DEVICE_NOT_SUPPORTED,
24
+ XG_MEMORY_ALLOCATION_FAILED,
25
+ // weight file access
26
+ XG_FILE_NOT_FOUND,
27
+ XG_EXECUTION_FAILED
28
+ };
29
+
30
+ enum class XGWeightSource
31
+ {
32
+ XG_ONNX,
33
+ XG_XGDB
34
+ };
35
+
36
+ enum class XGDataType
37
+ {
38
+ XG_BOOL,
39
+ XG_TOKEN,
40
+ XG_STRING,
41
+ XG_UINT8,
42
+ XG_UINT16,
43
+ XG_UINT32,
44
+ XG_UINT64,
45
+ XG_INT8,
46
+ XG_INT16,
47
+ XG_INT32,
48
+ XG_INT64,
49
+ XG_BFLOAT16,
50
+ XG_FLOAT16,
51
+ XG_FLOAT32,
52
+ XG_FLOAT64
53
+ };
54
+
55
+ // access information about the contained model
56
+ struct XgModelInfo
57
+ {
58
+ std::string model_name;
59
+ std::string model_version;
60
+ std::string device; // cpu, gpu, tpu etc.
61
+ std::string hardware; // e.g. intel i7 9th gen
62
+ unsigned int num_inputs;
63
+ unsigned int num_outputs;
64
+ };
65
+
66
+ XG_API void xg_get_model_info(
67
+ XgModelInfo* model_info
68
+ );
69
+
70
+ XG_API bool is_current_device_supported(); // may be list the supported devices on this machine
71
+
72
+ // create graph
73
+ struct XgGraph;
74
+
75
+ XG_API XGResult xg_init_graph(
76
+ const std::string& weight_path,
77
+ const XGWeightSource weight_source,
78
+ XgGraph** graph
79
+ );
80
+ XG_API XGResult xg_execute_graph(
81
+ XgGraph* graph
82
+ );
83
+ XG_API XGResult xg_destroy_graph(
84
+ XgGraph** graph
85
+ );
86
+
87
+ // set the input to the graph,
88
+ // query the output
89
+
90
+ struct XgData
91
+ {
92
+ XGDataType dtype;
93
+ unsigned int size_in_bytes;
94
+ unsigned int dimension;
95
+ unsigned int length;
96
+ unsigned int* shape;
97
+ char* raw_data;
98
+ };
99
+
100
+ XG_API unsigned int xg_calculate_tensor_size_in_bytes(
101
+ const XGDataType dtype,
102
+ const unsigned int* shape,
103
+ const unsigned int dimension
104
+ );
105
+ XG_API XGResult xg_allocate_input_compatible_data(
106
+ const unsigned int input_idx,
107
+ XgData** data
108
+ );
109
+ XG_API XGResult xg_destroy_data(
110
+ XgData** data
111
+ );
112
+ XG_API XGResult xg_get_output_data(
113
+ const XgGraph* graph,
114
+ const unsigned int output_idx,
115
+ XgData** data
116
+ );
117
+ XG_API XGResult xg_set_input_data(
118
+ const XgGraph* graph,
119
+ const unsigned int input_idx,
120
+ const XgData* data
121
+ );
122
+
123
+ // helper functions
124
+ XG_API bool xg_is_data_bool(const XgData* data);
125
+ XG_API bool xg_is_data_uint8(const XgData* data);
126
+ XG_API bool xg_is_data_uint16(const XgData* data);
127
+ XG_API bool xg_is_data_uint32(const XgData* data);
128
+ XG_API bool xg_is_data_uint64(const XgData* data);
129
+ XG_API bool xg_is_data_int8(const XgData* data);
130
+ XG_API bool xg_is_data_int16(const XgData* data);
131
+ XG_API bool xg_is_data_int32(const XgData* data);
132
+ XG_API bool xg_is_data_int64(const XgData* data);
133
+ XG_API bool xg_is_data_bfloat16(const XgData* data);
134
+ XG_API bool xg_is_data_float16(const XgData* data);
135
+ XG_API bool xg_is_data_float32(const XgData* data);
136
+ XG_API bool xg_is_data_float64(const XgData* data);
137
+
138
+ #endif // __XG_RUNTIME_API__