Skip to content

Commit

Permalink
[SPK-186] Vault error messages if fails a request (apache#160)
Browse files Browse the repository at this point in the history
* Vault error messages if is failing

* Changelog and solved compilations in mesos

* Mesos fixes

* Code style

* Uncomment tests

* Trying timeout

* Removed todo

* Go Jenkins

* From future to try

* Fixed futures to try

* More specific messages

* Revert "More specific messages"

This reverts commit 3b4304a0efca1ea2f837070cd40369085c3ce741.

* Uncommented tests
  • Loading branch information
pianista215 authored Mar 14, 2018
1 parent 0115a1c commit 9da8ad4
Show file tree
Hide file tree
Showing 12 changed files with 416 additions and 180 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Possibility of specifying the name of the certificate within a directory with multiple certificates
* Add support to multiples CAs
* Added performance tests PNF
* Add messages for errors in Vault

## 2.2.0.5 (upcoming)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import java.util.zip.Adler32
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.util.Random

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
Expand Down Expand Up @@ -92,7 +91,7 @@ private[spark] class TorrentSecretBroadcast(secretVaultPath: String,
private var checksums: Array[Int] = _

override protected def getValue() = {
VaultHelper.retrieveSecret(secretVaultPath, idJson)
VaultHelper.retrieveSecret(secretVaultPath, idJson).get
}

private def calcChecksum(block: ByteBuffer): Int = {
Expand Down
26 changes: 20 additions & 6 deletions core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.util.{Failure, Success, Try}
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


object ConfigSecurity extends Logging {

val secretsFolder: String = sys.env.get("SPARK_DRIVER_SECRET_FOLDER") match {
Expand All @@ -39,21 +40,34 @@ object ConfigSecurity extends Logging {
logInfo("Obtaining vault token using VAULT_TOKEN")
sys.env.get("VAULT_TOKEN")
} else if (sys.env.get("VAULT_TEMP_TOKEN").isDefined) {

logInfo("Obtaining vault token using VAULT_TEMP_TOKEN")
scala.util.Try {
VaultHelper.getRealToken(sys.env.get("VAULT_TEMP_TOKEN"))
} match {
val token = VaultHelper.getRealToken(sys.env.get("VAULT_TEMP_TOKEN"))

token match {
case Success(token) => Option(token)
case Failure(e) =>
logWarning("An error ocurred while trying to obtain" +
" Application Token from a temporal token", e)
None
}

} else if (sys.env.get("VAULT_ROLE_ID").isDefined && sys.env.get("VAULT_SECRET_ID").isDefined) {
logInfo("Obtaining vault token using ROLE_ID and SECRET_ID")
Option(VaultHelper.getTokenFromAppRole(

logInfo("Obtaining vault token using ROLE_ID and SECRET_ID")
val tokenRole = VaultHelper.getTokenFromAppRole(
sys.env("VAULT_ROLE_ID"),
sys.env("VAULT_SECRET_ID")))
sys.env("VAULT_SECRET_ID")
)

tokenRole match {
case Success(token) => Option(token)
case Failure(e) =>
logWarning("An error ocurred while trying to obtain" +
" Application Token from a ROLE_ID and SECRET_ID", e)
None
}

} else {
logInfo("No Vault token variables provided. Skipping Vault token retrieving")
None
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/security/DBConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
*/
package org.apache.spark.security


object DBConfig {
def prepareEnvironment(options: Map[String, String]): Map[String, String] = {
options.filter(_._1.endsWith("DB_USER_VAULT_PATH")).flatMap{case (_, path) =>
val (pass, user) = VaultHelper.getPassPrincipalFromVault(path)

val (pass, user) = VaultHelper.getPassPrincipalFromVault(path).get

Seq(("spark.db.enable", "true"), ("spark.db.user", user), ("spark.db.pass", pass))
}
}
Expand Down
72 changes: 40 additions & 32 deletions core/src/main/scala/org/apache/spark/security/HTTPHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.security
import java.io.{BufferedReader, File, InputStreamReader}
import java.security.cert.X509Certificate

import scala.annotation.tailrec
import scala.util.Try
import scala.util.parsing.json.JSON

import org.apache.http.client.HttpClient
Expand All @@ -31,6 +31,10 @@ import org.apache.http.ssl.{SSLContextBuilder, TrustStrategy}

import org.apache.spark.internal.Logging





object HTTPHelper extends Logging{

lazy val clientNaive: HttpClient = {
Expand All @@ -54,49 +58,53 @@ object HTTPHelper extends Logging{
def executePost(requestUrl: String,
parentField: String,
headers: Option[Seq[(String, String)]],
entity: Option[String] = None): Map[String, Any] = {
entity: Option[String] = None): Try[Map[String, Any]] = {
val post = new HttpPost(requestUrl)

getContentFromResponse(post, parentField, headers, entity)
}
def executeGet(requestUrl: String,
parentField: String,
headers: Option[Seq[(String, String)]]): Map[String, Any] = {
headers: Option[Seq[(String, String)]]): Try[Map[String, Any]] = {
val get = new HttpGet(requestUrl)
getContentFromResponse(get, parentField, headers)
}

private def getContentFromResponse(uriRequest: HttpUriRequest,
parentField: String,
headers: Option[Seq[(String, String)]],
entities: Option[String] = None): Map[String, Any] = {

headers.map(head => head.foreach { case (head, value) => uriRequest.addHeader(head, value) })

entities.map(entity => uriRequest.asInstanceOf[HttpPost].setEntity(new StringEntity(entity)))

val client = secureClient match {
case Some(secureClient) =>
logInfo(s"Using secure client")
secureClient
case _ => logInfo(s"Using non secure client")
clientNaive
}

val response = client.execute(uriRequest)

val rd = new BufferedReader(
new InputStreamReader(response.getEntity().getContent()))

val json = JSON.parseFull(Stream.continually(rd.readLine()).takeWhile(_ != null).mkString).
get.asInstanceOf[Map[String, Any]]
logTrace(s"getFrom Vault ${json.mkString("\n")}")
if(response.getStatusLine.getStatusCode != 200) {
val errors = json("errors").asInstanceOf[List[String]].mkString("\n")
throw new RuntimeException(errors)
}
else {
json(parentField).asInstanceOf[Map[String, Any]]
}
entities: Option[String] = None): Try[Map[String, Any]] =
Try {
headers.foreach(
head => head.foreach { case (head, value) => uriRequest.addHeader(head, value) }
)

entities.foreach(
entity => uriRequest.asInstanceOf[HttpPost].setEntity(new StringEntity(entity))
)

val client = secureClient match {
case Some(secureClient) =>
logInfo(s"Using secure client")
secureClient
case _ => logInfo(s"Using non secure client")
clientNaive
}

val response = client.execute(uriRequest)

val rd = new BufferedReader(
new InputStreamReader(response.getEntity().getContent()))

val json = JSON.parseFull(Stream.continually(rd.readLine()).takeWhile(_ != null).mkString).
get.asInstanceOf[Map[String, Any]]
logTrace(s"getFrom Vault ${json.mkString("\n")}")
if (response.getStatusLine.getStatusCode != 200) {
val errors = json("errors").asInstanceOf[List[String]].mkString("\n")
throw new RuntimeException(errors)
}
else {
json(parentField).asInstanceOf[Map[String, Any]]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,23 @@ import javax.xml.bind.DatatypeConverter

import org.apache.spark.internal.Logging


object KerberosConfig extends Logging{

def prepareEnviroment(options: Map[String, String]): Map[String, String] = {
val kerberosVaultPath = options.get("KERBEROS_VAULT_PATH")
if(kerberosVaultPath.isDefined) {

options.get("KERBEROS_VAULT_PATH") map { kerberosVaultPath =>

val (keytab64, principal) =
VaultHelper.getKeytabPrincipalFromVault(kerberosVaultPath.get)
VaultHelper.getKeytabPrincipalFromVault(kerberosVaultPath).get

val keytabPath = getKeytabPrincipal(keytab64, principal)
Map("principal" -> principal, "keytabPath" -> keytabPath)
} else {

} getOrElse {
logInfo(s"tying to get ssl secrets from vault for Kerberos but not found vault path," +
s" skipping")
Map[String, String]()
Map.empty
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ package org.apache.spark.security
object MesosConfig {
def prepareEnvironment(options: Map[String, String]): Map[String, String] = {
options.filter(_._1.endsWith("MESOS_VAULT_PATH")).flatMap{case (_, path) =>
val (pass, user) = VaultHelper.getPassPrincipalFromVault(path)

val (pass, user) =
VaultHelper.getPassPrincipalFromVault(path).get

Seq(("spark.mesos.principal", user), ("spark.mesos.secret", pass))

}
}
}
102 changes: 66 additions & 36 deletions core/src/main/scala/org/apache/spark/security/SSLConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,46 +30,89 @@ import sun.security.util.DerInputStream

import org.apache.spark.internal.Logging



object SSLConfig extends Logging {

val sslTypeDataStore = "DATASTORE"

private val sparkSSLPrefix = "spark.ssl."

def prepareEnvironment(sslType: String,
options: Map[String, String]): Map[String, String] = {

val sparkSSLPrefix = "spark.ssl."

val trustStore = VaultHelper.getAllCas
val trustPass = VaultHelper.getCAPass
val trustStorePath = generateTrustStore(sslType, trustStore, trustPass)
val trustStoreOptions = generateTruststoreOptions(sslType)
val keyStoreOptions = generateKeystoreOptions(sslType, options)

val vaultKeyPassPath = options(s"${sslType}_VAULT_KEY_PASS_PATH")

val keyPass =
VaultHelper.getCertPassForAppFromVault(vaultKeyPassPath).get

val keyPassOptions = Map(
s"$sparkSSLPrefix${sslType.toLowerCase}.keyPassword" -> keyPass
)

val certFilesPath =
Map(s"$sparkSSLPrefix${sslType.toLowerCase}.certPem.path" ->
s"${ConfigSecurity.secretsFolder}/cert.crt",
s"$sparkSSLPrefix${sslType.toLowerCase}.keyPKCS8.path" ->
s"${ConfigSecurity.secretsFolder}/key.pkcs8",
s"$sparkSSLPrefix${sslType.toLowerCase}.caPem.path" ->
s"${ConfigSecurity.secretsFolder}/ca.crt")

trustStoreOptions ++ keyStoreOptions ++ keyPassOptions ++ certFilesPath
}

private def generateTruststoreOptions(sslType: String): Map[String, String] = {

val getTrustStoreAndPass = for {

trustStore <- VaultHelper.getAllCas
trustPass <- VaultHelper.getCAPass

} yield {

generatePemFile(trustStore, "ca.crt")

(generateTrustStore(sslType, trustStore, trustPass), trustPass)
}

val (trustStorePath, trustPass) = getTrustStoreAndPass.get

logInfo(s"Setting SSL values for $sslType")

val trustStoreOptions =

Map(s"$sparkSSLPrefix${sslType.toLowerCase}.enabled" -> "true",
s"$sparkSSLPrefix${sslType.toLowerCase}.trustStore" -> trustStorePath,
s"$sparkSSLPrefix${sslType.toLowerCase}.trustStorePassword" -> trustPass,
s"$sparkSSLPrefix${sslType.toLowerCase}.security.protocol" -> "SSL")

val vaultKeystorePath = options.get(s"${sslType}_VAULT_CERT_PATH")
}

val vaultKeystorePassPath = options.get(s"${sslType}_VAULT_CERT_PASS_PATH")
val certName = options.get(s"${sslType}_CERTIFICATE_NAME")
private def generateKeystoreOptions(
sslType: String,
options: Map[String, String]
): Map[String, String] = {

val keyStoreOptions = if (vaultKeystorePath.isDefined
&& vaultKeystorePassPath.isDefined) {

val (key, certs) = certName match {
case Some(cert) => VaultHelper.getCertKeyForAppFromVault(
vaultKeystorePath.get, certName.get)
case None => VaultHelper.getCertKeyForAppFromVault(vaultKeystorePath.get)
}
val keystoreOptions = for {

vaultKeystorePath <- options.get(s"${sslType}_VAULT_CERT_PATH")
vaultKeystorePassPath <- options.get(s"${sslType}_VAULT_CERT_PASS_PATH")

} yield {

val certName = options.get(s"${sslType}_CERTIFICATE_NAME")

val (key, certs) =
VaultHelper.getCertKeyForAppFromVault(vaultKeystorePath, certName).get

pemToDer(key)
generatePemFile(certs, "cert.crt")
generatePemFile(trustStore, "ca.crt")

val pass = VaultHelper.getCertPassForAppFromVault( vaultKeystorePassPath.get)
val pass =
VaultHelper.getCertPassForAppFromVault( vaultKeystorePassPath).get

val keyStorePath = generateKeyStore(sslType, certs, key, pass)

Expand All @@ -78,27 +121,14 @@ object SSLConfig extends Logging {
s"$sparkSSLPrefix${sslType.toLowerCase}.protocol" -> "TLSv1.2",
s"$sparkSSLPrefix${sslType.toLowerCase}.needClientAuth" -> "true"
)

} else {
logInfo(s"trying to get ssl secrets from vault for ${sslType.toLowerCase} keyStore" +
s" but not found pass and cert vault paths, exiting")
Map[String, String]()
}

val vaultKeyPassPath = options.get(s"${sslType}_VAULT_KEY_PASS_PATH")

val keyPass = Map(s"$sparkSSLPrefix${sslType.toLowerCase}.keyPassword"
-> VaultHelper.getCertPassForAppFromVault(vaultKeyPassPath.get))

val certFilesPath =
Map(s"$sparkSSLPrefix${sslType.toLowerCase}.certPem.path" ->
s"${ConfigSecurity.secretsFolder}/cert.crt",
s"$sparkSSLPrefix${sslType.toLowerCase}.keyPKCS8.path" ->
s"${ConfigSecurity.secretsFolder}/key.pkcs8",
s"$sparkSSLPrefix${sslType.toLowerCase}.caPem.path" ->
s"${ConfigSecurity.secretsFolder}/ca.crt")
keystoreOptions.getOrElse {
logInfo(s"trying to get ssl secrets from vault for ${sslType.toLowerCase} keyStore" +
s" but not found pass and cert vault paths, exiting")
Map.empty
}

trustStoreOptions ++ keyStoreOptions ++ keyPass ++ certFilesPath
}

def generateTrustStore(sslType: String, cas: String, password: String): String = {
Expand Down
Loading

0 comments on commit 9da8ad4

Please sign in to comment.