Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Dec 19, 2024
1 parent 27e9f42 commit 67af420
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 42 deletions.
11 changes: 9 additions & 2 deletions crates/polars-io/src/cloud/credential_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ mod python_impl {

#[cfg(feature = "azure")]
fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
use object_store::azure::AzureAccessKey;
use polars_error::{to_compute_err, PolarsResult};

use crate::cloud::credential_provider::{
Expand All @@ -570,15 +571,21 @@ mod python_impl {

// We only support bearer for now
match k.as_ref() {
"account_key" => {
credentials = object_store::azure::AzureCredential::AccessKey(
AzureAccessKey::try_new(v.as_str())
.map_err(|e| PyValueError::new_err(e.to_string()))?,
)
},
"bearer_token" => {
credentials =
object_store::azure::AzureCredential::BearerToken(v)
},
v => {
return pyo3::PyResult::Err(PyValueError::new_err(format!(
"unknown configuration key for azure: {}, \
valid configuration keys are: {}",
v, "bearer_token",
valid configuration keys are: {}, {}",
v, "account_key", "bearer_token",
)))
},
}
Expand Down
22 changes: 13 additions & 9 deletions crates/polars-io/src/cloud/glob.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use futures::TryStreamExt;
use object_store::path::Path;
use polars_core::error::to_compute_err;
use polars_core::prelude::{polars_ensure, polars_err};
use polars_error::PolarsResult;
use polars_core::prelude::polars_ensure;
use polars_error::{polars_bail, PolarsResult};
use polars_utils::format_pl_smallstr;
use polars_utils::pl_str::PlSmallStr;
use regex::Regex;
use url::Url;
Expand Down Expand Up @@ -98,13 +99,16 @@ impl CloudLocation {
}

let key = parsed.path();
let bucket = parsed
.host()
.ok_or_else(
|| polars_err!(ComputeError: "cannot parse bucket (host) from url: {}", parsed),
)?
.to_string()
.into();

let bucket = format_pl_smallstr!(
"{}",
&parsed[url::Position::BeforeUsername..url::Position::AfterPort]
);

if bucket.is_empty() {
polars_bail!(ComputeError: "CloudLocation::from_url(): empty bucket: {}", parsed);
}

(bucket, key)
};

Expand Down
42 changes: 28 additions & 14 deletions crates/polars-io/src/cloud/object_store_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use polars_core::config;
use polars_error::{polars_bail, to_compute_err, PolarsError, PolarsResult};
use polars_utils::aliases::PlHashMap;
use polars_utils::pl_str::PlSmallStr;
use polars_utils::{format_pl_smallstr, pl_serialize};
use tokio::sync::RwLock;
use url::Url;

Expand All @@ -17,7 +18,7 @@ use crate::cloud::CloudConfig;
/// get rate limited when querying the DNS (can take up to 5s).
/// Other reasons are connection pools that must be shared between as much as possible.
#[allow(clippy::type_complexity)]
static OBJECT_STORE_CACHE: Lazy<RwLock<PlHashMap<String, PolarsObjectStore>>> =
static OBJECT_STORE_CACHE: Lazy<RwLock<PlHashMap<Vec<u8>, PolarsObjectStore>>> =
Lazy::new(Default::default);

#[allow(dead_code)]
Expand All @@ -29,10 +30,10 @@ fn err_missing_feature(feature: &str, scheme: &str) -> PolarsResult<Arc<dyn Obje
}

/// Get the key of a url for object store registration.
fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String {
fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> Vec<u8> {
#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct S {
struct C {
max_retries: usize,
#[cfg(feature = "file_cache")]
file_cache_ttl: u64,
Expand All @@ -41,8 +42,15 @@ fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String {
credential_provider: usize,
}

#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct S {
url_base: PlSmallStr,
cloud_options: Option<C>,
}

// We include credentials as they can expire, so users will send new credentials for the same url.
let creds = serde_json::to_string(&options.map(
let cloud_options = options.map(
|CloudOptions {
// Destructure to ensure this breaks if anything changes.
max_retries,
Expand All @@ -52,7 +60,7 @@ fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String {
#[cfg(feature = "cloud")]
credential_provider,
}| {
S {
C {
max_retries: *max_retries,
#[cfg(feature = "file_cache")]
file_cache_ttl: *file_cache_ttl,
Expand All @@ -61,15 +69,21 @@ fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String {
credential_provider: credential_provider.as_ref().map_or(0, |x| x.func_addr()),
}
},
))
.unwrap();

format!(
"{}://{}<\\creds\\>{}",
url.scheme(),
&url[url::Position::BeforeHost..url::Position::AfterPort],
creds
)
);

let cache_key = S {
url_base: format_pl_smallstr!(
"{}",
&url[url::Position::BeforeScheme..url::Position::AfterPort]
),
cloud_options,
};

if config::verbose() {
eprintln!("object store cache key: {} {:?}", url, &cache_key);
}

pl_serialize::serialize_to_bytes(&cache_key).unwrap()
}

/// Construct an object_store `Path` from a string without any encoding/decoding.
Expand Down
117 changes: 112 additions & 5 deletions crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,16 +406,20 @@ impl CloudOptions {
pub fn build_azure(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
use super::credential_provider::IntoCredentialProvider;

let mut builder = if self.credential_provider.is_none() {
MicrosoftAzureBuilder::from_env()
} else {
MicrosoftAzureBuilder::new()
};
let mut storage_account: Option<PlSmallStr> = None;

// The credential provider `self.credentials` is prioritized if it is set. We also need
// `from_env()` as it may source environment configured storage account name.
let mut builder = MicrosoftAzureBuilder::from_env();

if let Some(options) = &self.config {
let CloudConfig::Azure(options) = options else {
panic!("impl error: cloud type mismatch")
};
for (key, value) in options.iter() {
if key == &AzureConfigKey::AccountName {
storage_account = Some(value.into());
}
builder = builder.with_config(*key, value);
}
}
Expand All @@ -425,8 +429,18 @@ impl CloudOptions {
.with_url(url)
.with_retry(get_retry_config(self.max_retries));

// Prefer the one embedded in the path
storage_account = extract_adls_uri_storage_account(url)
.map(|x| x.into())
.or(storage_account);

let builder = if let Some(v) = self.credential_provider.clone() {
builder.with_credentials(v.into_azure_provider())
} else if let Some(v) = storage_account
.as_deref()
.and_then(get_azure_storage_account_key)
{
builder.with_access_key(v)
} else {
builder
};
Expand Down Expand Up @@ -610,6 +624,99 @@ impl CloudOptions {
}
}

/// ```text
/// "abfss://{CONTAINER}@{STORAGE_ACCOUNT}.dfs.core.windows.net/"
/// ^^^^^^^^^^^^^^^^^
/// ```
#[cfg(feature = "azure")]
fn extract_adls_uri_storage_account(path: &str) -> Option<&str> {
Some(
path.split_once("://")?
.1
.split_once('/')?
.0
.split_once('@')?
.1
.split_once(".dfs.core.windows.net")?
.0,
)
}

/// Attempt to retrieve the storage account key for this account using the Azure CLI.
#[cfg(feature = "azure")]
fn get_azure_storage_account_key(account_name: &str) -> Option<String> {
if polars_core::config::verbose() {
eprintln!(
"get_azure_storage_account_key: storage_account_name: {}",
account_name
);
}

let mut cmd = if cfg!(target_family = "windows") {
// https://github.com/apache/arrow-rs/blob/565c24b8071269b02c3937e34c51eacf0f4cbad6/object_store/src/azure/credential.rs#L877-L894
let mut v = std::process::Command::new("cmd");
v.args([
"/C",
"az",
"storage",
"account",
"keys",
"list",
"--output",
"json",
"--account-name",
account_name,
]);
v
} else {
let mut v = std::process::Command::new("az");
v.args([
"storage",
"account",
"keys",
"list",
"--output",
"json",
"--account-name",
account_name,
]);
v
};

let json_resp = cmd
.output()
.ok()
.filter(|x| x.status.success())
.map(|x| String::from_utf8(x.stdout))?
.ok()?;

// [
// {
// "creationTime": "1970-01-01T00:00:00.000000+00:00",
// "keyName": "key1",
// "permissions": "FULL",
// "value": "..."
// },
// {
// "creationTime": "1970-01-01T00:00:00.000000+00:00",
// "keyName": "key2",
// "permissions": "FULL",
// "value": "..."
// }
// ]

#[derive(Debug, serde::Deserialize)]
struct S {
value: String,
}

let resp: Vec<S> = serde_json::from_str(&json_resp).ok()?;

let access_key = resp.into_iter().next()?.value;

Some(access_key)
}

#[cfg(feature = "cloud")]
#[cfg(test)]
mod tests {
Expand Down
18 changes: 8 additions & 10 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,19 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult

let sources = match &scan_type {
#[cfg(feature = "parquet")]
FileScan::Parquet {
ref cloud_options, ..
} => sources
FileScan::Parquet { cloud_options, .. } => sources
.expand_paths_with_hive_update(&mut file_options, cloud_options.as_ref())?,
#[cfg(feature = "ipc")]
FileScan::Ipc {
ref cloud_options, ..
} => sources
FileScan::Ipc { cloud_options, .. } => sources
.expand_paths_with_hive_update(&mut file_options, cloud_options.as_ref())?,
#[cfg(feature = "csv")]
FileScan::Csv {
ref cloud_options, ..
} => sources.expand_paths(&file_options, cloud_options.as_ref())?,
FileScan::Csv { cloud_options, .. } => {
sources.expand_paths(&file_options, cloud_options.as_ref())?
},
#[cfg(feature = "json")]
FileScan::NDJson { .. } => sources.expand_paths(&file_options, None)?,
FileScan::NDJson { cloud_options, .. } => {
sources.expand_paths(&file_options, cloud_options.as_ref())?
},
FileScan::Anonymous { .. } => sources,
};

Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/io/cloud/credential_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
----------
Parameters are passed to `google.auth.default()`
"""
msg = "`CredentialProviderAWS` functionality is considered unstable"
msg = "`CredentialProviderGCP` functionality is considered unstable"
issue_unstable_warning(msg)

self._check_module_availability()
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self.creds = creds

def __call__(self) -> CredentialProviderFunctionReturn:
"""Fetch the credentials for the configured profile name."""
"""Fetch the credentials."""
import google.auth.transport.requests

self.creds.refresh(google.auth.transport.requests.__dict__["Request"]())
Expand Down

0 comments on commit 67af420

Please sign in to comment.