-
Notifications
You must be signed in to change notification settings - Fork 62
/
__init__.py
executable file
·69 lines (62 loc) · 3.12 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from openwakeword.model import Model
from openwakeword.vad import VAD
from openwakeword.custom_verifier_model import train_custom_verifier
__all__ = ['Model', 'VAD', 'train_custom_verifier']
FEATURE_MODELS = {
"embedding": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/embedding_model.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/embedding_model.tflite"
},
"melspectrogram": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/melspectrogram.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/melspectrogram.tflite"
}
}
VAD_MODELS = {
"silero_vad": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/silero_vad.onnx"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/silero_vad.onnx"
}
}
MODELS = {
"alexa": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/alexa_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/alexa_v0.1.tflite"
},
"hey_mycroft": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_mycroft_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_mycroft_v0.1.tflite"
},
"hey_jarvis": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_jarvis_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_jarvis_v0.1.tflite"
},
"hey_rhasspy": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_rhasspy_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_rhasspy_v0.1.tflite"
},
"timer": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/timer_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/timer_v0.1.tflite"
},
"weather": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/weather_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/weather_v0.1.tflite"
}
}
model_class_mappings = {
"timer": {
"1": "1_minute_timer",
"2": "5_minute_timer",
"3": "10_minute_timer",
"4": "20_minute_timer",
"5": "30_minute_timer",
"6": "1_hour_timer"
}
}
def get_pretrained_model_paths(inference_framework="tflite"):
if inference_framework == "tflite":
return [MODELS[i]["model_path"] for i in MODELS.keys()]
elif inference_framework == "onnx":
return [MODELS[i]["model_path"].replace(".tflite", ".onnx") for i in MODELS.keys()]