Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Dagon API a bit. #7

Merged
merged 2 commits into from
Aug 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ object Example {
// 3. set up rewrite rules

object SimplifyNegation extends PartialRule[Eqn] {
def applyWhere[T](on: ExpressionDag[Eqn]) = {
def applyWhere[T](on: Dag[Eqn]) = {
case Negate(Negate(e)) => e
case Negate(Const(x)) => Const(-x)
}
}

object SimplifyAddition extends PartialRule[Eqn] {
def applyWhere[T](on: ExpressionDag[Eqn]) = {
def applyWhere[T](on: Dag[Eqn]) = {
case Add(Const(x), Const(y)) => Const(x + y)
case Add(Add(e, Const(x)), Const(y)) => Add(e, Const(x + y))
case Add(Add(Const(x), e), Const(y)) => Add(e, Const(x + y))
Expand All @@ -128,7 +128,7 @@ object Example {
val rules = SimplifyNegation.orElse(SimplifyAddition)

val simplified: Eqn[Unit] =
ExpressionDag.applyRule(c, toLiteral, rules)
Dag.applyRule(c, toLiteral, rules)
}
```

Expand Down
70 changes: 70 additions & 0 deletions core/src/main/scala/com/stripe/dagon/Cache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.stripe.dagon

/**
* This is a useful cache for memoizing function.
*
* The cache is implemented using a mutable pointer to an immutable
* map value. In the worst-case, race conditions might cause us to
* lose cache values (i.e. compute some keys twice), but we will never
* produce incorrect values.
*/
sealed class Cache[K, V] private (init: Map[K, V]) {

private[this] var map: Map[K, V] = init

/**
* Given a key, either return a cached value, or compute, store, and
* return a new value.
*
* This method is what justifies the existence of Cache. Its second
* parameter (`v`) is by-name: it will only be evaluated in cases
* where the key is not cached.
*
* For example:
*
* def greet(i: Int): Int = {
* println("hi")
* i + 1
* }
*
* val c = Cache.empty[Int, Int]
* c.getOrElseUpdate(1, greet(1)) // says hi, returns 2
* c.getOrElseUpdate(1, greet(1)) // just returns 2
*/
def getOrElseUpdate(k: K, v: => V): V =
map.get(k) match {
case Some(exists) => exists
case None =>
val res = v
map = map.updated(k, res)
res
}

/**
* Create a second cache with the same values as this one.
*
* The two caches will start with the same values, but will be
* independently updated.
*/
def duplicate: Cache[K, V] =
new Cache(map)

/**
* Access the currently-cached keys and values as a map.
*/
def toMap: Map[K, V] =
map

/**
* Forget all cached keys and values.
*
* After calling this method, the resulting cache is equivalent to
* Cache.empty[K, V].
*/
def reset(): Unit =
map = Map.empty
}

object Cache {
def empty[K, V]: Cache[K, V] = new Cache(Map.empty)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,84 @@

package com.stripe.dagon

sealed abstract class ExpressionDag[N[_]] { self =>
/**
* Represents a directed acyclic graph (DAG).
*
* The type N[_] represents the type of nodes in the graph.
*/
sealed abstract class Dag[N[_]] { self =>

/**
* These have package visibility to test
* the law that for all Expr, the node they
* evaluate to is unique
*/
protected def idToExp: HMap[Id, Expr[N, ?]]

/**
* The set of roots that were added by addRoot.
* These are Ids that will always evaluate
* such that roots.forall(evaluateOption(_).isDefined)
*/
protected def roots: Set[Id[_]]

/**
* This is the next Id value which will be allocated
*/
protected def nextId: Int

/**
* Convert a N[T] to a Literal[T, N]
* Convert a N[T] to a Literal[T, N].
*/
def toLiteral: FunctionK[N, Literal[N, ?]]

// Caches polymorphic functions of type T => Option[N[T]]
private val idToN: HCache[Id, Lambda[t => Option[N[t]]]] =
HCache.empty[Id, Lambda[t => Option[N[t]]]]

// Caches polymorphic functions of type N[T] => Option[T]
private val nodeToId: HCache[N, Lambda[t => Option[Id[t]]]] =
HCache.empty[N, Lambda[t => Option[Id[t]]]]

// Convenient method to produce new, modified DAGs based on this
// one.
private def copy(
id2Exp: HMap[Id, Expr[N, ?]] = self.idToExp,
node2Literal: FunctionK[N, Literal[N, ?]] = self.toLiteral,
gcroots: Set[Id[_]] = self.roots,
id: Int = self.nextId
): ExpressionDag[N] = new ExpressionDag[N] {
): Dag[N] = new Dag[N] {
def idToExp = id2Exp
def roots = gcroots
def toLiteral = node2Literal
def nextId = id
}

override def toString: String =
s"ExpressionDag(idToExp = $idToExp, roots = $roots)"
// Produce a new DAG that is equivalent to this one, but which frees
// orphaned nodes and other internal state which may no longer be
// needed.
private def gc: Dag[N] = {
val keepers = reachableIds
if (idToExp.forallKeys(keepers)) this
else copy(id2Exp = idToExp.filterKeys(keepers))
}

// This is a cache of Id[T] => Option[N[T]]
private val idToN =
HCache.empty[Id, Lambda[t => Option[N[t]]]]
private val nodeToId =
HCache.empty[N, Lambda[t => Option[Id[t]]]]
/**
* String representation of this DAG.
*/
override def toString: String =
s"Dag(idToExp = $idToExp, roots = $roots)"

/**
* Add a GC root, or tail in the DAG, that can never be deleted.
*/
def addRoot[T](node: N[T]): (ExpressionDag[N], Id[T]) = {
def addRoot[T](node: N[T]): (Dag[N], Id[T]) = {
val (dag, id) = ensure(node)
(dag.copy(gcroots = roots + id), id)
}

/**
* Which ids are reachable from the roots
* Which ids are reachable from the roots?
*/
def reachableIds: Set[Id[_]] = {

Expand All @@ -86,22 +109,14 @@ sealed abstract class ExpressionDag[N[_]] { self =>
Graphs.reflexiveTransitiveClosure(roots.toList)(neighbors _).toSet
}

private def gc: ExpressionDag[N] = {
val goodIds = reachableIds
val toKeepI2E = idToExp.filter(new FunctionK[HMap[Id, Expr[N, ?]]#Pair, BoolT] {
def toFunction[T] = { case (id, _) => goodIds(id) }
})
copy(id2Exp = toKeepI2E)
}

/**
* Apply the given rule to the given dag until
* the graph no longer changes.
*/
def apply(rule: Rule[N]): ExpressionDag[N] = {
def apply(rule: Rule[N]): Dag[N] = {

@annotation.tailrec
def loop(d: ExpressionDag[N]): ExpressionDag[N] = {
def loop(d: Dag[N]): Dag[N] = {
val next = d.applyOnce(rule)
if (next eq d) next
else loop(next)
Expand All @@ -114,8 +129,8 @@ sealed abstract class ExpressionDag[N[_]] { self =>
* apply the rule at the first place that satisfies
* it, and return from there.
*/
def applyOnce(rule: Rule[N]): ExpressionDag[N] = {
type DagT[T] = ExpressionDag[N]
def applyOnce(rule: Rule[N]): Dag[N] = {
type DagT[T] = Dag[N]

val f = new FunctionK[HMap[Id, Expr[N, ?]]#Pair, Lambda[x => Option[DagT[x]]]] {
def toFunction[U] = { (kv: (Id[U], Expr[N, U])) =>
Expand All @@ -134,7 +149,7 @@ sealed abstract class ExpressionDag[N[_]] { self =>
// publicly, and the ids may be embedded in many
// nodes. Instead we remap 'id' to be a pointer
// to 'newid'.
dag.copy(id2Exp = dag.idToExp + (id -> Expr.Var[N, U](newId))).gc
dag.copy(id2Exp = dag.idToExp.updated(id, Expr.Var[N, U](newId))).gc
}
}
}
Expand All @@ -146,10 +161,10 @@ sealed abstract class ExpressionDag[N[_]] { self =>
/**
* Apply a rule at most cnt times.
*/
def applyMax(rule: Rule[N], cnt: Int): ExpressionDag[N] = {
def applyMax(rule: Rule[N], cnt: Int): Dag[N] = {

@annotation.tailrec
def loop(d: ExpressionDag[N], cnt: Int): ExpressionDag[N] =
def loop(d: Dag[N], cnt: Int): Dag[N] =
if (cnt <= 0) d
else {
val next = d.applyOnce(rule)
Expand All @@ -165,10 +180,10 @@ sealed abstract class ExpressionDag[N[_]] { self =>
*
* Note, Expr must never be a Var
*/
private def addExp[T](node: N[T], exp: Expr[N, T]): (ExpressionDag[N], Id[T]) = {
private def addExp[T](node: N[T], exp: Expr[N, T]): (Dag[N], Id[T]) = {
require(!exp.isVar)
val nodeId = Id[T](nextId)
(copy(id2Exp = idToExp + (nodeId -> exp), id = nextId + 1), nodeId)
(copy(id2Exp = idToExp.updated(nodeId, exp), id = nextId + 1), nodeId)
}

/**
Expand Down Expand Up @@ -264,7 +279,7 @@ sealed abstract class ExpressionDag[N[_]] { self =>
* at most one id in the graph. Put another way, for all
* Id[T] in the graph evaluate(id) is distinct.
*/
protected def ensure[T](node: N[T]): (ExpressionDag[N], Id[T]) =
protected def ensure[T](node: N[T]): (Dag[N], Id[T]) =
find(node) match {
case Some(id) => (this, id)
case None =>
Expand Down Expand Up @@ -377,20 +392,20 @@ sealed abstract class ExpressionDag[N[_]] { self =>
}
}

object ExpressionDag {
object Dag {

def empty[N[_]](n2l: FunctionK[N, Literal[N, ?]]): ExpressionDag[N] =
new ExpressionDag[N] {
def empty[N[_]](n2l: FunctionK[N, Literal[N, ?]]): Dag[N] =
new Dag[N] {
val idToExp = HMap.empty[Id, Expr[N, ?]]
val toLiteral = n2l
val roots = Set.empty[Id[_]]
val nextId = 0
}

/**
* This creates a new ExpressionDag rooted at the given tail node
* This creates a new Dag rooted at the given tail node
*/
def apply[T, N[_]](n: N[T], nodeToLit: FunctionK[N, Literal[N, ?]]): (ExpressionDag[N], Id[T]) =
def apply[T, N[_]](n: N[T], nodeToLit: FunctionK[N, Literal[N, ?]]): (Dag[N], Id[T]) =
empty(nodeToLit).addRoot(n)

/**
Expand Down
57 changes: 48 additions & 9 deletions core/src/main/scala/com/stripe/dagon/HCache.scala
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
package com.stripe.dagon

/**
* This is a useful cache for memoizing heterogenously types functions
* This is a useful cache for memoizing natural transformations.
*
* The cache is implemented using a mutable pointer to an immutable
* map value. In the worst-case, race conditions might cause us to
* lose cache values (i.e. compute some keys twice), but we will never
* produce incorrect values.
*/
sealed class HCache[K[_], V[_]] private (init: HMap[K, V]) {
private var hmap: HMap[K, V] = init

/**
* Get an immutable snapshot of the current state
*/
def snapshot: HMap[K, V] = hmap
private[this] var hmap: HMap[K, V] = init

/**
* Get a mutable copy of the current state
* Given a key, either return a cached value, or compute, store, and
* return a new value.
*
* This method is what justifies the existence of Cache. Its second
* parameter (`v`) is by-name: it will only be evaluated in cases
* where the key is not cached.
*
* For example:
*
* def greet(i: Int): Option[Int] = {
* println("hi")
* Option(i + 1)
* }
*
* val c = Cache.empty[Option, Option]
* c.getOrElseUpdate(Some(1), greet(1)) // says hi, returns Some(2)
* c.getOrElseUpdate(Some(1), greet(1)) // just returns Some(2)
*/
def duplicate: HCache[K, V] = new HCache(hmap)

def getOrElseUpdate[T](k: K[T], v: => V[T]): V[T] =
hmap.get(k) match {
case Some(exists) => exists
Expand All @@ -24,6 +39,30 @@ sealed class HCache[K[_], V[_]] private (init: HMap[K, V]) {
hmap = hmap + (k -> res)
res
}

/**
* Create a second cache with the same values as this one.
*
* The two caches will start with the same values, but will be
* independently updated.
*/
def duplicate: HCache[K, V] =
new HCache(hmap)

/**
* Access the currently-cached keys and values as a map.
*/
def toHMap: HMap[K, V] =
hmap

/**
* Forget all cached keys and values.
*
* After calling this method, the resulting cache is equivalent to
* Cache.empty[K, V].
*/
def reset(): Unit =
hmap = HMap.empty[K, V]
}

object HCache {
Expand Down
Loading