Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix incorrect object store caching for ADLS URI #20357

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions crates/polars-io/src/cloud/glob.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use futures::TryStreamExt;
use object_store::path::Path;
use polars_core::error::to_compute_err;
use polars_core::prelude::{polars_ensure, polars_err};
use polars_error::PolarsResult;
use polars_core::prelude::polars_ensure;
use polars_error::{polars_bail, PolarsResult};
use polars_utils::format_pl_smallstr;
use polars_utils::pl_str::PlSmallStr;
use regex::Regex;
use url::Url;
Expand Down Expand Up @@ -98,13 +99,16 @@ impl CloudLocation {
}

let key = parsed.path();
let bucket = parsed
.host()
.ok_or_else(
|| polars_err!(ComputeError: "cannot parse bucket (host) from url: {}", parsed),
)?
.to_string()
.into();

let bucket = format_pl_smallstr!(
"{}",
&parsed[url::Position::BeforeUsername..url::Position::AfterPort]
);

if bucket.is_empty() {
polars_bail!(ComputeError: "CloudLocation::from_url(): empty bucket: {}", parsed);
}

(bucket, key)
};

Expand Down
44 changes: 29 additions & 15 deletions crates/polars-io/src/cloud/object_store_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use polars_core::config;
use polars_error::{polars_bail, to_compute_err, PolarsError, PolarsResult};
use polars_utils::aliases::PlHashMap;
use polars_utils::pl_str::PlSmallStr;
use polars_utils::{format_pl_smallstr, pl_serialize};
use tokio::sync::RwLock;
use url::Url;

Expand All @@ -17,7 +18,7 @@ use crate::cloud::CloudConfig;
/// get rate limited when querying the DNS (can take up to 5s).
/// Other reasons are connection pools that must be shared between as much as possible.
#[allow(clippy::type_complexity)]
static OBJECT_STORE_CACHE: Lazy<RwLock<PlHashMap<String, PolarsObjectStore>>> =
static OBJECT_STORE_CACHE: Lazy<RwLock<PlHashMap<Vec<u8>, PolarsObjectStore>>> =
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
Lazy::new(Default::default);

#[allow(dead_code)]
Expand All @@ -29,10 +30,10 @@ fn err_missing_feature(feature: &str, scheme: &str) -> PolarsResult<Arc<dyn Obje
}

