Skip to content

Commit

Permalink
Merge branch 'develop' into WX-1506
Browse files Browse the repository at this point in the history
  • Loading branch information
THWiseman authored Mar 14, 2024
2 parents b5f1a94 + a8f6e9e commit 172b338
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 97 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ lazy val `cloud-nio-impl-ftp` = (project in cloudNio / "cloud-nio-impl-ftp")
lazy val `cloud-nio-impl-drs` = (project in cloudNio / "cloud-nio-impl-drs")
.withLibrarySettings(libraryName = "cloud-nio-impl-drs", dependencies = implDrsDependencies)
.dependsOn(`cloud-nio-util`)
.dependsOn(cloudSupport)
.dependsOn(common)
.dependsOn(common % "test->test")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import common.exception._
import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.http.HttpStatus

class DrsCloudNioFileProvider(drsPathResolver: EngineDrsPathResolver, drsReadInterpreter: DrsReadInterpreter)
class DrsCloudNioFileProvider(drsPathResolver: DrsPathResolver, drsReadInterpreter: DrsReadInterpreter)
extends CloudNioFileProvider {

private def checkIfPathExistsThroughDrsResolver(drsPath: String): IO[Boolean] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class DrsCloudNioFileSystemProvider(rootConfig: Config,
lazy val drsResolverConfig = rootConfig.getConfig("resolver")
lazy val drsConfig: DrsConfig = DrsConfig.fromConfig(drsResolverConfig)

lazy val drsPathResolver: EngineDrsPathResolver =
EngineDrsPathResolver(drsConfig, drsCredentials)
lazy val drsPathResolver: DrsPathResolver = new DrsPathResolver(drsConfig, drsCredentials)

override def config: Config = rootConfig

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
package cloud.nio.impl.drs

import cats.syntax.validated._
import com.azure.core.credential.TokenRequestContext
import com.azure.core.management.AzureEnvironment
import com.azure.core.management.profile.AzureProfile
import com.azure.identity.DefaultAzureCredentialBuilder
import com.google.auth.oauth2.{AccessToken, GoogleCredentials, OAuth2Credentials}
import com.typesafe.config.Config
import common.validation.ErrorOr.ErrorOr
import net.ceedubs.ficus.Ficus._
import cromwell.cloudsupport.azure.AzureCredentials

import scala.concurrent.duration._
import scala.jdk.DurationConverters._
import scala.util.{Failure, Success, Try}

/**
Expand Down Expand Up @@ -76,38 +72,6 @@ case object GoogleAppDefaultTokenStrategy extends DrsCredentials {
* If you need to disambiguate among multiple active user-assigned managed identities, pass
* in the client id of the identity that should be used.
*/
case class AzureDrsCredentials(identityClientId: Option[String]) extends DrsCredentials {

final val tokenAcquisitionTimeout = 5.seconds

val azureProfile = new AzureProfile(AzureEnvironment.AZURE)
val tokenScope = "https://management.azure.com/.default"

def tokenRequestContext: TokenRequestContext = {
val trc = new TokenRequestContext()
trc.addScopes(tokenScope)
trc
}

def defaultCredentialBuilder: DefaultAzureCredentialBuilder =
new DefaultAzureCredentialBuilder()
.authorityHost(azureProfile.getEnvironment.getActiveDirectoryEndpoint)

def getAccessToken: ErrorOr[String] = {
val credentials = identityClientId
.foldLeft(defaultCredentialBuilder) { (builder, clientId) =>
builder.managedIdentityClientId(clientId)
}
.build()

Try(
credentials
.getToken(tokenRequestContext)
.block(tokenAcquisitionTimeout.toJava)
) match {
case Success(null) => "null token value attempting to obtain access token".invalidNel
case Success(token) => token.getToken.validNel
case Failure(error) => s"Failed to refresh access token: ${error.getMessage}".invalidNel
}
}
case class AzureDrsCredentials(identityClientId: Option[String] = None) extends DrsCredentials {
def getAccessToken: ErrorOr[String] = AzureCredentials.getAccessToken(identityClientId)
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import java.nio.ByteBuffer
import java.nio.channels.{Channels, ReadableByteChannel}
import scala.util.Try

abstract class DrsPathResolver(drsConfig: DrsConfig) {
class DrsPathResolver(drsConfig: DrsConfig, drsCredentials: DrsCredentials) {

protected lazy val httpClientBuilder: HttpClientBuilder = {
val clientBuilder = HttpClientBuilder.create()
Expand All @@ -38,7 +38,17 @@ abstract class DrsPathResolver(drsConfig: DrsConfig) {
clientBuilder
}

def getAccessToken: ErrorOr[String]
def getAccessToken: ErrorOr[String] = drsCredentials.getAccessToken

private lazy val currentCloudPlatform: Option[DrsCloudPlatform.Value] = drsCredentials match {
case _: GoogleOauthDrsCredentials => Option(DrsCloudPlatform.GoogleStorage)
case GoogleAppDefaultTokenStrategy => Option(DrsCloudPlatform.GoogleStorage)
case _: AzureDrsCredentials => Option(DrsCloudPlatform.Azure)
case _ => None
}

def makeDrsResolverRequest(drsPath: String, fields: NonEmptyList[DrsResolverField.Value]): DrsResolverRequest =
DrsResolverRequest(drsPath, currentCloudPlatform, fields)

private def makeHttpRequestToDrsResolver(drsPath: String,
fields: NonEmptyList[DrsResolverField.Value]
Expand All @@ -47,7 +57,7 @@ abstract class DrsPathResolver(drsConfig: DrsConfig) {
case Valid(token) =>
IO {
val postRequest = new HttpPost(drsConfig.drsResolverUrl)
val requestJson = DrsResolverRequest(drsPath, fields).asJson.noSpaces
val requestJson = makeDrsResolverRequest(drsPath, fields).asJson.noSpaces
postRequest.setEntity(new StringEntity(requestJson, ContentType.APPLICATION_JSON))
postRequest.setHeader("Authorization", s"Bearer $token")
postRequest
Expand Down Expand Up @@ -118,7 +128,9 @@ abstract class DrsPathResolver(drsConfig: DrsConfig) {
* Please note, this method returns an IO that would make a synchronous HTTP request to DRS Resolver when run.
*/
def resolveDrs(drsPath: String, fields: NonEmptyList[DrsResolverField.Value]): IO[DrsResolverResponse] =
rawDrsResolverResponse(drsPath, fields).use(httpResponseToDrsResolverResponse(drsPathForDebugging = drsPath))
rawDrsResolverResponse(drsPath, fields).use(
httpResponseToDrsResolverResponse(drsPathForDebugging = drsPath)
)

def openChannel(accessUrl: AccessUrl): IO[ReadableByteChannel] =
IO {
Expand Down Expand Up @@ -178,7 +190,20 @@ object DrsResolverField extends Enumeration {
val LocalizationPath: DrsResolverField.Value = Value("localizationPath")
}

final case class DrsResolverRequest(url: String, fields: NonEmptyList[DrsResolverField.Value])
// We supply a cloud platform value to the DRS service. In cases where the DRS repository
// has multiple cloud files associated with a DRS link, it will prefer sending a file on the same
// platform as this Cromwell instance. That is, if a DRS file has copies on both GCP and Azure,
// we'll get the GCP one when running on GCP and the Azure one when running on Azure.
object DrsCloudPlatform extends Enumeration {
val GoogleStorage: DrsCloudPlatform.Value = Value("gs")
val Azure: DrsCloudPlatform.Value = Value("azure")
val AmazonS3: DrsCloudPlatform.Value = Value("s3") // supported by DRSHub but not currently used by us
}

final case class DrsResolverRequest(url: String,
cloudPlatform: Option[DrsCloudPlatform.Value],
fields: NonEmptyList[DrsResolverField.Value]
)

final case class SADataObject(data: Json)

Expand Down Expand Up @@ -219,6 +244,8 @@ object DrsResolverResponseSupport {

implicit lazy val drsResolverFieldEncoder: Encoder[DrsResolverField.Value] =
Encoder.encodeEnumeration(DrsResolverField)
implicit lazy val drsResolverCloudPlatformEncoder: Encoder[DrsCloudPlatform.Value] =
Encoder.encodeEnumeration(DrsCloudPlatform)
implicit lazy val drsResolverRequestEncoder: Encoder[DrsResolverRequest] = deriveEncoder

implicit lazy val saDataObjectDecoder: Decoder[SADataObject] = deriveDecoder
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class DrsCloudNioFileProviderSpec extends AnyFlatSpecLike with CromwellTimeoutSp
}

it should "return a file provider that can read bytes from gcs" in {
val drsPathResolver = new MockEngineDrsPathResolver() {
val drsPathResolver = new MockDrsPathResolver() {
override def resolveDrs(drsPath: String, fields: NonEmptyList[DrsResolverField.Value]): IO[DrsResolverResponse] =
IO(DrsResolverResponse(gsUri = Option("gs://bucket/object/path")))
}
Expand All @@ -99,7 +99,7 @@ class DrsCloudNioFileProviderSpec extends AnyFlatSpecLike with CromwellTimeoutSp
}

it should "return a file provider that can read bytes from an access url" in {
val drsPathResolver = new MockEngineDrsPathResolver() {
val drsPathResolver = new MockDrsPathResolver() {
override def resolveDrs(drsPath: String, fields: NonEmptyList[DrsResolverField.Value]): IO[DrsResolverResponse] =
IO(DrsResolverResponse(accessUrl = Option(AccessUrl("https://host/object/path", None))))
}
Expand All @@ -121,7 +121,7 @@ class DrsCloudNioFileProviderSpec extends AnyFlatSpecLike with CromwellTimeoutSp
}

it should "return a file provider that can return file attributes" in {
val drsPathResolver = new MockEngineDrsPathResolver() {
val drsPathResolver = new MockDrsPathResolver() {
override def resolveDrs(drsPath: String,
fields: NonEmptyList[DrsResolverField.Value]
): IO[DrsResolverResponse] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package cloud.nio.impl.drs

import cats.data.NonEmptyList

import java.nio.file.attribute.FileTime
import java.time.OffsetDateTime

import cloud.nio.impl.drs.DrsCloudNioRegularFileAttributes._
import cloud.nio.spi.{FileHash, HashType}
import common.assertion.CromwellTimeoutSpec
Expand Down Expand Up @@ -157,6 +158,18 @@ class DrsPathResolverSpec extends AnyFlatSpecLike with CromwellTimeoutSpec with
val responseStatusLine = new BasicStatusLine(new ProtocolVersion("http", 1, 2), 345, "test-reason")
val testDrsResolverUri = "www.drshub_v4.com"

it should "construct the right request when using Azure creds" in {
val resolver = new MockDrsPathResolver(drsCredentials = AzureDrsCredentials())
val drsRequest = resolver.makeDrsResolverRequest(drsPathForDebugging, NonEmptyList.of(DrsResolverField.AccessUrl))
drsRequest.cloudPlatform shouldBe Option(DrsCloudPlatform.Azure)
}

it should "construct the right request when using Google creds" in {
val resolver = new MockDrsPathResolver()
val drsRequest = resolver.makeDrsResolverRequest(drsPathForDebugging, NonEmptyList.of(DrsResolverField.AccessUrl))
drsRequest.cloudPlatform shouldBe Option(DrsCloudPlatform.GoogleStorage)
}

it should "construct an error message from a populated, well-formed failure response" in {
val failureResponse = Option(failureResponseJson)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,23 @@ import com.google.cloud.NoCredentials
import com.typesafe.config.{Config, ConfigFactory}
import org.apache.http.impl.client.HttpClientBuilder

import scala.concurrent.duration.Duration

class MockDrsCloudNioFileSystemProvider(config: Config = mockConfig,
httpClientBuilder: Option[HttpClientBuilder] = None,
drsReadInterpreter: DrsReadInterpreter = (_, _) =>
IO.raiseError(
new UnsupportedOperationException("mock did not specify a read interpreter")
),
mockResolver: Option[EngineDrsPathResolver] = None
mockResolver: Option[DrsPathResolver] = None
) extends DrsCloudNioFileSystemProvider(config,
GoogleOauthDrsCredentials(NoCredentials.getInstance, config),
drsReadInterpreter
) {

override lazy val drsPathResolver: EngineDrsPathResolver =
override lazy val drsPathResolver: DrsPathResolver =
mockResolver getOrElse
new MockEngineDrsPathResolver(
new MockDrsPathResolver(
drsConfig = drsConfig,
httpClientBuilderOverride = httpClientBuilder,
accessTokenAcceptableTTL = Duration.Inf
httpClientBuilderOverride = httpClientBuilder
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ import common.mock.MockSugar

import scala.concurrent.duration.Duration

class MockEngineDrsPathResolver(drsConfig: DrsConfig = MockDrsPaths.mockDrsConfig,
httpClientBuilderOverride: Option[HttpClientBuilder] = None,
accessTokenAcceptableTTL: Duration = Duration.Inf
) extends EngineDrsPathResolver(drsConfig,
GoogleOauthDrsCredentials(NoCredentials.getInstance, accessTokenAcceptableTTL)
) {
class MockDrsPathResolver(drsConfig: DrsConfig = MockDrsPaths.mockDrsConfig,
httpClientBuilderOverride: Option[HttpClientBuilder] = None,
drsCredentials: DrsCredentials =
GoogleOauthDrsCredentials(NoCredentials.getInstance, Duration.Inf)
) extends DrsPathResolver(drsConfig, drsCredentials) {

override protected lazy val httpClientBuilder: HttpClientBuilder =
httpClientBuilderOverride getOrElse MockSugar.mock[HttpClientBuilder]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ class DrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
}
}

def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] =
def getDrsPathResolver: IO[DrsPathResolver] =
IO {
val drsConfig = DrsConfig.fromEnv(sys.env)
logger.info(s"Using ${drsConfig.drsResolverUrl} to resolve DRS Objects")
new DrsLocalizerDrsPathResolver(drsConfig, drsCredentials)
new DrsPathResolver(drsConfig, drsCredentials)
}

/**
Expand Down Expand Up @@ -182,9 +182,7 @@ class DrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
/**
* Runs a synchronous HTTP request to resolve the provided DRS URL with the provided resolver.
*/
def resolveSingleUrl(resolverObject: DrsLocalizerDrsPathResolver,
drsUrlToResolve: UnresolvedDrsUrl
): IO[ResolvedDrsUrl] = {
def resolveSingleUrl(resolverObject: DrsPathResolver, drsUrlToResolve: UnresolvedDrsUrl): IO[ResolvedDrsUrl] = {
val fields = NonEmptyList.of(DrsResolverField.GsUri,
DrsResolverField.GoogleServiceAccount,
DrsResolverField.AccessUrl,
Expand Down Expand Up @@ -213,7 +211,7 @@ class DrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
}
}

def resolveWithRetries(resolverObject: DrsLocalizerDrsPathResolver,
def resolveWithRetries(resolverObject: DrsPathResolver,
drsUrlToResolve: UnresolvedDrsUrl,
resolutionRetries: Int,
backoff: Option[CloudNioBackoff],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import cats.data.NonEmptyList
import cats.effect.{ExitCode, IO}
import cats.syntax.validated._
import drs.localizer.MockDrsPaths.{fakeAccessUrls, fakeDrsUrlWithGcsResolutionOnly, fakeGoogleUrls}
import cloud.nio.impl.drs.{AccessUrl, DrsConfig, DrsCredentials, DrsResolverField, DrsResolverResponse}
import cloud.nio.impl.drs.{AccessUrl, DrsConfig, DrsCredentials, DrsPathResolver, DrsResolverField, DrsResolverResponse}
import common.assertion.CromwellTimeoutSpec
import common.validation.ErrorOr.ErrorOr
import drs.localizer.MockDrsLocalizerDrsPathResolver.{FakeAccessTokenStrategy, FakeHashes}
Expand Down Expand Up @@ -372,11 +372,11 @@ class MockDrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
requesterPaysProjectIdOption
) {

override def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] =
override def getDrsPathResolver: IO[DrsPathResolver] =
IO {
new MockDrsLocalizerDrsPathResolver(cloud.nio.impl.drs.MockDrsPaths.mockDrsConfig)
}
override def resolveSingleUrl(resolverObject: DrsLocalizerDrsPathResolver,
override def resolveSingleUrl(resolverObject: DrsPathResolver,
drsUrlToResolve: UnresolvedDrsUrl
): IO[ResolvedDrsUrl] =
IO {
Expand All @@ -391,7 +391,7 @@ class MockDrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
}

class MockDrsLocalizerDrsPathResolver(drsConfig: DrsConfig)
extends DrsLocalizerDrsPathResolver(drsConfig, FakeAccessTokenStrategy) {
extends DrsPathResolver(drsConfig, FakeAccessTokenStrategy) {

override def resolveDrs(drsPath: String, fields: NonEmptyList[DrsResolverField.Value]): IO[DrsResolverResponse] = {

Expand Down
Loading

0 comments on commit 172b338

Please sign in to comment.