Skip to content

Commit

Permalink
Refactor code and update endpoint names
Browse files Browse the repository at this point in the history
  • Loading branch information
mathewsrc committed Feb 2, 2024
1 parent 4185bc7 commit 315cb50
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,8 @@ class Body(BaseModel):

load_dotenv()

QDRANT_URL = os.environ.get("QDRANT_URL")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
AWS_REGION = 'us-east-1'

boto_session = boto3.Session(region_name=AWS_REGION)
credentials = boto_session.get_credentials()
bedrock_runtime = boto3.client("bedrock-runtime", region_name=AWS_REGION)
session = boto3.Session(region_name=AWS_REGION)

prompt_template = """
Use the following pieces of context to provide a concise answer to the question at the end.
Expand All @@ -43,7 +38,7 @@ class Body(BaseModel):
"""

def get_secret(secret_name):
client = boto_session.client(
client = session.client(
service_name='secretsmanager',
region_name=AWS_REGION
)
Expand All @@ -55,39 +50,46 @@ def get_secret(secret_name):
raise e
return get_secret_value_response['SecretString']

def get_bedrock_embeddings(model_name: str) -> BedrockEmbeddings:

def get_bedrock_embeddings(model_name: str, bedrock_runtime) -> BedrockEmbeddings:
embeddings = BedrockEmbeddings(client=bedrock_runtime, model_id=model_name)
return embeddings

@app.get("/", response_class=HTMLResponse)
async def root():
return HTMLResponse(
"""
<h1>Welcome to our Question/Answering application with Bedrock</h1><br>
<p>Use the /question endpoint to ask a question.</p>
<h1>Welcome to our Question/Answering application where you can ask
questions about the Concurso Nacional Unificado</h1><br>
<p>Use the /ask endpoint to ask a question.</p>
"""
)

@app.post("/ask", response_class=HTMLResponse)
@app.post("/ask")
async def question(body:Body):
try:
if QDRANT_URL is None:
try:
qdrant_url = os.environ.get("QDRANT_URL")
qdrant_api_key = os.environ.get("QDRANT_API_KEY")

bedrock_runtime = boto3.client("bedrock-runtime", region_name=AWS_REGION)

if qdrant_url is None:
qdrant_url = get_secret("prod/qdrant_url")

if QDRANT_API_KEY is None:
if qdrant_api_key is None:
qdrant_api_key = get_secret("prod/qdrant_api_key")

client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
embeddings = get_bedrock_embeddings(BEDROCK_EMBEDDINGS_MODEL_NAME)
embeddings = get_bedrock_embeddings(BEDROCK_EMBEDDINGS_MODEL_NAME, bedrock_runtime)

qdrant = Qdrant(
client=client,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
)

prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
prompt = PromptTemplate(template=prompt_template,
input_variables=["context", "question"])

# Bedrock Hyperparameters
inference_modifier = {
Expand All @@ -114,4 +116,4 @@ async def question(body:Body):
answer = result["result"]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return HTMLResponse(f"<h1>Question: {body.text}</h1><p>Answer: {answer}</p>")
return {"answer": answer}

0 comments on commit 315cb50

Please sign in to comment.