+ graph trimmer (to get rid of unwanted joins)
authorEric Prud'hommeaux <eric@w3.org>
Wed, 10 Feb 2010 17:22:04 -0500
changeset 160 517697fc09ab
parent 159 c7bf17241c5d
child 161 ce95b5ab3d6c
+ graph trimmer (to get rid of unwanted joins)
src/main/scala/GraphAnalyzer.scala
src/main/scala/SPARQL.scala
src/main/scala/SparqlToSparql.scala
src/test/scala/SparqlToSparqlTest.scala
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/main/scala/GraphAnalyzer.scala	Wed Feb 10 17:22:04 2010 -0500
@@ -0,0 +1,71 @@
+package w3c.sw.util
+import scala.collection.immutable._
+
+class GraphAnalyzer[A](m:Map[A, Set[A]]) {
+
+  def reaches (from:A, to:A) = {
+    if (m.contains(from))
+      new GraphAnalyzer[A](m + (from -> (m(from) + to)))
+    else
+      new GraphAnalyzer[A](m + (from -> Set(to)))
+  }
+  def pair (l:A, r:A) = {
+    reaches(l, r).reaches(r, l)
+  }
+  def neededFor (need:Set[A], t:A, visited:Set[A]):Boolean = {
+    if (!m.contains(t))
+      error("unknown symbol: " + t)
+    var ret:Boolean = false
+    m(t).map((r) => {
+      if (visited.contains(t)) return false
+      if (need.contains(t)) return true
+      ret |= neededFor(need, r, visited + t)
+    })
+    return ret
+  }
+
+}
+
+object GraphAnalyzer {
+  def apply[A]():GraphAnalyzer[A] = GraphAnalyzer() // Map[A, Set[A]]()
+  // def apply[A](list:List[A]):GraphAnalyzer[A] = new GraphAnalyzer(list)
+  // def apply[A](args:A*):GraphAnalyzer[A] = GraphAnalyzer(args.toList)
+}
+
+// object GraphAnalyzer {
+// }
+// //def createEmptyGraphAnalyzer () = GraphAnalyzer(Map[A, Set[A]]())
+
+  // case class Reacher (b:Map[String, Map[String, Set[String]]]) {
+  //   def reaches (from:String, to:String) = {
+  //     if (b.contains(from)) {
+  // 	if (b(from).contains(to)) {
+  // 	  val back = b(from)(to) + from
+  // 	  val fore = b(from) + (to -> back)
+  // 	  val ret = Reacher(b + (from -> fore))
+  // 	  println("duplicate path from " + from + " to " + to)
+  // 	  // println("ret: " + ret)
+  // 	  ret
+  // 	} else {
+  // 	  println(from + "->" + to + " + " + this)
+  // 	  val back = Set[String](from)
+  // 	  val fore = b(from) + (to -> back)
+  // 	  val ret = Reacher(b + (from -> fore))
+  // 	  println("ret: " + ret)
+  // 	  ret
+  // 	}
+  //     } else {
+  // 	val back = Set[String](from)
+  // 	val fore = Map[String, Set[String]](to -> back)
+  // 	val ret = Reacher(b + (from -> fore))
+  // 	// println("ret: " + ret)
+  // 	ret
+  //     }
+  //   }
+  //   def pair (l:String, r:String) = {
+  //     reaches(l, r).reaches(r, l)
+  //   }
+  //   override def toString = b.toString
+  // }
+  // def createEmptyReacher () = Reacher(Map[String, Map[String, Set[String]]]())
+
--- a/src/main/scala/SPARQL.scala	Mon Feb 08 15:19:02 2010 -0500
+++ b/src/main/scala/SPARQL.scala	Wed Feb 10 17:22:04 2010 -0500
@@ -43,7 +43,36 @@
     }
   }
 
+  def trim (terms:Set[Term]):GraphPattern = {
+    this match {
+      case TableFilter(gp2:GraphPattern, expr:Expression) =>
+	TableFilter(gp2.trim(terms), expr)
+
+      case TriplesBlock(triplepatterns) => {
+	val r0 = new w3c.sw.util.GraphAnalyzer[Term](Map[Term, Set[Term]]())
+	/* Examine each triple, updating the compilation state. */
+	val r = triplepatterns.foldLeft(r0)((r, triple) => r.pair(triple.s, triple.o))
+	val useful = triplepatterns.foldLeft(Set[TriplePattern]())((s, t) => {
+	  if (r.neededFor(terms, t.s, Set(t.o)) &&
+	      r.neededFor(terms, t.o, Set(t.s))) s + t
+	  else s
+	})
+	TriplesBlock(useful.toList)
+      }
+
+      case TableConjunction(list) =>
+	/* Examine each triple, updating the compilation state. */
+	TableConjunction(list.map(gp2 => gp2.trim(terms)))
+
+      case OptionalGraphPattern(gp2) =>
+	/* Examine each triple, updating the compilation state. */
+	OptionalGraphPattern(gp2.trim(terms))
+
+      case x => error("no code to handle " + x)
+    }
+  }
 }
+
 case class TriplesBlock(triplepatterns:List[TriplePattern]) extends GraphPattern {
   override def toString = "{\n  " + (triplepatterns.toList.map(s => s.toString.replace("\n", "\n  ")).mkString(".\n  ")) + "\n}"
 }
