~ enforce Content-Type for POST
authorAlexandre Bertails <bertails@gmail.com>
Sat, 15 Oct 2011 10:58:04 -0400
changeset 67 353368a48cf5
parent 66 42222285bfea
child 69 3e84e72b6e82
child 80 4328660e5a0b
~ enforce Content-Type for POST
src/main/scala/Lang.scala
src/main/scala/Post.scala
src/main/scala/plan.scala
src/test/scala/CreateContentSpecs.scala
src/test/scala/OtherSpecs.scala
src/test/scala/SparqlQuerySpecs.scala
src/test/scala/SparqlUpdateSpecs.scala
src/test/scala/util/utiltest.scala
--- a/src/main/scala/Lang.scala	Sat Oct 15 10:28:08 2011 -0400
+++ b/src/main/scala/Lang.scala	Sat Oct 15 10:58:04 2011 -0400
@@ -20,7 +20,7 @@
 
 object Lang {
   
-  val supportedLanguages = Seq(RDFXML, TURTLE, N3)
+  val supportedLanguages = Set(RDFXML, TURTLE, N3)
   val supportContentTypes = supportedLanguages map (_.contentType)
   val supportedAsString = supportContentTypes mkString ", "
   
@@ -33,13 +33,16 @@
       case "application/rdf+xml" => Some(RDFXML)
       case _ => None
   }
+  
+  def unapply(contentType: String): Option[Lang] =
+    apply(contentType)
 
   def apply(req: HttpRequest[_]): Option[Lang] =
     RequestContentType(req) flatMap Lang.apply
     
   def unapply(req: HttpRequest[_]): Option[Lang] =
     apply(req)
-
+    
 }
 
 case object RDFXML extends Lang
--- a/src/main/scala/Post.scala	Sat Oct 15 10:28:08 2011 -0400
+++ b/src/main/scala/Post.scala	Sat Oct 15 10:58:04 2011 -0400
@@ -21,20 +21,29 @@
 
 object Post {
   
-  val SparqlContentType = "application/sparql-query"
-  val supportContentTypes = SparqlContentType + Lang.supportContentTypes
+  val SPARQL = "application/sparql-query"
+  val supportContentTypes = Lang.supportContentTypes + SPARQL
   val supportedAsString = supportContentTypes mkString ", "
 
   
   val logger: Logger = LoggerFactory.getLogger(this.getClass)
 
-  def parse(is: InputStream, baseURI:String): Post = {
+  def parse(
+      is: InputStream,
+      baseURI: String,
+      contentType: String): Post = {
+    assert(supportContentTypes contains contentType)
     val source = Source.fromInputStream(is, "UTF-8")
     val s = source.getLines.mkString("\n")
-    parse(s, baseURI)
+    parse(s, baseURI, contentType)
   }
   
-  def parse(s: String, baseURI: String): Post = {
+  def parse(
+      s: String,
+      baseURI: String,
+      contentType: String): Post = {
+    assert(supportContentTypes contains contentType)
+    
     val reader = new StringReader(s)
     
     def postUpdate =
@@ -45,9 +54,8 @@
         case qpe: QueryParseException => qpe.fail
       }
       
-    // TODO
-    def postRDF =
-      modelFromString(s, baseURI, RDFXML) flatMap { model => PostRDF(model).success }
+    def postRDF(lang: Lang) =
+      modelFromString(s, baseURI, lang) flatMap { model => PostRDF(model).success }
     
     def postQuery =
       try {
@@ -57,7 +65,11 @@
         case qe: QueryException => qe.fail
       }
     
-    postUpdate | (postRDF | (postQuery | PostUnknown))
+    contentType match {
+      case SPARQL => postUpdate | (postQuery | PostUnknown)
+      case Lang(lang) => postRDF(lang) | PostUnknown
+    }
+
   }
   
 }
--- a/src/main/scala/plan.scala	Sat Oct 15 10:28:08 2011 -0400
+++ b/src/main/scala/plan.scala	Sat Oct 15 10:58:04 2011 -0400
@@ -79,18 +79,11 @@
           } yield Created
         case PUT(_) =>
           BadRequest ~> ResponseString("Content-Type MUST be one of: " + Lang.supportedAsString)
