Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] get MAX_MEM_ALLOC from device property #5270

Merged
merged 4 commits into from
Feb 2, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ namespace dpct
}
size_t get_global_mem_size() const { return _global_mem_size; }
size_t get_local_mem_size() const { return _local_mem_size; }
size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
/// Returns the maximum clock rate of device's global memory in kHz. If
/// compiler does not support this API then returns default value 3200000 kHz.
unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
Expand Down Expand Up @@ -398,6 +399,10 @@ namespace dpct
{
_local_mem_size = local_mem_size;
}
void set_max_mem_alloc_size(size_t max_mem_alloc_size)
{
_max_mem_alloc_size = max_mem_alloc_size;
}
void set_max_work_group_size(int max_work_group_size)
{
_max_work_group_size = max_work_group_size;
Expand Down Expand Up @@ -465,6 +470,7 @@ namespace dpct
int _max_register_size_per_work_group;
size_t _global_mem_size;
size_t _local_mem_size;
size_t _max_mem_alloc_size;
size_t _max_nd_range_size[3];
int _max_nd_range_size_i[3];
uint32_t _device_id;
Expand Down Expand Up @@ -516,6 +522,7 @@ namespace dpct
dev.get_info<sycl::info::device::max_work_group_size>());
prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());

#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
Expand Down Expand Up @@ -644,6 +651,11 @@ namespace dpct
return get_device_info().get_global_mem_size();
}

size_t get_max_mem_alloc_size() const
{
return get_device_info().get_max_mem_alloc_size();
}

/// Get the number of bytes of free and total memory on the SYCL device.
/// \param [out] free_memory The number of bytes of free memory on the SYCL device.
/// \param [out] total_memory The number of bytes of total memory on the SYCL device.
Expand Down Expand Up @@ -14788,6 +14800,12 @@ static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_ty
UNUSED(buft);
}

static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
return dpct::get_current_device().get_max_mem_alloc_size();

UNUSED(buft);
}

static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
int64_t row_low = 0;
int64_t row_high = ggml_nrows(tensor);
Expand Down Expand Up @@ -14818,7 +14836,7 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
/* .get_name = */ ggml_backend_sycl_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_sycl_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_sycl_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // TODO: return device.maxBufferLength
/* .get_max_size = */ ggml_backend_sycl_buffer_type_get_max_size,
/* .get_alloc_size = */ ggml_backend_sycl_buffer_type_get_alloc_size,
/* .supports_backend = */ ggml_backend_sycl_buffer_type_supports_backend,
/* .is_host = */ nullptr,
Expand Down
Loading