/// Get the key of a url for object store registration.
fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String {
fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> Vec<u8> {
#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct S {
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct C {
max_retries: usize,
#[cfg(feature = "file_cache")]
file_cache_ttl: u64,
Expand All @@ -41,8 +42,15 @@ fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String {
credential_provider: usize,
}

#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct S {
url_base: PlSmallStr,
cloud_options: Option<C>,
}

// We include credentials as they can expire, so users will send new credentials for the same url.
let creds = serde_json::to_string(&options.map(
let cloud_options = options.map(
|CloudOptions {
// Destructure to ensure this breaks if anything changes.
max_retries,
Expand All @@ -52,7 +60,7 @@ fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String {
#[cfg(feature = "cloud")]
credential_provider,
}| {
S {
C {
max_retries: *max_retries,
#[cfg(feature = "file_cache")]
file_cache_ttl: *file_cache_ttl,
Expand All @@ -61,15 +69,21 @@ fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String {
credential_provider: credential_provider.as_ref().map_or(0, |x| x.func_addr()),
}
},
))
.unwrap();

format!(
"{}://{}<\\creds\\>{}",
url.scheme(),
&url[url::Position::BeforeHost..url::Position::AfterPort],
creds
)
);

let cache_key = S {
url_base: format_pl_smallstr!(
"{}",
&url[url::Position::BeforeScheme..url::Position::AfterPort]
),
cloud_options,
};

if config::verbose() {
eprintln!("object store cache key: {} {:?}", url, &cache_key);
}

pl_serialize::serialize_to_bytes(&cache_key).unwrap()
}

/// Construct an object_store `Path` from a string without any encoding/decoding.
Expand Down
126 changes: 117 additions & 9 deletions crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ use polars_error::*;
#[cfg(feature = "aws")]
use polars_utils::cache::FastFixedCache;
#[cfg(feature = "aws")]
use polars_utils::pl_str::PlSmallStr;
#[cfg(feature = "aws")]
use regex::Regex;
#[cfg(feature = "http")]
use reqwest::header::HeaderMap;
Expand All @@ -43,8 +41,11 @@ use crate::file_cache::get_env_file_cache_ttl;
use crate::pl_async::with_concurrency_budget;

#[cfg(feature = "aws")]
static BUCKET_REGION: Lazy<std::sync::Mutex<FastFixedCache<PlSmallStr, PlSmallStr>>> =
Lazy::new(|| std::sync::Mutex::new(FastFixedCache::new(32)));
static BUCKET_REGION: Lazy<
std::sync::Mutex<
FastFixedCache<polars_utils::pl_str::PlSmallStr, polars_utils::pl_str::PlSmallStr>,
>,
> = Lazy::new(|| std::sync::Mutex::new(FastFixedCache::new(32)));

/// The type of the config keys must satisfy the following requirements:
/// 1. must be easily collected into a HashMap, the type required by the object_crate API.
Expand Down Expand Up @@ -406,16 +407,20 @@ impl CloudOptions {
pub fn build_azure(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
use super::credential_provider::IntoCredentialProvider;

let mut builder = if self.credential_provider.is_none() {
MicrosoftAzureBuilder::from_env()
} else {
MicrosoftAzureBuilder::new()
};
let mut storage_account: Option<polars_utils::pl_str::PlSmallStr> = None;

// The credential provider `self.credentials` is prioritized if it is set. We also need
// `from_env()` as it may source environment configured storage account name.
let mut builder = MicrosoftAzureBuilder::from_env();

if let Some(options) = &self.config {
let CloudConfig::Azure(options) = options else {
panic!("impl error: cloud type mismatch")
};
for (key, value) in options.iter() {
if key == &AzureConfigKey::AccountName {
storage_account = Some(value.into());
}
builder = builder.with_config(*key, value);
}
}
Expand All @@ -425,8 +430,18 @@ impl CloudOptions {
.with_url(url)
.with_retry(get_retry_config(self.max_retries));

// Prefer the one embedded in the path
storage_account = extract_adls_uri_storage_account(url)
.map(|x| x.into())
.or(storage_account);

let builder = if let Some(v) = self.credential_provider.clone() {
builder.with_credentials(v.into_azure_provider())
} else if let Some(v) = storage_account
.as_deref()
.and_then(get_azure_storage_account_key)
{
builder.with_access_key(v)
} else {
builder
};
Expand Down Expand Up @@ -610,6 +625,99 @@ impl CloudOptions {
}
}

/// ```text
/// "abfss://{CONTAINER}@{STORAGE_ACCOUNT}.dfs.core.windows.net/"
/// ^^^^^^^^^^^^^^^^^
/// ```
#[cfg(feature = "azure")]
fn extract_adls_uri_storage_account(path: &str) -> Option<&str> {
Some(
path.split_once("://")?
.1
.split_once('/')?
.0
.split_once('@')?
.1
.split_once(".dfs.core.windows.net")?
.0,
)
}

/// Attempt to retrieve the storage account key for this account using the Azure CLI.
#[cfg(feature = "azure")]
fn get_azure_storage_account_key(account_name: &str) -> Option<String> {
if polars_core::config::verbose() {
eprintln!(
"get_azure_storage_account_key: storage_account_name: {}",
account_name
);
}

let mut cmd = if cfg!(target_family = "windows") {
// https://github.com/apache/arrow-rs/blob/565c24b8071269b02c3937e34c51eacf0f4cbad6/object_store/src/azure/credential.rs#L877-L894
let mut v = std::process::Command::new("cmd");
v.args([
"/C",
"az",
"storage",
"account",
"keys",
"list",
"--output",
"json",
"--account-name",
account_name,
]);
v
} else {
let mut v = std::process::Command::new("az");
v.args([
"storage",
"account",
"keys",
"list",
"--output",
"json",
"--account-name",
account_name,
]);
v
};

let json_resp = cmd
.output()
.ok()
.filter(|x| x.status.success())
.map(|x| String::from_utf8(x.stdout))?
.ok()?;

// [
// {
// "creationTime": "1970-01-01T00:00:00.000000+00:00",
// "keyName": "key1",
// "permissions": "FULL",
// "value": "..."
// },
// {
// "creationTime": "1970-01-01T00:00:00.000000+00:00",
// "keyName": "key2",
// "permissions": "FULL",
// "value": "..."
// }
// ]

#[derive(Debug, serde::Deserialize)]
struct S {
value: String,
}

let resp: Vec<S> = serde_json::from_str(&json_resp).ok()?;

let access_key = resp.into_iter().next()?.value;

Some(access_key)
}

#[cfg(feature = "cloud")]
#[cfg(test)]
mod tests {
Expand Down
18 changes: 8 additions & 10 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,19 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult

let sources = match &scan_type {
#[cfg(feature = "parquet")]
FileScan::Parquet {
ref cloud_options, ..
} => sources
FileScan::Parquet { cloud_options, .. } => sources
.expand_paths_with_hive_update(&mut file_options, cloud_options.as_ref())?,
#[cfg(feature = "ipc")]
FileScan::Ipc {
ref cloud_options, ..
} => sources
FileScan::Ipc { cloud_options, .. } => sources
.expand_paths_with_hive_update(&mut file_options, cloud_options.as_ref())?,
#[cfg(feature = "csv")]
FileScan::Csv {
ref cloud_options, ..
} => sources.expand_paths(&file_options, cloud_options.as_ref())?,
FileScan::Csv { cloud_options, .. } => {
sources.expand_paths(&file_options, cloud_options.as_ref())?
},
#[cfg(feature = "json")]
FileScan::NDJson { .. } => sources.expand_paths(&file_options, None)?,
FileScan::NDJson { cloud_options, .. } => {
sources.expand_paths(&file_options, cloud_options.as_ref())?
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
},
FileScan::Anonymous { .. } => sources,
};

Expand Down
Loading