Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WX-1346] Scalafmt #7257

Merged
merged 5 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
18 changes: 18 additions & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
version = 3.7.17
align.preset = none
align.openParenCallSite = true
align.openParenDefnSite = true
maxColumn = 120
continuationIndent.defnSite = 2
assumeStandardLibraryStripMargin = true
align.stripMargin = true
danglingParentheses.preset = true
rewrite.rules = [Imports, RedundantBraces, RedundantParens, SortModifiers]
rewrite.imports.sort = scalastyle
docstrings.style = keep
project.excludeFilters = [
Dependencies.scala,
Settings.scala,
build.sbt
]
runner.dialect = scala213
23 changes: 13 additions & 10 deletions CromIAM/src/main/scala/cromiam/auth/Collection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import scala.util.{Success, Try}
final case class Collection(name: String) extends AnyVal

object Collection {

/**
* Parses a raw JSON string to make sure it fits the standard pattern (see below) for labels,
* performs some CromIAM-specific checking to ensure the user isn't attempting to manipulate the
Expand All @@ -19,13 +20,14 @@ object Collection {
*/
def validateLabels(labelsJson: Option[String]): Directive1[Option[Map[String, JsValue]]] = {

val labels = labelsJson map { l =>
Try(l.parseJson) match {
case Success(JsObject(json)) if json.keySet.contains(CollectionLabelName) => throw new LabelContainsCollectionException
case Success(JsObject(json)) => json
case _ => throw InvalidLabelsException(l)
}
val labels = labelsJson map { l =>
Try(l.parseJson) match {
case Success(JsObject(json)) if json.keySet.contains(CollectionLabelName) =>
throw new LabelContainsCollectionException
case Success(JsObject(json)) => json
case _ => throw InvalidLabelsException(l)
}
}

provide(labels)
}
Expand All @@ -34,15 +36,16 @@ object Collection {
val LabelsKey = "labels"

// LabelContainsCollectionException is a class because of ScalaTest, some of the constructs don't play well w/ case objects
final class LabelContainsCollectionException extends Exception(s"Submitted labels contain the key $CollectionLabelName, which is not allowed\n")
final case class InvalidLabelsException(labels: String) extends Exception(s"Labels must be a valid JSON object, received: $labels\n")
final class LabelContainsCollectionException
extends Exception(s"Submitted labels contain the key $CollectionLabelName, which is not allowed\n")
final case class InvalidLabelsException(labels: String)
extends Exception(s"Labels must be a valid JSON object, received: $labels\n")

/**
* Returns the default collection for a user.
*/
def forUser(user: User): Collection = {
def forUser(user: User): Collection =
Collection(user.userId.value)
}

implicit val collectionJsonReader = new JsonReader[Collection] {
import spray.json.DefaultJsonProtocol._
Expand Down
1 change: 0 additions & 1 deletion CromIAM/src/main/scala/cromiam/auth/User.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ import org.broadinstitute.dsde.workbench.model.WorkbenchUserId
* Wraps the concept of an authenticated workbench user including their numeric ID as well as their bearer token
*/
final case class User(userId: WorkbenchUserId, authorization: Authorization)

63 changes: 37 additions & 26 deletions CromIAM/src/main/scala/cromiam/cromwell/CromwellClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ import scala.concurrent.{ExecutionContextExecutor, Future}
*
* FIXME: Look for ways to synch this up with the mothership
*/
class CromwellClient(scheme: String, interface: String, port: Int, log: LoggingAdapter, serviceRegistryActorRef: ActorRef)(implicit system: ActorSystem,
ece: ExecutionContextExecutor,
materializer: ActorMaterializer)
extends SprayJsonSupport with DefaultJsonProtocol with StatusCheckedSubsystem with CromIamInstrumentation{
class CromwellClient(scheme: String,
interface: String,
port: Int,
log: LoggingAdapter,
serviceRegistryActorRef: ActorRef
)(implicit system: ActorSystem, ece: ExecutionContextExecutor, materializer: ActorMaterializer)
extends SprayJsonSupport
with DefaultJsonProtocol
with StatusCheckedSubsystem
with CromIamInstrumentation {

val cromwellUrl = new URL(s"$scheme://$interface:$port")
val cromwellApiVersion = "v1"
Expand All @@ -41,35 +47,38 @@ class CromwellClient(scheme: String, interface: String, port: Int, log: LoggingA

def collectionForWorkflow(workflowId: String,
user: User,
cromIamRequest: HttpRequest): FailureResponseOrT[Collection] = {
cromIamRequest: HttpRequest
): FailureResponseOrT[Collection] = {
import CromwellClient.EnhancedWorkflowLabels

log.info("Requesting collection for " + workflowId + " for user " + user.userId + " from metadata")

// Look up in Cromwell what the collection is for this workflow. If it doesn't exist, fail the Future
val cromwellApiLabelFunc = () => cromwellApiClient.labels(WorkflowId.fromString(workflowId), headers = List(user.authorization)) flatMap {
_.caasCollection match {
case Some(c) => FailureResponseOrT.pure[IO, HttpResponse](c)
case None =>
val exception = new IllegalArgumentException(s"Workflow $workflowId has no associated collection")
val failure = IO.raiseError[Collection](exception)
FailureResponseOrT.right[HttpResponse](failure)
val cromwellApiLabelFunc = () =>
cromwellApiClient.labels(WorkflowId.fromString(workflowId), headers = List(user.authorization)) flatMap {
_.caasCollection match {
case Some(c) => FailureResponseOrT.pure[IO, HttpResponse](c)
case None =>
val exception = new IllegalArgumentException(s"Workflow $workflowId has no associated collection")
val failure = IO.raiseError[Collection](exception)
FailureResponseOrT.right[HttpResponse](failure)
}
}
}

instrumentRequest(cromwellApiLabelFunc, cromIamRequest, wfCollectionPrefix)
}

def forwardToCromwell(httpRequest: HttpRequest): FailureResponseOrT[HttpResponse] = {
val future = {
// See CromwellClient's companion object for info on these header modifications
val headers = httpRequest.headers.filterNot(header => header.name == TimeoutAccessHeader || header.name == HostHeader)
val headers =
httpRequest.headers.filterNot(header => header.name == TimeoutAccessHeader || header.name == HostHeader)
val cromwellRequest = httpRequest
.copy(uri = httpRequest.uri.withAuthority(interface, port).withScheme(scheme))
.withHeaders(headers)
Http().singleRequest(cromwellRequest)
} recoverWith {
case e => Future.failed(CromwellConnectionFailure(e))
} recoverWith { case e =>
Future.failed(CromwellConnectionFailure(e))
}
future.asFailureResponseOrT
}
Expand All @@ -86,7 +95,7 @@ class CromwellClient(scheme: String, interface: String, port: Int, log: LoggingA
use the current workflow id.

This is all called from inside the context of a Future, so exceptions will be properly caught.
*/
*/
metadata.value.parseJson.asJsObject.fields.get("rootWorkflowId").map(_.convertTo[String]).getOrElse(workflowId)
}

Expand All @@ -96,11 +105,13 @@ class CromwellClient(scheme: String, interface: String, port: Int, log: LoggingA
Grab the metadata from Cromwell filtered down to the rootWorkflowId. Then transform the response to get just the
root workflow ID itself
*/
val cromwellApiMetadataFunc = () => cromwellApiClient.metadata(
WorkflowId.fromString(workflowId),
args = Option(Map("includeKey" -> List("rootWorkflowId"))),
headers = List(user.authorization)).map(metadataToRootWorkflowId
)
val cromwellApiMetadataFunc = () =>
cromwellApiClient
.metadata(WorkflowId.fromString(workflowId),
args = Option(Map("includeKey" -> List("rootWorkflowId"))),
headers = List(user.authorization)
)
.map(metadataToRootWorkflowId)

instrumentRequest(cromwellApiMetadataFunc, cromIamRequest, rootWfIdPrefix)
}
Expand All @@ -120,14 +131,14 @@ object CromwellClient {
// See: https://broadworkbench.atlassian.net/browse/DDO-2190
val HostHeader = "Host"

final case class CromwellConnectionFailure(f: Throwable) extends Exception(s"Unable to connect to Cromwell (${f.getMessage})", f)
final case class CromwellConnectionFailure(f: Throwable)
extends Exception(s"Unable to connect to Cromwell (${f.getMessage})", f)

implicit class EnhancedWorkflowLabels(val wl: WorkflowLabels) extends AnyVal {

import Collection.{CollectionLabelName, collectionJsonReader}
import Collection.{collectionJsonReader, CollectionLabelName}

def caasCollection: Option[Collection] = {
def caasCollection: Option[Collection] =
wl.labels.fields.get(CollectionLabelName).map(_.convertTo[Collection])
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ trait CromIamInstrumentation extends CromwellInstrumentation {
val rootWfIdPrefix = NonEmptyList.one("root-workflow-id")
val wfCollectionPrefix = NonEmptyList.one("workflow-collection")


def convertRequestToPath(httpRequest: HttpRequest): NonEmptyList[String] = NonEmptyList.of(
// Returns the path of the URI only, without query parameters (e.g: api/engine/workflows/metadata)
httpRequest.uri.path.toString().stripPrefix("/")
httpRequest.uri.path
.toString()
.stripPrefix("/")
// Replace UUIDs with [id] to keep paths same regardless of the workflow
.replaceAll(CromIamInstrumentation.UUIDRegex, "[id]"),
// Name of the method (e.g: GET)
Expand All @@ -43,15 +44,19 @@ trait CromIamInstrumentation extends CromwellInstrumentation {
def makePathFromRequestAndResponse(httpRequest: HttpRequest, httpResponse: HttpResponse): InstrumentationPath =
convertRequestToPath(httpRequest).concatNel(NonEmptyList.of(httpResponse.status.intValue.toString))

def sendTimingApi(statsDPath: InstrumentationPath, timing: FiniteDuration, prefixToStatsd: NonEmptyList[String]): Unit = {
def sendTimingApi(statsDPath: InstrumentationPath,
timing: FiniteDuration,
prefixToStatsd: NonEmptyList[String]
): Unit =
sendTiming(prefixToStatsd.concatNel(statsDPath), timing, CromIamPrefix)
}

def instrumentationPrefixForSam(methodPrefix: NonEmptyList[String]): NonEmptyList[String] = samPrefix.concatNel(methodPrefix)
def instrumentationPrefixForSam(methodPrefix: NonEmptyList[String]): NonEmptyList[String] =
samPrefix.concatNel(methodPrefix)

def instrumentRequest[A](func: () => FailureResponseOrT[A],
httpRequest: HttpRequest,
prefix: NonEmptyList[String]): FailureResponseOrT[A] = {
prefix: NonEmptyList[String]
): FailureResponseOrT[A] = {
def now(): Deadline = Deadline.now

val startTimestamp = now()
Expand Down
66 changes: 39 additions & 27 deletions CromIAM/src/main/scala/cromiam/sam/SamClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,21 @@ class SamClient(scheme: String,
port: Int,
checkSubmitWhitelist: Boolean,
log: LoggingAdapter,
serviceRegistryActorRef: ActorRef)
(implicit system: ActorSystem, ece: ExecutionContextExecutor, materializer: ActorMaterializer) extends StatusCheckedSubsystem with CromIamInstrumentation {
serviceRegistryActorRef: ActorRef
)(implicit system: ActorSystem, ece: ExecutionContextExecutor, materializer: ActorMaterializer)
extends StatusCheckedSubsystem
with CromIamInstrumentation {

private implicit val cs = IO.contextShift(ece)
implicit private val cs = IO.contextShift(ece)

override val statusUri = uri"$samBaseUri/status"
override val serviceRegistryActor: ActorRef = serviceRegistryActorRef

def isSubmitWhitelisted(user: User, cromIamRequest: HttpRequest): FailureResponseOrT[Boolean] = {
def isSubmitWhitelisted(user: User, cromIamRequest: HttpRequest): FailureResponseOrT[Boolean] =
checkSubmitWhitelist.fold(
isSubmitWhitelistedSam(user, cromIamRequest),
FailureResponseOrT.pure(true)
)
}

def isSubmitWhitelistedSam(user: User, cromIamRequest: HttpRequest): FailureResponseOrT[Boolean] = {
val request = HttpRequest(
Expand All @@ -64,7 +65,7 @@ class SamClient(scheme: String,
whitelisted <- response.status match {
case StatusCodes.OK =>
// Does not seem to be already provided?
implicit val entityToBooleanUnmarshaller : Unmarshaller[HttpEntity, Boolean] =
implicit val entityToBooleanUnmarshaller: Unmarshaller[HttpEntity, Boolean] =
(Unmarshaller.stringUnmarshaller flatMap Unmarshaller.booleanFromStringUnmarshaller).asScala
val unmarshal = IO.fromFuture(IO(Unmarshal(response.entity).to[Boolean]))
FailureResponseOrT.right[HttpResponse](unmarshal)
Expand Down Expand Up @@ -95,14 +96,19 @@ class SamClient(scheme: String,
userInfo.enabled
}
case _ =>
log.error("Could not verify access with Sam for user {}, error was {} {}", user.userId, response.status, response.toString().take(100))
log.error("Could not verify access with Sam for user {}, error was {} {}",
user.userId,
response.status,
response.toString().take(100)
)
FailureResponseOrT.pure[IO, HttpResponse](false)
}
} yield userEnabled
}

def collectionsForUser(user: User, cromIamRequest: HttpRequest): FailureResponseOrT[List[Collection]] = {
val request = HttpRequest(method = HttpMethods.GET, uri = samBaseCollectionUri, headers = List[HttpHeader](user.authorization))
val request =
HttpRequest(method = HttpMethods.GET, uri = samBaseCollectionUri, headers = List[HttpHeader](user.authorization))

for {
response <- instrumentRequest(
Expand All @@ -120,24 +126,25 @@ class SamClient(scheme: String,
* @return Successful future if the auth is accepted, a Failure otherwise.
*/
def requestAuth(authorizationRequest: CollectionAuthorizationRequest,
cromIamRequest: HttpRequest): FailureResponseOrT[Unit] = {
cromIamRequest: HttpRequest
): FailureResponseOrT[Unit] = {
val logString = authorizationRequest.action + " access for user " + authorizationRequest.user.userId +
" on a request to " + authorizationRequest.action + " for collection " + authorizationRequest.collection.name
" on a request to " + authorizationRequest.action + " for collection " + authorizationRequest.collection.name

def validateEntityBytes(byteString: ByteString): FailureResponseOrT[Unit] = {
def validateEntityBytes(byteString: ByteString): FailureResponseOrT[Unit] =
if (byteString.utf8String == "true") {
Monad[FailureResponseOrT].unit
} else {
log.warning("Sam denied " + logString)
FailureResponseOrT[IO, HttpResponse, Unit](IO.raiseError(new SamDenialException))
}
}

log.info("Requesting authorization for " + logString)

val request = HttpRequest(method = HttpMethods.GET,
uri = samAuthorizeActionUri(authorizationRequest),
headers = List[HttpHeader](authorizationRequest.user.authorization))
uri = samAuthorizeActionUri(authorizationRequest),
headers = List[HttpHeader](authorizationRequest.user.authorization)
)

for {
response <- instrumentRequest(
Expand All @@ -158,26 +165,28 @@ class SamClient(scheme: String,
- If user has the 'add' permission we're ok
- else fail the future
*/
def requestSubmission(user: User,
collection: Collection,
cromIamRequest: HttpRequest
): FailureResponseOrT[Unit] = {
def requestSubmission(user: User, collection: Collection, cromIamRequest: HttpRequest): FailureResponseOrT[Unit] = {
log.info("Verifying user " + user.userId + " can submit a workflow to collection " + collection.name)
val createCollection = registerCreation(user, collection, cromIamRequest)

createCollection flatMap {
case r if r.status == StatusCodes.NoContent => Monad[FailureResponseOrT].unit
case r => FailureResponseOrT[IO, HttpResponse, Unit](IO.raiseError(SamRegisterCollectionException(r.status)))
} recoverWith {
case r if r.status == StatusCodes.Conflict => requestAuth(CollectionAuthorizationRequest(user, collection, "add"), cromIamRequest)
case r if r.status == StatusCodes.Conflict =>
requestAuth(CollectionAuthorizationRequest(user, collection, "add"), cromIamRequest)
case r => FailureResponseOrT[IO, HttpResponse, Unit](IO.raiseError(SamRegisterCollectionException(r.status)))
}
}

protected def registerCreation(user: User,
collection: Collection,
cromIamRequest: HttpRequest): FailureResponseOrT[HttpResponse] = {
val request = HttpRequest(method = HttpMethods.POST, uri = samRegisterUri(collection), headers = List[HttpHeader](user.authorization))
cromIamRequest: HttpRequest
): FailureResponseOrT[HttpResponse] = {
val request = HttpRequest(method = HttpMethods.POST,
uri = samRegisterUri(collection),
headers = List[HttpHeader](user.authorization)
)

instrumentRequest(
() => Http().singleRequest(request).asFailureResponseOrT,
Expand All @@ -186,9 +195,9 @@ class SamClient(scheme: String,
)
}

private def samAuthorizeActionUri(authorizationRequest: CollectionAuthorizationRequest) = {
akka.http.scaladsl.model.Uri(s"${samBaseUriForWorkflow(authorizationRequest.collection)}/action/${authorizationRequest.action}")
}
private def samAuthorizeActionUri(authorizationRequest: CollectionAuthorizationRequest) =
akka.http.scaladsl.model
.Uri(s"${samBaseUriForWorkflow(authorizationRequest.collection)}/action/${authorizationRequest.action}")

private def samRegisterUri(collection: Collection) = akka.http.scaladsl.model.Uri(samBaseUriForWorkflow(collection))

Expand All @@ -207,15 +216,18 @@ object SamClient {

class SamDenialException extends Exception("Access Denied")

final case class SamConnectionFailure(phase: String, f: Throwable) extends Exception(s"Unable to connect to Sam during $phase (${f.getMessage})", f)
final case class SamConnectionFailure(phase: String, f: Throwable)
extends Exception(s"Unable to connect to Sam during $phase (${f.getMessage})", f)

final case class SamRegisterCollectionException(errorCode: StatusCode) extends Exception(s"Can't register collection with Sam. Status code: ${errorCode.value}")
final case class SamRegisterCollectionException(errorCode: StatusCode)
extends Exception(s"Can't register collection with Sam. Status code: ${errorCode.value}")

final case class CollectionAuthorizationRequest(user: User, collection: Collection, action: String)

val SamDenialResponse = HttpResponse(status = StatusCodes.Forbidden, entity = new SamDenialException().getMessage)

def SamRegisterCollectionExceptionResp(statusCode: StatusCode) = HttpResponse(status = statusCode, entity = SamRegisterCollectionException(statusCode).getMessage)
def SamRegisterCollectionExceptionResp(statusCode: StatusCode) =
HttpResponse(status = statusCode, entity = SamRegisterCollectionException(statusCode).getMessage)

case class UserStatusInfo(adminEnabled: Boolean, enabled: Boolean, userEmail: String, userSubjectId: String)

Expand Down
Loading
Loading