Skip to content

Commit

Permalink
Improve cache FS security
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-- committed Jan 23, 2025
1 parent 5c4788d commit 24891f4
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 32 deletions.
41 changes: 24 additions & 17 deletions third_party/cached_httpfs/http_file_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

namespace duckdb {

CachedFile::CachedFile(const string &cache_dir, FileSystem &fs, const string &key, bool cache_file) : cache_directory(cache_dir), fs(fs) {
file_name = cache_dir + "/" + key;

GetDirectoryCacheLock(cache_dir);
CachedFile::CachedFile(LocalCacheFileSystem &_fs, const string &key, bool cache_file) : fs(_fs), file_name(key) {
GetDirectoryCacheLock();

FileOpenFlags flags =
FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS | FileLockType::READ_LOCK;
Expand All @@ -21,38 +19,31 @@ CachedFile::CachedFile(const string &cache_dir, FileSystem &fs, const string &ke
ReleaseDirectoryCacheLock();
}

CachedFile::~CachedFile() {

}

void CachedFile::GetDirectoryCacheLock(const string &cache_dir) {
string lock_file = cache_dir + "/.lock";
void CachedFile::GetDirectoryCacheLock() {
constexpr const char* lock_file_name = ".lock";
FileOpenFlags flags = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE |
FileFlags::FILE_FLAGS_EXCLUSIVE_CREATE | FileFlags::FILE_FLAGS_NULL_IF_EXISTS |
FileLockType::WRITE_LOCK;
directory_lock_handle = fs.OpenFile(lock_file, flags);
directory_lock_handle = fs.OpenFile(lock_file_name, flags);
if (directory_lock_handle == nullptr) {
flags = FileFlags::FILE_FLAGS_WRITE | FileLockType::WRITE_LOCK;
directory_lock_handle = fs.OpenFile(lock_file, flags);
directory_lock_handle = fs.OpenFile(lock_file_name, flags);
}
}


void CachedFile::ReleaseDirectoryCacheLock() {
directory_lock_handle->Close();
directory_lock_handle.reset();
}


CachedFileHandle::CachedFileHandle(shared_ptr<CachedFile> &file_p) {
file = file_p;
}

void CachedFileHandle::WriteMetadata(const string &cache_key, const string &remote_path, idx_t total_size) {
D_ASSERT(!file->initialized);
string metadata_file_name = file->cache_directory + "/" + cache_key + ".meta";
FileOpenFlags flags = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE | FileLockType::WRITE_LOCK;
auto handle = file->fs.OpenFile(metadata_file_name, flags);
auto handle = file->fs.OpenFile(cache_key + ".meta", flags);
auto cached_file_timestamp = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
string metadata_info = cache_key + "," + remote_path + "," + std::to_string(total_size) + "," + std::to_string(cached_file_timestamp);
handle->Write((void *)metadata_info.c_str(), metadata_info.length(), 0);
Expand Down Expand Up @@ -90,6 +81,21 @@ void CachedFileHandle::Read(void *buffer, idx_t length, idx_t offset) {
file->handle->Read((void *)buffer, length, offset);
}

LocalCacheFileSystem& HTTPFileCache::GetFS(const string &cache_dir) {
if (cached_fs) {
cached_fs->AssertSameCacheDir(cache_dir);
return *cached_fs;
}

{
lock_guard<mutex> lock(cached_fs_mutex);
if (!cached_fs) {
cached_fs = make_uniq<LocalCacheFileSystem>(cache_dir);
}
}
return HTTPFileCache::GetFS(cache_dir);
}

//! Get cache entry, create if not exists only if caching is enabled
shared_ptr<CachedFile> HTTPFileCache::GetCachedFile(const string &cache_dir, const string &key, bool cache_file) {
lock_guard<mutex> lock(cached_files_mutex);
Expand All @@ -98,7 +104,8 @@ shared_ptr<CachedFile> HTTPFileCache::GetCachedFile(const string &cache_dir, con
return it->second;
}

auto cache_entry = make_shared_ptr<CachedFile>(cache_dir, cache_fs, key, cache_file);
auto& fs = GetFS(cache_dir);
auto cache_entry = make_shared_ptr<CachedFile>(fs, key, cache_file);
if (cache_entry->Initialized() || cache_file) {
cached_files[key] = cache_entry;
return cache_entry;
Expand Down
90 changes: 75 additions & 15 deletions third_party/cached_httpfs/include/http_file_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,79 @@
namespace duckdb {

class CachedFileHandle;
class LocalFileSystem;

class LocalCacheFileSystem: public LocalFileSystem {

public:
LocalCacheFileSystem(const std::string &_cache_file_directory) : cache_file_directory(_cache_file_directory) {
// TODO: Should we forbid symlinks for cache_file_directory? Or have a more complex check?
}

std::string GetName() const override {
return "LocalCacheFileSystem";
}

unique_ptr<FileHandle> OpenFile(const string &file_name, FileOpenFlags flags,
optional_ptr<FileOpener> opener = nullptr) override {
if (file_name.find("..") != string::npos || file_name.find("/") != string::npos) {
throw PermissionException("Must provide a file name, not a path. Got: '", file_name, "'");

}

std::ostringstream oss;
oss << cache_file_directory << "/" << file_name;
return LocalFileSystem::OpenFile(oss.str(), flags, opener);
}

void AssertSameCacheDir(const string &other_dir) {
if (other_dir != cache_file_directory) {
throw PermissionException("BUG: expected cache directory to be '", cache_file_directory, "' but got '", other_dir, "'");
}
}

#define LOCAL_FILE_SYSTEM_METHOD(method_name) \
method_name(const string &path, optional_ptr<FileOpener> opener = nullptr) override { \
throw PermissionException("LocalCacheFileSystem cannot run " #method_name " on '", path, "'"); \
}

LOCAL_FILE_SYSTEM_METHOD(bool DirectoryExists)
LOCAL_FILE_SYSTEM_METHOD(void CreateDirectory)
LOCAL_FILE_SYSTEM_METHOD(void RemoveDirectory)
LOCAL_FILE_SYSTEM_METHOD(bool FileExists)
LOCAL_FILE_SYSTEM_METHOD(bool IsPipe)
LOCAL_FILE_SYSTEM_METHOD(void RemoveFile)

bool ListFiles(const string &directory, const std::function<void(const string &, bool)> &,
FileOpener *opener = nullptr) override {
throw PermissionException("LocalCacheFileSystem cannot run ListFiles on '", directory, "'");
}

void MoveFile(const string &source, const string &target, optional_ptr<FileOpener> opener = nullptr) override {
throw PermissionException("LocalCacheFileSystem cannot run MoveFile on '", source, "' to '", target, "'");
}

vector<string> Glob(const string &path, FileOpener *opener = nullptr) override {
throw PermissionException("LocalCacheFileSystem cannot run Glob on '", path, "'");
}

bool CanHandleFile(const string &fpath) override {
throw PermissionException("LocalCacheFileSystem cannot run CanHandleFile on '", fpath, "'");
}

static bool IsPrivateFile(const string &path_p, FileOpener *opener) {
throw PermissionException("LocalCacheFileSystem cannot run IsPrivateFile on '", path_p, "'");
}
private:
std::string cache_file_directory;
};


//! Represents a file that is intended to be fully downloaded, then used in parallel by multiple threads
class CachedFile : public enable_shared_from_this<CachedFile> {
friend class CachedFileHandle;

public:
CachedFile(const string &cache_dir, FileSystem &fs, const std::string &key, bool cache_file);
~CachedFile();
CachedFile(LocalCacheFileSystem &fs, const std::string &key, bool cache_file);

unique_ptr<CachedFileHandle> GetHandle() {
auto this_ptr = shared_from_this();
Expand All @@ -26,13 +90,11 @@ class CachedFile : public enable_shared_from_this<CachedFile> {
}

private:
void GetDirectoryCacheLock(const string &cache_dir);
void GetDirectoryCacheLock();
void ReleaseDirectoryCacheLock();

private:
string cache_directory;
// FileSystem
FileSystem &fs;
LocalCacheFileSystem &fs;
// File name
std::string file_name;
// Cache file FileDescriptor
Expand Down Expand Up @@ -80,13 +142,6 @@ class CachedFileHandle {
shared_ptr<CachedFile> file;
};

class LocalCacheFileSystem: public LocalFileSystem {
// TODO: we could lock down the LocalFileSystem to only allow path that are in the cache directory
std::string GetName() const override {
return "LocalCacheFileSystem";
}
};

class HTTPFileCache : public ClientContextState {
public:
explicit HTTPFileCache(ClientContext &context) {
Expand All @@ -97,10 +152,15 @@ class HTTPFileCache : public ClientContextState {
shared_ptr<CachedFile> GetCachedFile(const string &cache_dir, const string &key, bool create_cache);

private:
LocalCacheFileSystem cache_fs;
LocalCacheFileSystem& GetFS(const string &cache_dir);

unique_ptr<LocalCacheFileSystem> cached_fs;

//! Database Instance
shared_ptr<DatabaseInstance> db;

//! Mutex to lock when getting the cached fs (Parallel Only)
mutex cached_fs_mutex;
//! Mutex to lock when getting the cached file (Parallel Only)
mutex cached_files_mutex;
//! In case of fully downloading the file, the cached files of this query
Expand Down

0 comments on commit 24891f4

Please sign in to comment.