-
Notifications
You must be signed in to change notification settings - Fork 1
/
tutor_model.py
60 lines (48 loc) · 1.91 KB
/
tutor_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# from langchain.chat_models import ChatOpenAI
import templates
import json
class Tutor:
def __init__(self, API_KEY, role, domain=None):
self.role = role
self.API_KEY = API_KEY
if self.role == "Learn":
situation = PromptTemplate.from_template(
templates.role_templates[self.role]
).format(domain=domain)
else:
situation = templates.role_templates[self.role]
self.situation = situation
# returns a questions based on given role
def get_questions(self):
llm = OpenAI(openai_api_key=self.API_KEY, temperature=0.6)
questions = llm.predict(self.situation)
return questions.split("\n")
def rate_answer(self, question, answer):
llm = OpenAI(openai_api_key=self.API_KEY, temperature=0)
prompt = PromptTemplate.from_template(templates.answer_rating_template).format(
question=question,
answer=answer,
situation=templates.suggestion_situation[self.role],
)
rating = llm.predict(prompt)
# print("DEBUG:", rating)
rating = json.loads(rating)
self.rating = rating
return rating
def get_suggestion(self, question, answer):
# sorting dict wrt values
# to get key with lowest rating_score
rating = dict(sorted(self.rating.items(), key=lambda item: item[1]))
# first key of this dict (lowest rating)
key = list(rating.keys())[0]
llm = OpenAI(openai_api_key=self.API_KEY, temperature=0.6)
prompt = PromptTemplate.from_template(templates.suggestion_template).format(
question=question,
answer=answer,
key=key,
situation=templates.suggestion_situation[self.role],
)
suggestion = llm.predict(prompt)
return suggestion