Skip to content

Commit

Permalink
Add Document structure with source to VectorStore
Browse files Browse the repository at this point in the history
  • Loading branch information
raulraja committed Jun 20, 2024
1 parent a211c79 commit 31a2d89
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@ constructor(
val conversationId: ConversationId? = ConversationId(UUID.generateUUID().toString())
) {

@AiDsl
@JvmSynthetic
suspend fun addContext(vararg docs: String) {
store.addTexts(docs.toList())
}

@AiDsl
@JvmSynthetic
suspend fun addContext(docs: Iterable<String>): Unit {
store.addTexts(docs.toList())
}

companion object {

@JvmSynthetic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ class CombinedVectorStore(private val top: VectorStore, private val bottom: Vect
.reversed()
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
override suspend fun similaritySearch(query: String, limit: Int): List<VectorStore.Document> {
val topResults = top.similaritySearch(query, limit)
return when {
topResults.size >= limit -> topResults
else -> topResults + bottom.similaritySearch(query, limit - topResults.size)
}
}

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> {
override suspend fun similaritySearchByVector(
embedding: Embedding,
limit: Int
): List<VectorStore.Document> {
val topResults = top.similaritySearchByVector(embedding, limit)
return when {
topResults.size >= limit -> topResults
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import kotlin.math.sqrt

private data class State(
val orderedMemories: Map<ConversationId, List<Memory>>,
val documents: List<String>,
val documents: List<VectorStore.Document>,
val precomputedEmbeddings: Map<String, Embedding>
) {
companion object {
Expand Down Expand Up @@ -75,26 +75,30 @@ private constructor(
.reversed()
}

override suspend fun addTexts(texts: List<String>) {
override suspend fun addDocuments(texts: List<VectorStore.Document>) {
val docsAsJson = texts.map { it.content }
val embeddingsList =
embeddings.embedDocuments(texts, embeddingRequestModel = embeddingRequestModel)
embeddings.embedDocuments(docsAsJson, embeddingRequestModel = embeddingRequestModel)
state.getAndUpdate { prevState ->
val newEmbeddings = prevState.precomputedEmbeddings + texts.zip(embeddingsList)
val newEmbeddings = prevState.precomputedEmbeddings + docsAsJson.zip(embeddingsList)
State(prevState.orderedMemories, prevState.documents + texts, newEmbeddings)
}
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
override suspend fun similaritySearch(query: String, limit: Int): List<VectorStore.Document> {
val queryEmbedding =
embeddings.embedQuery(query, embeddingRequestModel = embeddingRequestModel).firstOrNull()
return queryEmbedding?.let { similaritySearchByVector(it, limit) }.orEmpty()
}

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> {
override suspend fun similaritySearchByVector(
embedding: Embedding,
limit: Int
): List<VectorStore.Document> {
val state0 = state.get()
return state0.documents
.asSequence()
.mapNotNull { doc -> state0.precomputedEmbeddings[doc]?.let { doc to it } }
.mapNotNull { doc -> state0.precomputedEmbeddings[doc.content]?.let { doc to it } }
.map { (doc, e) -> doc to embedding.cosineSimilarity(e) }
.sortedByDescending { (_, similarity) -> similarity }
.take(limit)
Expand All @@ -103,9 +107,9 @@ private constructor(
}

private fun Embedding.cosineSimilarity(other: Embedding): Double {
val dotProduct = this.embedding.zip(other.embedding).sumOf { (a, b) -> (a * b).toDouble() }
val magnitudeA = sqrt(this.embedding.sumOf { (it * it).toDouble() })
val magnitudeB = sqrt(other.embedding.sumOf { (it * it).toDouble() })
val dotProduct = this.embedding.zip(other.embedding).sumOf { (a, b) -> (a * b) }
val magnitudeA = sqrt(this.embedding.sumOf { (it * it) })
val magnitudeB = sqrt(other.embedding.sumOf { (it * it) })
return dotProduct / (magnitudeA * magnitudeB)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@ package com.xebia.functional.xef.store
import arrow.atomic.AtomicInt
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.openai.generated.model.Embedding
import com.xebia.functional.xef.Config
import kotlin.jvm.JvmStatic
import kotlinx.serialization.Serializable

interface VectorStore {

@Serializable
data class Document(val content: String, val source: String) {
fun toJson(): String = Config.DEFAULT.json.encodeToString(serializer(), this)

companion object {
fun fromJson(json: String): Document =
Config.DEFAULT.json.decodeFromString(serializer(), json)
}
}

val indexValue: AtomicInt

fun incrementIndexAndGet(): Int = indexValue.addAndGet(1)
Expand All @@ -27,9 +39,9 @@ interface VectorStore {
* @param texts list of text to add to the vector store
* @return a list of IDs from adding the texts to the vector store
*/
suspend fun addTexts(texts: List<String>)
suspend fun addDocuments(texts: List<Document>)

suspend fun addText(texts: String) = addTexts(listOf(texts))
suspend fun addDocument(texts: Document) = addDocuments(listOf(texts))

/**
* Return the docs most similar to the query
Expand All @@ -38,7 +50,7 @@ interface VectorStore {
* @param limit number of documents to return
* @return a list of Documents most similar to query
*/
suspend fun similaritySearch(query: String, limit: Int): List<String>
suspend fun similaritySearch(query: String, limit: Int): List<Document>

/**
* Return the docs most similar to the embedding
Expand All @@ -47,7 +59,7 @@ interface VectorStore {
* @param limit number of documents to return
* @return list of Documents most similar to the embedding
*/
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String>
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<Document>

companion object {
@JvmStatic
Expand All @@ -65,14 +77,15 @@ interface VectorStore {
limitTokens: Int
): List<Memory> = emptyList()

override suspend fun addTexts(texts: List<String>) {}
override suspend fun addDocuments(texts: List<Document>) {}

override suspend fun similaritySearch(query: String, limit: Int): List<String> = emptyList()
override suspend fun similaritySearch(query: String, limit: Int): List<Document> =
emptyList()

override suspend fun similaritySearchByVector(
embedding: Embedding,
limit: Int
): List<String> = emptyList()
): List<Document> = emptyList()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.xebia.functional.xef.vectorstore

import com.xebia.functional.xef.OpenAI
import com.xebia.functional.xef.store.LocalVectorStore
import com.xebia.functional.xef.store.VectorStore.Document

suspend fun main() {
val embeddings = OpenAI().embeddings
val vectorStore = LocalVectorStore(embeddings)
val helloDoc = Document("Hello, how are you?", "source1")
val unrelatedDoc = Document("Unrelated text", "source2")
vectorStore.addDocuments(listOf(helloDoc, unrelatedDoc))
val maybeHelloDoc = vectorStore.similaritySearch("Hello", 1).first()
assert(maybeHelloDoc == helloDoc) { "Expected $helloDoc but got $maybeHelloDoc" }
val maybeUnrelatedDoc = vectorStore.similaritySearch("Unrelated", 1).first()
assert(maybeUnrelatedDoc == unrelatedDoc) { "Expected $unrelatedDoc but got $maybeUnrelatedDoc" }
println("All expected documents found!")
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,23 @@ class PGVectorStore(
}
}

override suspend fun addTexts(texts: List<String>): Unit =
override suspend fun addDocuments(texts: List<VectorStore.Document>): Unit =
dataSource.connection {
val embeddings = embeddings.embedDocuments(texts, chunkSize, embeddingRequestModel)
val docsContent = texts.map { it.content }
val embeddings = embeddings.embedDocuments(docsContent, chunkSize, embeddingRequestModel)
val collection = getCollection(collectionName)
texts.zip(embeddings) { text, embedding ->
val uuid = UUID.generateUUID()
update(addNewText) {
bind(uuid.toString())
bind(collection.uuid.toString())
bind(embedding.embedding.toString())
bind(text)
bind(text.toJson())
}
}
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> =
override suspend fun similaritySearch(query: String, limit: Int): List<VectorStore.Document> =
dataSource.connection {
val collection = getCollection(collectionName)

Expand All @@ -123,10 +124,12 @@ class PGVectorStore(
}
) {
string()
}.map { json ->
VectorStore.Document.fromJson(json)
}
}

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> =
override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<VectorStore.Document> =
dataSource.connection {
val collection = getCollection(collectionName)
queryAsList(
Expand All @@ -138,6 +141,8 @@ class PGVectorStore(
}
) {
string()
}.map { json ->
VectorStore.Document.fromJson(json)
}
}

Expand Down
14 changes: 8 additions & 6 deletions integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestMo
import com.xebia.functional.openai.generated.model.CreateEmbeddingRequestModel
import com.xebia.functional.openai.generated.model.Embedding
import com.xebia.functional.xef.store.PGVectorStore
import com.xebia.functional.xef.store.VectorStore
import com.xebia.functional.xef.store.migrations.runDatabaseMigrations
import com.xebia.functional.xef.store.postgresql.PGDistanceStrategy
import com.zaxxer.hikari.HikariConfig
Expand All @@ -17,7 +18,6 @@ import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.assertThrows
import org.testcontainers.containers.PostgreSQLContainer
import org.testcontainers.utility.DockerImageName
import kotlin.coroutines.coroutineContext

val postgres: PostgreSQLContainer<Nothing> =
PostgreSQLContainer(
Expand Down Expand Up @@ -67,10 +67,12 @@ class PGVectorStoreSpec :
postgresVector.createCollection()
}

val docs = listOf(VectorStore.Document(content = "foo", source = "tests"), VectorStore.Document(content = "bar", source = "tests"))

"initialDbSetup should configure the DB properly" { pg().initialDbSetup() }

"addTexts should fail with a CollectionNotFoundError if collection isn't present in the DB" {
assertThrows<IllegalStateException> { pg().addTexts(listOf("foo", "bar")) }.message shouldBe
assertThrows<IllegalStateException> { pg().addDocuments(docs) }.message shouldBe
"Collection 'test_collection' not found"
}

Expand All @@ -82,13 +84,13 @@ class PGVectorStoreSpec :
"createCollection should create collection" { pg().createCollection() }

"addTexts should not fail now that we created the collection" {
pg().addTexts(listOf("foo", "bar"))
pg().addDocuments(docs)
}

"similaritySearchByVector should return both documents" {
pg().addTexts(listOf("bar", "foo"))
pg().addDocuments(docs.reversed())
pg().similaritySearchByVector(Embedding(0, listOf(4.0, 5.0, 6.0), Embedding.Object.embedding), 2) shouldBe
listOf("bar", "foo")
docs.reversed()
}

"similaritySearch should return 2 documents" {
Expand All @@ -104,7 +106,7 @@ class PGVectorStoreSpec :
pg().similaritySearchByVector(
Embedding(0, listOf(1.0, 2.0, 3.0), Embedding.Object.embedding),
1
) shouldBe listOf("foo")
) shouldBe listOf(docs[0])
}

"the added memories sorted by index should be obtained in the same order" {
Expand Down

0 comments on commit 31a2d89

Please sign in to comment.