--- a/src/main/scala/SparqlToSparql.scala	Mon Feb 08 15:19:02 2010 -0500
+++ b/src/main/scala/SparqlToSparql.scala	Wed Feb 10 17:22:04 2010 -0500
@@ -46,13 +46,14 @@
     /* "Uniquely" prefix unmapped vars to void conflict with other rules. */
     // val bound = Set[sparql.Var](vartermmap.map((varterm) => varterm._1))
     val bound = vartermmap.foldLeft(Set[sparql.Var]())((s, varterm) => s + varterm._1)
+    val mappedTo = vartermmap.foldLeft(Set[sparql.Term]())((s, varterm) => s + varterm._2)
     val vars = gp.findVars
     val diff = vars -- bound
     diff.foldLeft(mapped)((incrementalGP, varr) => {
-      substitute(incrementalGP, sparql.TermVar(varr), sparql.TermVar(sparql.Var(varPrefix + varr.s)))
+      substitute(incrementalGP, sparql.TermVar(varr), sparql.TermVar(sparql.Var(varPrefix + varr.s))).trim(mappedTo)
     })
   }
-  case class HornRule (trigger:sparql.TriplePattern, construct:sparql.Construct) {
+  case class RuleIndex (trigger:sparql.TriplePattern, construct:sparql.Construct) {
     override def toString = "{ \"" + trigger + "\" } => {\"\n  " + _shorten(construct.gp.toString).replace("\n", "\n  ") + "\n\"}"
     def transform (tp:sparql.TriplePattern):sparql.GraphPattern = {
       substitute(substitute(construct.gp, trigger.s, tp.s), trigger.o, tp.o)
@@ -147,7 +148,7 @@
   }
   def createEmptyBindings () = Bindings(Map[sparql.Construct, List[Map[sparql.Var, sparql.Term]]]())
 
-  case class RuleMap (rules:Map[sparql.Uri, List[HornRule]]) {
+  case class RuleMap (rules:Map[sparql.Uri, List[RuleIndex]]) {
     def transform (prove:List[sparql.TriplePattern], used:Set[sparql.TriplePattern], varsP:Bindings):Bindings = {
       val _pad = used.foldLeft("")((s, x) => s + " ")
       def _deepPrint (s:String):Unit = { println(used.size + ":" + _pad + s.replace("\n", "\n" + _pad)) }
@@ -213,15 +214,15 @@
   def apply (query:sparql.Select, constructs:List[sparql.Construct]) : sparql.Select = {
     var _ruleNo = 0
     val ruleMap = RuleMap({
-      constructs.foldLeft(Map[sparql.Uri, List[HornRule]]())((m, rule) => {
+      constructs.foldLeft(Map[sparql.Uri, List[RuleIndex]]())((m, rule) => {
 	RuleLabels.update(rule.head.toString, "head" + _ruleNo)
 	RuleLabels.update(rule.gp.toString, "body" + _ruleNo)
 	_ruleNo = _ruleNo + 1
 	rule.head.triplepatterns.foldLeft(m)((m, tp) => m + ({
 	  tp.p match {
 	    case sparql.TermUri(u) => u -> {
-	      if (m.contains(u)) m(u) ++ List(HornRule(tp, rule))
-	      else List(HornRule(tp, rule))}
+	      if (m.contains(u)) m(u) ++ List(RuleIndex(tp, rule))
+	      else List(RuleIndex(tp, rule))}
 	    case _ => error("not implemented: " + tp.p)
 	  }
 	}))
@@ -233,5 +234,6 @@
       mapGraphPattern(query.gp, ruleMap)
     )
   }
+
 }
 
--- a/src/test/scala/SparqlToSparqlTest.scala	Mon Feb 08 15:19:02 2010 -0500
+++ b/src/test/scala/SparqlToSparqlTest.scala	Wed Feb 10 17:22:04 2010 -0500
@@ -86,7 +86,6 @@
 PREFIX empP : <http://hr.example/DB/Employee#>
 PREFIX xsd : <http://www.w3.org/2001/XMLSchema#>
 SELECT ?emp {
-?emp  empP:firstName    ?_0_fname .
 ?emp  empP:lastName    "Smith"^^xsd:string
 }
 """).get
@@ -167,10 +166,8 @@
 SELECT ?lname
      {{?who     empP:lastName   ?lname .
        ?_0_pair task:drone      ?who .
-       ?_0_pair task:manager    ?whom .
-       ?whom    empP:lastName   ?_0_mname }
-      {?whom    empP:lastName   ?_1_wname .
-       ?_1_pair task:drone      ?whom .
+       ?_0_pair task:manager    ?whom }
+      {?_1_pair task:drone      ?whom .
        ?_1_pair task:manager    ?whom2 .
        ?whom2   empP:lastName   "Smith"^^xsd:string }
      }
@@ -178,4 +175,23 @@
     assert(transformed === expected)
   }
 
+  test("reaches") {
+    val graph = Set[(String,String)](
+    ("a", "b"),
+    ("c", "a"),
+    ("b", "d"),
+    ("c", "e"),
+    ("e", "a"),
+    ("e", "f"))
+    val r0 = new w3c.sw.util.GraphAnalyzer[String](Map[String, Set[String]]())
+    val r = graph.foldLeft(r0)((r, pair) => r.pair(pair._1, pair._2))
+    println("r: " + r0)
+    val useful = graph.foldLeft(Set[(String,String)]())((s, pair) => {
+      if (r.neededFor(Set("a", "f"), pair._1, Set(pair._2)) &&
+	  r.neededFor(Set("a", "f"), pair._2, Set(pair._1))) s + pair
+      else s
+    })
+    println("of interest: " + useful)
+  }
+
 }