Skip to content

Commit

Permalink
Update imports and secrets handling in main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mathewsrc committed Jan 24, 2024
1 parent 662858f commit a106a62
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import boto3
from langchain_community.llms import Bedrock
from botocore.exceptions import ClientError
from langchain_community.llms.bedrock import Bedrock
import os
from dotenv import load_dotenv
from qdrant_client import QdrantClient
Expand All @@ -24,13 +25,10 @@ class Body(BaseModel):

QDRANT_URL = os.environ.get("QDRANT_URL")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
AWS_REGION = os.environ.get("AWS_REGION")

client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
AWS_REGION = 'us-east-1'

boto_session = boto3.Session(region_name=AWS_REGION)
credentials = boto_session.get_credentials()
bedrock_models = boto3.client("bedrock", region_name=AWS_REGION)
bedrock_runtime = boto3.client("bedrock-runtime", region_name=AWS_REGION)

prompt_template = """
Expand All @@ -44,6 +42,20 @@ class Body(BaseModel):
Answer:
"""

def get_secret(secret_name):
client = boto_session.client(
service_name='secretsmanager',
region_name=AWS_REGION
)
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
except ClientError as e:
raise e
return get_secret_value_response['SecretString']


def get_bedrock_embeddings(model_name: str) -> BedrockEmbeddings:
embeddings = BedrockEmbeddings(client=bedrock_runtime, model_id=model_name)
return embeddings
Expand All @@ -60,6 +72,13 @@ async def root():
@app.post("/ask", response_class=HTMLResponse)
async def question(body:Body):
try:
if QDRANT_URL is None:
qdrant_url = get_secret("prod/qdrant_url")

if QDRANT_API_KEY is None:
qdrant_api_key = get_secret("prod/qdrant_api_key")

client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
embeddings = get_bedrock_embeddings(BEDROCK_EMBEDDINGS_MODEL_NAME)

qdrant = Qdrant(
Expand Down

0 comments on commit a106a62

Please sign in to comment.