-        case POST(_) =>
-          req match {
-            case RequestContentType("application/sparql-query") => null
-            case RequestContentType(ct) if Lang.supportContentTypes contains ct => null
-            case _ => BadRequest ~> ResponseString("Content-Type MUST be one of: " + Post.supportedAsString)
-          }
-          
-          {
-          Post.parse(Body.stream(req), baseURI) match {
+        case POST(_) & RequestContentType(ct) if Post.supportContentTypes contains ct => {
+          Post.parse(Body.stream(req), baseURI, ct) match {
             case PostUnknown => {
               logger.info("Couldn't parse the request")
-              BadRequest ~> ResponseString("You MUST provide valid content for either: SPARQL UPDATE, SPARQL Query, RDF/XML, TURTLE")
+              BadRequest ~> ResponseString("You MUST provide valid content for given Content-Type: " + ct)
             }
             case PostUpdate(update) => {
               logger.info("SPARQL UPDATE:\n" + update.toString())
@@ -135,6 +128,8 @@
             }
           }
         }
+        case POST(_) =>
+          BadRequest ~> ResponseString("Content-Type MUST be one of: " + Post.supportedAsString)
         case _ => MethodNotAllowed ~> Allow("GET", "PUT", "POST")
       }
     }
--- a/src/test/scala/CreateContentSpecs.scala	Sat Oct 15 10:28:08 2011 -0400
+++ b/src/test/scala/CreateContentSpecs.scala	Sat Oct 15 10:58:04 2011 -0400
@@ -60,7 +60,7 @@
 
   "POSTing an RDF document to Joe's URI" should {
     "succeed" in {
-      val httpCode:Int = Http(uri.post(diffRDF) get_statusCode)
+      val httpCode:Int = Http(uri.post(diffRDF, RDFXML) get_statusCode)
       httpCode must_== 200
     }
     "append the diff graph to the initial graph" in {
--- a/src/test/scala/OtherSpecs.scala	Sat Oct 15 10:28:08 2011 -0400
+++ b/src/test/scala/OtherSpecs.scala	Sat Oct 15 10:58:04 2011 -0400
@@ -9,7 +9,7 @@
 
   """POSTing something that does not make sense to Joe's URI""" should {
     "return a 400 Bad Request" in {
-      val statusCode = Http.when(_ == 400)(uri.post("that's bouleshit") get_statusCode)
+      val statusCode = Http.when(_ == 400)(uri.post("that's bouleshit", RDFXML) get_statusCode)
       statusCode must_== 400
     }
   }
--- a/src/test/scala/SparqlQuerySpecs.scala	Sat Oct 15 10:28:08 2011 -0400
+++ b/src/test/scala/SparqlQuerySpecs.scala	Sat Oct 15 10:58:04 2011 -0400
@@ -20,7 +20,7 @@
   
   """POSTing "SELECT ?name WHERE { [] foaf:name ?name }" to Joe's URI""" should {
     "return Joe's name" in {
-      val resultSet = Http(uri.post(selectFoafName) >- { body => ResultSetFactory.fromXML(body) } )
+      val resultSet = Http(uri.postSPARQL(selectFoafName) >- { body => ResultSetFactory.fromXML(body) } )
       resultSet.next().getLiteral("name").getString must_== "Joe Lambda"
     }
   }
@@ -39,7 +39,7 @@
   """POSTing "ASK ?name WHERE { [] foaf:name ?name }" to Joe's URI""" should {
     "return true" in {
       val result: Boolean =
-        Http(uri.post(askFoafName) >~ { s => 
+        Http(uri.postSPARQL(askFoafName) >~ { s => 
           (XML.fromSource(s) \ "boolean" \ text).head.toBoolean
           } )
       result must_== true
--- a/src/test/scala/SparqlUpdateSpecs.scala	Sat Oct 15 10:28:08 2011 -0400
+++ b/src/test/scala/SparqlUpdateSpecs.scala	Sat Oct 15 10:58:04 2011 -0400
@@ -15,7 +15,7 @@
   
   "POSTing an INSERT query on Joe's URI (which does not exist yet)" should {
     "succeed" in {
-      val httpCode = Http(uri.post(insertQuery) get_statusCode)
+      val httpCode = Http(uri.postSPARQL(insertQuery) get_statusCode)
       httpCode must_== 200
     }
     "produce a graph with one more triple than the original one" in {
--- a/src/test/scala/util/utiltest.scala	Sat Oct 15 10:28:08 2011 -0400
+++ b/src/test/scala/util/utiltest.scala	Sat Oct 15 10:58:04 2011 -0400
@@ -42,8 +42,15 @@
     def as_model(base: String, lang: Lang): Handler[Model] =
       req >> { is => modelFromInputStream(is, base, lang).toOption.get }
 
-    def post(body: String): Request =
-      (req <<< body).copy(method="POST")
+    def post(body: String, lang: Lang): Request =
+      post(body, lang.contentType)
+    
+    def postSPARQL(body: String): Request =
+      post(body, Post.SPARQL)
+      
+    private def post(body: String, contentType: String): Request =
+      (req <:< Map("Content-Type" -> contentType) <<< body).copy(method="POST")
+
       
     def put(lang: Lang, body: String): Request =
       req <:< Map("Content-Type" -> lang.contentType) <<< body