Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a unique_id identifier for NVIDIA and AMD devices #30

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ sha2 = "0.8.2"
thiserror = "1.0.10"
lazy_static = "1.2"
log = "0.4.11"
hex = "0.4.3"
2 changes: 2 additions & 0 deletions src/opencl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub enum GPUError {
ProgramInfoNotAvailable(ocl::enums::ProgramInfo),
#[error("IO Error: {0}")]
IO(#[from] std::io::Error),
#[error("Cannot parse UUID, expected hex-encoded string formated as aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee, got {0}.")]
Uuid(String),
}

#[allow(clippy::upper_case_acronyms)]
Expand Down
147 changes: 110 additions & 37 deletions src/opencl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
mod error;
mod utils;

use std::convert::TryInto;
use std::fmt;
use std::hash::{Hash, Hasher};

pub use error::{GPUError, GPUResult};

pub type BusId = u32;
pub type PciId = u32;

#[allow(non_camel_case_types)]
pub type cl_device_id = ocl::ffi::cl_device_id;
Expand Down Expand Up @@ -88,25 +89,77 @@ impl<T> Buffer<T> {
}
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct DeviceUuid([u8; utils::CL_UUID_SIZE_KHR]);

impl TryInto<DeviceUuid> for &str {
type Error = GPUError;

fn try_into(self) -> GPUResult<DeviceUuid> {
let res = self
.split('-')
.map(|s| hex::decode(s).map_err(|_| GPUError::Uuid(self.to_string())))
.collect::<GPUResult<Vec<_>>>()?;

let res = res.into_iter().flatten().collect::<Vec<u8>>();

if res.len() != utils::CL_UUID_SIZE_KHR {
Err(GPUError::Uuid(self.to_string()))
} else {
let mut raw = [0u8; utils::CL_UUID_SIZE_KHR];
raw.copy_from_slice(res.as_slice());
Ok(DeviceUuid(raw))
}
}
}

impl fmt::Display for DeviceUuid {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use hex::encode;

// formats the uuid the same way as clinfo does, as an example:
// the output should looks like 46abccd6-022e-b783-572d-833f7104d05f
write!(
f,
"{}-{}-{}-{}-{}",
encode(&self.0[..4]),
encode(&self.0[4..6]),
encode(&self.0[6..8]),
encode(&self.0[8..10]),
encode(&self.0[10..])
)
}
}

impl fmt::Debug for DeviceUuid {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.to_string())
}
}

#[derive(Debug, Clone)]
pub struct Device {
brand: Brand,
name: String,
memory: u64,
bus_id: Option<BusId>,
platform: ocl::Platform,
pci_id: Option<PciId>,
uuid: Option<DeviceUuid>,
pub device: ocl::Device,
}

impl Hash for Device {
fn hash<H: Hasher>(&self, state: &mut H) {
self.bus_id.hash(state);
// hash both properties because a device might have set only one
self.uuid.hash(state);
neithanmo marked this conversation as resolved.
Show resolved Hide resolved
self.pci_id.hash(state);
}
}

impl PartialEq for Device {
fn eq(&self, other: &Self) -> bool {
self.bus_id == other.bus_id
// A device might have set only one of the properties, hence compare both
self.uuid == other.uuid && self.pci_id == other.pci_id
}
}

