Skip to content

Commit

Permalink
Add cache policies to CachedTool (#792)
Browse files Browse the repository at this point in the history
* feat: add the ability of removing all expired entries when one is expired

* feat: add cached tool expiration policies

* feat: eviction and expiration policies

* docs: update comments

* refactor: change naming
  • Loading branch information
realdavidvega authored Oct 2, 2024
1 parent c55ae53 commit c5dc939
Showing 1 changed file with 99 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,65 @@ import kotlin.time.Duration.Companion.days

data class CachedToolKey<K>(val value: K, val seed: String)

data class CachedToolValue<V>(val value: V, val timestamp: Long)
data class CachedToolValue<V>(val value: V, val accessTimestamp: Long, val writeTimestamp: Long) {
fun withAccessTimestamp() = copy(accessTimestamp = timeInMillis())

companion object {
fun <V> withActualResponse(response: V): CachedToolValue<V> =
CachedToolValue(
value = response,
accessTimestamp = timeInMillis(),
writeTimestamp = timeInMillis()
)
}
}

data class CachedToolConfig(
val timeCachePolicy: Duration,
val cacheExpirationPolicy: CacheExpirationPolicy,
val cacheEvictionPolicy: CacheEvictionPolicy
) {

/** Policy to expire the entries in the cache, based on last access or last write time. */
enum class CacheExpirationPolicy {
/** Last access time is used to determine expiration */
ACCESS,
/** Last write time is used to determine expiration */
WRITE
}

/** Policy to evict the expired entries from the cache, based on one or all expired entries. */
enum class CacheEvictionPolicy {
/** Removes the expired entry when found */
SINGLE,
/** Removes all expired entries when one expired entry found */
ALL
}

companion object {
val Default =
CachedToolConfig(
timeCachePolicy = 1.days,
cacheEvictionPolicy = CacheEvictionPolicy.ALL,
cacheExpirationPolicy = CacheExpirationPolicy.WRITE
)
}
}

/**
* Tool that caches the result of the execution of [onCacheMissed] if [shouldUseCache] returns true.
* Otherwise, returns the result of [onCacheMissed]. This output is added to the cache when
* [shouldCacheOutput] returns true.
*
* Cache is stored in a [Map] of [CachedToolKey] to [CachedToolValue].
*
* Supports expiration policies using [CachedToolConfig].
*/
abstract class CachedTool<Input, Output>(
private val cache: Atomic<MutableMap<CachedToolKey<Input>, CachedToolValue<Output>>>,
private val seed: String,
private val timeCachePolicy: Duration = 1.days
private val config: CachedToolConfig = CachedToolConfig.Default
) : Tool<Input, Output> {

/**
* Logic to be executed when the cache is missed.
*
Expand Down Expand Up @@ -49,44 +100,59 @@ abstract class CachedTool<Input, Output>(
else onCacheMissed(input)

/**
* Exposes the cache as a [Map] of [Input] to [Output] filtered by instance [seed] and
* [timeCachePolicy]. Removes expired cache entries.
* Returns a snapshot of the cache as a [Map] of [Input] to [Output] filtered by instance [seed]
* and removing expired cache entries with the given [config] policies. Does not modify the cache.
*
* @return the map of input to output.
*/
suspend fun getCache(): Map<Input, Output> {
val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
val withoutExpired =
suspend fun getValidCacheSnapshot(): Map<Input, Output> {
val validEntries =
cache.modify { cachedToolInfo ->
// Filter entries belonging to the current seed and have not expired
val validEntries =
cachedToolInfo
.filter { (key, value) ->
if (key.seed == seed) lastTimeInCache <= value.timestamp else true
}
.toMutableMap()
// Remove expired entries for the current seed only
cachedToolInfo.keys.removeAll { key -> key.seed == seed && !validEntries.containsKey(key) }
// Modifies state A, and returns state B
val validEntries = cachedToolInfo.filterExpired().filter { (key, _) -> key.seed == seed }
Pair(cachedToolInfo, validEntries)
}
return withoutExpired.map { it.key.value to it.value.value }.toMap()
return validEntries.map { it.key.value to it.value.value }.toMap()
}

private suspend fun cache(input: CachedToolKey<Input>, block: suspend () -> Output): Output {
val cachedToolInfo = cache.get()[input]
if (cachedToolInfo != null) {
val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
if (lastTimeInCache > cachedToolInfo.timestamp) {
cache.get().remove(input)
} else {
return cachedToolInfo.value
}
private suspend fun cache(input: CachedToolKey<Input>, block: suspend () -> Output): Output =
cache.modify { cachedToolInfo ->
cachedToolInfo[input]?.let { output ->
if (output.isExpired()) {
val updatedCache =
when (config.cacheEvictionPolicy) {
CachedToolConfig.CacheEvictionPolicy.SINGLE -> cachedToolInfo.apply { remove(input) }
CachedToolConfig.CacheEvictionPolicy.ALL -> cachedToolInfo.filterExpired()
}
Pair(updatedCache, null)
} else {
val updatedOutput = output.withAccessTimestamp()
Pair(cachedToolInfo, updatedOutput.value)
}
} ?: Pair(cachedToolInfo, null)
}
val response = block()
if (shouldCacheOutput(input.value, response)) {
cache.get()[input] = CachedToolValue(response, timeInMillis())
?: run {
val response = block()
if (shouldCacheOutput(input.value, response)) {
cache.update { cachedToolInfo ->
cachedToolInfo[input] = CachedToolValue.withActualResponse(response)
cachedToolInfo
}
}
response
}

private fun MutableMap<CachedToolKey<Input>, CachedToolValue<Output>>.filterExpired() =
this.filter { (_, value) -> !value.isExpired() }.toMutableMap()

private fun CachedToolValue<Output>.isExpired(): Boolean =
when (config.cacheExpirationPolicy) {
CachedToolConfig.CacheExpirationPolicy.ACCESS -> {
val lastTimeInCache = timeInMillis() - accessTimestamp
lastTimeInCache > config.timeCachePolicy.inWholeMilliseconds
}
CachedToolConfig.CacheExpirationPolicy.WRITE -> {
val lastTimeInCache = timeInMillis() - writeTimestamp
lastTimeInCache > config.timeCachePolicy.inWholeMilliseconds
}
}
return response
}
}

0 comments on commit c5dc939

Please sign in to comment.