diff --git a/Cargo.toml b/Cargo.toml index 8e70a89..ececb40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,9 @@ sha2 = "0.8.2" thiserror = "1.0.10" lazy_static = "1.2" log = "0.4.11" +hex = "0.4.3" +serde = { version = "1.0.126", optional = true} + +[features] +default = [] +serde_support = ["serde/derive"] diff --git a/src/opencl/error.rs b/src/opencl/error.rs index 5028160..f42a69d 100644 --- a/src/opencl/error.rs +++ b/src/opencl/error.rs @@ -11,6 +11,8 @@ pub enum GPUError { ProgramInfoNotAvailable(ocl::enums::ProgramInfo), #[error("IO Error: {0}")] IO(#[from] std::io::Error), + #[error("Cannot parse UUID, expected hex-encoded string formated as aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee, got {0}.")] + Uuid(String), } #[allow(clippy::upper_case_acronyms)] diff --git a/src/opencl/mod.rs b/src/opencl/mod.rs index e3e147e..544777c 100644 --- a/src/opencl/mod.rs +++ b/src/opencl/mod.rs @@ -1,12 +1,13 @@ mod error; mod utils; +use std::convert::TryFrom; use std::fmt; use std::hash::{Hash, Hasher}; pub use error::{GPUError, GPUResult}; -pub type BusId = u32; +pub type PciId = u32; #[allow(non_camel_case_types)] pub type cl_device_id = ocl::ffi::cl_device_id; @@ -88,25 +89,89 @@ impl Buffer { } } +#[cfg_attr( + feature = "serde_support", + derive(serde::Serialize, serde::Deserialize) +)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct DeviceUuid([u8; utils::CL_UUID_SIZE_KHR]); + +impl TryFrom<&str> for DeviceUuid { + type Error = GPUError; + + fn try_from(value: &str) -> GPUResult { + let res = value + .split('-') + .map(|s| hex::decode(s).map_err(|_| GPUError::Uuid(value.to_string()))) + .collect::>>()?; + + let res = res.into_iter().flatten().collect::>(); + + if res.len() != utils::CL_UUID_SIZE_KHR { + Err(GPUError::Uuid(value.to_string())) + } else { + let mut raw = [0u8; utils::CL_UUID_SIZE_KHR]; + raw.copy_from_slice(res.as_slice()); + Ok(DeviceUuid(raw)) + } + } +} + +impl TryFrom for DeviceUuid { + type Error = GPUError; + + fn try_from(value: String) -> GPUResult { + DeviceUuid::try_from(value.as_ref()) + } +} + +impl fmt::Display for DeviceUuid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use hex::encode; + + // formats the uuid the same way as clinfo does, as an example: + // the output should looks like 46abccd6-022e-b783-572d-833f7104d05f + write!( + f, + "{}-{}-{}-{}-{}", + encode(&self.0[..4]), + encode(&self.0[4..6]), + encode(&self.0[6..8]), + encode(&self.0[8..10]), + encode(&self.0[10..]) + ) + } +} + +impl fmt::Debug for DeviceUuid { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.to_string()) + } +} + #[derive(Debug, Clone)] pub struct Device { brand: Brand, name: String, memory: u64, - bus_id: Option, platform: ocl::Platform, + pci_id: Option, + uuid: Option, pub device: ocl::Device, } impl Hash for Device { fn hash(&self, state: &mut H) { - self.bus_id.hash(state); + // hash both properties because a device might have set only one + self.uuid.hash(state); + self.pci_id.hash(state); } } impl PartialEq for Device { fn eq(&self, other: &Self) -> bool { - self.bus_id == other.bus_id + // A device might have set only one of the properties, hence compare both + self.uuid == other.uuid && self.pci_id == other.pci_id } } @@ -125,8 +190,11 @@ impl Device { pub fn is_little_endian(&self) -> GPUResult { utils::is_little_endian(self.device) } - pub fn bus_id(&self) -> Option { - self.bus_id + pub fn pci_id(&self) -> Option { + self.pci_id + } + pub fn uuid(&self) -> Option { + self.uuid } /// Return all available GPU devices of supported brands. @@ -134,10 +202,19 @@ impl Device { Self::all_iter().collect() } - pub fn by_bus_id(bus_id: BusId) -> GPUResult<&'static Device> { - Self::all_iter() - .find(|d| match d.bus_id { - Some(id) => bus_id == id, + pub fn by_pci_id(pci_id: PciId) -> GPUResult<&'static Device> { + Device::all_iter() + .find(|d| match d.pci_id { + Some(id) => pci_id == id, + None => false, + }) + .ok_or(GPUError::DeviceNotFound) + } + + pub fn by_uuid(uuid: &DeviceUuid) -> GPUResult<&'static Device> { + Device::all_iter() + .find(|d| match d.uuid { + Some(ref id) => id == uuid, None => false, }) .ok_or(GPUError::DeviceNotFound) @@ -152,50 +229,45 @@ impl Device { } } -#[allow(clippy::upper_case_acronyms)] -#[derive(Debug, Clone, Copy)] +#[cfg_attr( + feature = "serde_support", + derive(serde::Serialize, serde::Deserialize) +)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum GPUSelector { - BusId(u32), + Uuid(DeviceUuid), + PciId(u32), Index(usize), } impl GPUSelector { - pub fn get_bus_id(&self) -> Option { - match self { - GPUSelector::BusId(bus_id) => Some(*bus_id), - GPUSelector::Index(index) => get_device_bus_id_by_index(*index), - } + pub fn get_uuid(&self) -> Option { + self.get_device().and_then(|dev| dev.uuid) + } + + pub fn get_pci_id(&self) -> Option { + self.get_device().and_then(|dev| dev.pci_id) } pub fn get_device(&self) -> Option<&'static Device> { match self { - GPUSelector::BusId(bus_id) => Device::by_bus_id(*bus_id).ok(), + GPUSelector::Uuid(uuid) => Device::all_iter().find(|d| d.uuid == Some(*uuid)), + GPUSelector::PciId(pci_id) => Device::all_iter().find(|d| d.pci_id == Some(*pci_id)), GPUSelector::Index(index) => get_device_by_index(*index), } } pub fn get_key(&self) -> String { match self { - GPUSelector::BusId(id) => format!("BusID: {}", id), + GPUSelector::Uuid(uuid) => format!("Uuid: {}", uuid), + GPUSelector::PciId(id) => format!("PciId: {}", id), GPUSelector::Index(idx) => { - if let Some(id) = self.get_bus_id() { - format!("BusID: {}", id) - } else { - format!("Index: {}", idx) - } + format!("Index: {}", idx) } } } } -fn get_device_bus_id_by_index(index: usize) -> Option { - if let Some(device) = get_device_by_index(index) { - device.bus_id - } else { - None - } -} - fn get_device_by_index(index: usize) -> Option<&'static Device> { Device::all_iter().nth(index) } @@ -370,13 +442,30 @@ macro_rules! call_kernel { #[cfg(test)] mod test { - use super::Device; + use super::{Device, DeviceUuid}; + use std::convert::TryFrom; #[test] fn test_device_all() { - for _ in 0..10 { - let devices = Device::all(); - dbg!(&devices.len()); - } + let devices = Device::all(); + dbg!(&devices.len()); + println!("{:?}", devices); + } + + #[test] + fn test_uuid() { + let test_uuid = "46abccd6-022e-b783-572d-833f7104d05f"; + let uuid = DeviceUuid::try_from(test_uuid).unwrap(); + assert_eq!(test_uuid, &uuid.to_string()); + + // test wrong length uuid + let bad_uuid = "46abccd6-022e-b783-572-833f7104d05f"; + let uuid = DeviceUuid::try_from(bad_uuid); + assert!(uuid.is_err()); + + // test invalid hex character + let bad_uuid = "46abccd6-022e-b783-572d-833f7104d05h"; + let uuid = DeviceUuid::try_from(bad_uuid); + assert!(uuid.is_err()); } } diff --git a/src/opencl/utils.rs b/src/opencl/utils.rs index f952249..cee5e38 100644 --- a/src/opencl/utils.rs +++ b/src/opencl/utils.rs @@ -5,7 +5,7 @@ use lazy_static::lazy_static; use log::{debug, warn}; use sha2::{Digest, Sha256}; -use super::{Brand, Device, GPUError, GPUResult}; +use super::{Brand, Device, DeviceUuid, GPUError, GPUResult}; #[repr(C)] #[derive(Debug, Clone, Default)] @@ -17,8 +17,13 @@ struct cl_amd_device_topology { function: u8, } -const AMD_DEVICE_VENDOR_STRING: &str = "AMD"; -const NVIDIA_DEVICE_VENDOR_STRING: &str = "NVIDIA Corporation"; +const AMD_DEVICE_VENDOR_STRING: &'static str = "AMD"; +const NVIDIA_DEVICE_VENDOR_STRING: &'static str = "NVIDIA Corporation"; + +// constants defined as part of the opencl spec +// https://github.com/KhronosGroup/OpenCL-Headers/blob/master/CL/cl_ext.h#L687 +const CL_DEVICE_UUID_KHR: u32 = 0x106A; +pub(crate) const CL_UUID_SIZE_KHR: usize = 16; pub fn is_little_endian(d: ocl::Device) -> GPUResult { match d.info(ocl::enums::DeviceInfo::EndianLittle)? { @@ -29,26 +34,36 @@ pub fn is_little_endian(d: ocl::Device) -> GPUResult { } } -pub fn get_bus_id(d: ocl::Device) -> ocl::Result { +pub fn get_device_uuid(d: ocl::Device) -> ocl::Result { + let result = d.info_raw(CL_DEVICE_UUID_KHR)?; + assert_eq!(result.len(), CL_UUID_SIZE_KHR); + let mut raw = [0u8; CL_UUID_SIZE_KHR]; + raw.copy_from_slice(result.as_slice()); + Ok(DeviceUuid(raw)) +} + +pub fn get_pci_id(d: ocl::Device) -> ocl::Result { let vendor = d.vendor()?; match vendor.as_str() { - AMD_DEVICE_VENDOR_STRING => get_amd_bus_id(d), - NVIDIA_DEVICE_VENDOR_STRING => get_nvidia_bus_id(d), + AMD_DEVICE_VENDOR_STRING => get_amd_pci_id(d), + NVIDIA_DEVICE_VENDOR_STRING => get_nvidia_pci_id(d), _ => Err(ocl::Error::from(format!( - "cannot get bus ID for device with vendor {} ", + "cannot get pciId for device with vendor {} ", vendor ))), } } -pub fn get_nvidia_bus_id(d: ocl::Device) -> ocl::Result { - const CL_DEVICE_PCI_BUS_ID_NV: u32 = 0x4008; - let result = d.info_raw(CL_DEVICE_PCI_BUS_ID_NV)?; +fn get_nvidia_pci_id(d: ocl::Device) -> ocl::Result { + const CL_DEVICE_PCI_SLOT_ID_NV: u32 = 0x4009; + + let result = d.info_raw(CL_DEVICE_PCI_SLOT_ID_NV)?; Ok(u32::from_le_bytes(result[..].try_into().unwrap())) } -pub fn get_amd_bus_id(d: ocl::Device) -> ocl::Result { +fn get_amd_pci_id(d: ocl::Device) -> ocl::Result { const CL_DEVICE_TOPOLOGY_AMD: u32 = 0x4037; + let result = d.info_raw(CL_DEVICE_TOPOLOGY_AMD)?; let size = std::mem::size_of::(); assert_eq!(result.len(), size); @@ -57,7 +72,10 @@ pub fn get_amd_bus_id(d: ocl::Device) -> ocl::Result { std::slice::from_raw_parts_mut(&mut topo as *mut cl_amd_device_topology as *mut u8, size) .copy_from_slice(&result); } - Ok(topo.bus as u32) + let device = topo.device as u32; + let bus = topo.bus as u32; + let function = topo.function as u32; + Ok((device << 16) | (bus << 8) | function) } pub fn cache_path(device: &Device, cl_source: &str) -> std::io::Result { @@ -66,14 +84,12 @@ pub fn cache_path(device: &Device, cl_source: &str) -> std::io::Result Vec { brand, name: d.name()?, memory: get_memory(d)?, - bus_id: get_bus_id(d).ok(), + uuid: get_device_uuid(d).ok(), + pci_id: get_pci_id(d).ok(), platform: *platform, device: d, })