Expand All @@ -125,19 +178,31 @@ impl Device {
pub fn is_little_endian(&self) -> GPUResult<bool> {
utils::is_little_endian(self.device)
}
pub fn bus_id(&self) -> Option<BusId> {
self.bus_id
pub fn pci_id(&self) -> Option<PciId> {
self.pci_id
}
pub fn uuid(&self) -> Option<DeviceUuid> {
self.uuid
}

/// Return all available GPU devices of supported brands.
pub fn all() -> Vec<&'static Device> {
Self::all_iter().collect()
}

pub fn by_bus_id(bus_id: BusId) -> GPUResult<&'static Device> {
Self::all_iter()
.find(|d| match d.bus_id {
Some(id) => bus_id == id,
pub fn by_pci_id(pci_id: PciId) -> GPUResult<&'static Device> {
Device::all_iter()
.find(|d| match d.pci_id {
Some(id) => pci_id == id,
None => false,
})
.ok_or(GPUError::DeviceNotFound)
}

pub fn by_uuid(uuid: &DeviceUuid) -> GPUResult<&'static Device> {
Device::all_iter()
.find(|d| match d.uuid {
Some(ref id) => id == uuid,
None => false,
})
.ok_or(GPUError::DeviceNotFound)
Expand All @@ -152,50 +217,41 @@ impl Device {
}
}

#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, Copy)]
pub enum GPUSelector {
BusId(u32),
Uuid(DeviceUuid),
PciId(u32),
Index(usize),
}

impl GPUSelector {
pub fn get_bus_id(&self) -> Option<u32> {
match self {
GPUSelector::BusId(bus_id) => Some(*bus_id),
GPUSelector::Index(index) => get_device_bus_id_by_index(*index),
}
pub fn get_uuid(&self) -> Option<DeviceUuid> {
self.get_device().and_then(|dev| dev.uuid)
}

pub fn get_pci_id(&self) -> Option<u32> {
self.get_device().and_then(|dev| dev.pci_id)
}

pub fn get_device(&self) -> Option<&'static Device> {
match self {
GPUSelector::BusId(bus_id) => Device::by_bus_id(*bus_id).ok(),
GPUSelector::Uuid(uuid) => Device::all_iter().find(|d| d.uuid == Some(*uuid)),
GPUSelector::PciId(pci_id) => Device::all_iter().find(|d| d.pci_id == Some(*pci_id)),
GPUSelector::Index(index) => get_device_by_index(*index),
}
}

pub fn get_key(&self) -> String {
match self {
GPUSelector::BusId(id) => format!("BusID: {}", id),
GPUSelector::Uuid(uuid) => format!("Uuid: {}", uuid),
GPUSelector::PciId(id) => format!("PciId: {}", id),
GPUSelector::Index(idx) => {
if let Some(id) = self.get_bus_id() {
format!("BusID: {}", id)
} else {
format!("Index: {}", idx)
}
format!("Index: {}", idx)
}
}
}
}

fn get_device_bus_id_by_index(index: usize) -> Option<BusId> {
if let Some(device) = get_device_by_index(index) {
device.bus_id
} else {
None
}
}

fn get_device_by_index(index: usize) -> Option<&'static Device> {
Device::all_iter().nth(index)
}
Expand Down Expand Up @@ -370,13 +426,30 @@ macro_rules! call_kernel {

#[cfg(test)]
mod test {
use super::Device;
use super::{Device, DeviceUuid};
use std::convert::TryInto;

#[test]
fn test_device_all() {
for _ in 0..10 {
let devices = Device::all();
dbg!(&devices.len());
}
let devices = Device::all();
dbg!(&devices.len());
println!("{:?}", devices);
}

#[test]
fn test_uuid() {
let test_uuid = "46abccd6-022e-b783-572d-833f7104d05f";
let uuid: DeviceUuid = test_uuid.try_into().unwrap();
assert_eq!(test_uuid, &uuid.to_string());

// test wrong length uuid
let bad_uuid = "46abccd6-022e-b783-572-833f7104d05f";
let uuid: Result<DeviceUuid, _> = bad_uuid.try_into();
assert!(uuid.is_err());

// test invalid hex character
let bad_uuid = "46abccd6-022e-b783-572d-833f7104d05h";
let uuid: Result<DeviceUuid, _> = bad_uuid.try_into();
assert!(uuid.is_err());
}
}
57 changes: 37 additions & 20 deletions src/opencl/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use lazy_static::lazy_static;
use log::{debug, warn};
use sha2::{Digest, Sha256};

use super::{Brand, Device, GPUError, GPUResult};
use super::{Brand, Device, DeviceUuid, GPUError, GPUResult};

#[repr(C)]
#[derive(Debug, Clone, Default)]
Expand All @@ -17,8 +17,13 @@ struct cl_amd_device_topology {
function: u8,
}

const AMD_DEVICE_VENDOR_STRING: &str = "AMD";
const NVIDIA_DEVICE_VENDOR_STRING: &str = "NVIDIA Corporation";
const AMD_DEVICE_VENDOR_STRING: &'static str = "AMD";
const NVIDIA_DEVICE_VENDOR_STRING: &'static str = "NVIDIA Corporation";

// constants defined as part of the opencl spec
// https://github.com/KhronosGroup/OpenCL-Headers/blob/master/CL/cl_ext.h#L687
const CL_DEVICE_UUID_KHR: u32 = 0x106A;
pub(crate) const CL_UUID_SIZE_KHR: usize = 16;

pub fn is_little_endian(d: ocl::Device) -> GPUResult<bool> {
match d.info(ocl::enums::DeviceInfo::EndianLittle)? {
Expand All @@ -29,26 +34,36 @@ pub fn is_little_endian(d: ocl::Device) -> GPUResult<bool> {
}
}

pub fn get_bus_id(d: ocl::Device) -> ocl::Result<u32> {
pub fn get_device_uuid(d: ocl::Device) -> ocl::Result<DeviceUuid> {
let result = d.info_raw(CL_DEVICE_UUID_KHR)?;
assert_eq!(result.len(), CL_UUID_SIZE_KHR);
let mut raw = [0u8; CL_UUID_SIZE_KHR];
raw.copy_from_slice(result.as_slice());
Ok(DeviceUuid(raw))
}

pub fn get_pci_id(d: ocl::Device) -> ocl::Result<u32> {
let vendor = d.vendor()?;
match vendor.as_str() {
AMD_DEVICE_VENDOR_STRING => get_amd_bus_id(d),
NVIDIA_DEVICE_VENDOR_STRING => get_nvidia_bus_id(d),
AMD_DEVICE_VENDOR_STRING => get_amd_pci_id(d),
NVIDIA_DEVICE_VENDOR_STRING => get_nvidia_pci_id(d),
_ => Err(ocl::Error::from(format!(
"cannot get bus ID for device with vendor {} ",
"cannot get pciId for device with vendor {} ",
vendor
))),
}
}

pub fn get_nvidia_bus_id(d: ocl::Device) -> ocl::Result<u32> {
const CL_DEVICE_PCI_BUS_ID_NV: u32 = 0x4008;
let result = d.info_raw(CL_DEVICE_PCI_BUS_ID_NV)?;
fn get_nvidia_pci_id(d: ocl::Device) -> ocl::Result<u32> {
const CL_DEVICE_PCI_SLOT_ID_NV: u32 = 0x4009;

let result = d.info_raw(CL_DEVICE_PCI_SLOT_ID_NV)?;
Ok(u32::from_le_bytes(result[..].try_into().unwrap()))
}

pub fn get_amd_bus_id(d: ocl::Device) -> ocl::Result<u32> {
fn get_amd_pci_id(d: ocl::Device) -> ocl::Result<u32> {
const CL_DEVICE_TOPOLOGY_AMD: u32 = 0x4037;

let result = d.info_raw(CL_DEVICE_TOPOLOGY_AMD)?;
let size = std::mem::size_of::<cl_amd_device_topology>();
assert_eq!(result.len(), size);
Expand All @@ -57,7 +72,10 @@ pub fn get_amd_bus_id(d: ocl::Device) -> ocl::Result<u32> {
std::slice::from_raw_parts_mut(&mut topo as *mut cl_amd_device_topology as *mut u8, size)
.copy_from_slice(&result);
}
Ok(topo.bus as u32)
let device = topo.device as u32;
let bus = topo.bus as u32;
let function = topo.function as u32;
Ok((device << 16) | (bus << 8) | function)
}

pub fn cache_path(device: &Device, cl_source: &str) -> std::io::Result<std::path::PathBuf> {
Expand All @@ -66,14 +84,12 @@ pub fn cache_path(device: &Device, cl_source: &str) -> std::io::Result<std::path
std::fs::create_dir(&path)?;
}
let mut hasher = Sha256::new();
// If there are multiple devices with the same name and neither has a Bus-Id,
// then there will be a collision. Bus-Id can be missing in the case of an Apple
// GPU. For now, we assume that in the unlikely event of a collision, the same
// cache can be used.
// TODO: We might be able to get around this issue by using cl_vendor_id instead of Bus-Id.
hasher.input(device.name.as_bytes());
if let Some(bus_id) = device.bus_id {
hasher.input(bus_id.to_be_bytes());
if let Some(uuid) = device.uuid {
neithanmo marked this conversation as resolved.
Show resolved Hide resolved
hasher.input(uuid.to_string());
}
if let Some(pci) = device.pci_id {
hasher.input(pci.to_le_bytes());
}
hasher.input(cl_source.as_bytes());
let mut digest = String::new();
Expand Down Expand Up @@ -134,7 +150,8 @@ fn build_device_list() -> Vec<Device> {
brand,
name: d.name()?,
memory: get_memory(d)?,
bus_id: get_bus_id(d).ok(),
uuid: get_device_uuid(d).ok(),
pci_id: get_pci_id(d).ok(),
platform: *platform,
device: d,
})
Expand Down