-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5435 from Courtesy-Xs/add_gpu_launch_config
Add query and other components
- Loading branch information
Showing
22 changed files
with
401 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#pragma once | ||
|
||
#include <memory> | ||
|
||
#include "common/nvgpu_dev_info.h" | ||
#include "target.h" | ||
|
||
namespace colossalAI { | ||
namespace common { | ||
|
||
template <typename Ret> | ||
class DevInfoMgr final { | ||
public: | ||
static std::unique_ptr<Ret> GetDevInfo(int device_num) const { | ||
return std::make_unique<Ret>(device_num); | ||
} | ||
}; | ||
|
||
} // namespace common | ||
} // namespace colossalAI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
#pragma once | ||
|
||
#include <exception> | ||
#include <iostream> | ||
#include <string> | ||
|
||
namespace colossalAI { | ||
namespace common { | ||
|
||
class Target { | ||
public: | ||
enum class OS : int { | ||
Unk = -1, | ||
Linux, | ||
Windows, | ||
}; | ||
enum class Arch : int { | ||
Unk = -1, | ||
X86, | ||
Arm, | ||
NVGPU, | ||
AMDGPU, | ||
Ascend, | ||
}; | ||
enum class BitLen : int { | ||
Unk = -1, | ||
k32, | ||
k64, | ||
}; | ||
|
||
explicit Target(OS os, Arch arch, BitLen bitlen) | ||
: os_(os), arch_(arch), bitlen_(bitlen) {} | ||
|
||
bool defined() const { | ||
return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk); | ||
} | ||
|
||
std::string str() const { | ||
std::string s{"OS: "}; | ||
switch (os_) { | ||
case OS::Unk: | ||
s += "Unk"; | ||
break; | ||
case OS::Linux: | ||
s += "Linux"; | ||
break; | ||
case OS::Windows: | ||
s += "Windows"; | ||
break; | ||
default: | ||
throw std::invalid_argument("Invalid OS type!"); | ||
} | ||
s += "\t"; | ||
s += "Arch: "; | ||
|
||
switch (arch_) { | ||
case Arch::Unk: | ||
s += "Unk"; | ||
break; | ||
case Arch::X86: | ||
s += "X86"; | ||
break; | ||
case Arch::Arm: | ||
s += "Arm"; | ||
break; | ||
case Arch::NVGPU: | ||
s += "NVGPU"; | ||
break; | ||
case Arch::AMDGPU: | ||
s += "AMDGPU"; | ||
break; | ||
case Arch::Ascend: | ||
s += "Ascend"; | ||
break; | ||
default: | ||
throw std::invalid_argument("Invalid Arch type!"); | ||
} | ||
s += "\t"; | ||
s += "BitLen: "; | ||
|
||
switch (bitlen_) { | ||
case BitLen::Unk: | ||
s += "Unk"; | ||
break; | ||
case BitLen::k32: | ||
s += "k32"; | ||
break; | ||
case BitLen::k64: | ||
s += "k64"; | ||
break; | ||
default: | ||
throw std::invalid_argument("Invalid target bit length!"); | ||
} | ||
|
||
return s; | ||
} | ||
|
||
OS os() const { return os_; } | ||
Arch arch() const { return arch_; } | ||
BitLen bitlen() const { return bitlen_; } | ||
|
||
static Target DefaultX86Target(); | ||
static Target DefaultArmTarget(); | ||
static Target DefaultRocmTarget(); | ||
static Target DefaultAscendTarget(); | ||
|
||
static Target DefaultCUDATarget() { | ||
return Target(OS::Linux, Arch::CUDA, BitLen::k64); | ||
} | ||
|
||
friend std::ostream& operator<<(std::ostream& os, const Target& target); | ||
friend bool operator==(const Target& lhs, const Target& rhs); | ||
friend bool operator!=(const Target& lhs, const Target& rhs); | ||
|
||
private: | ||
OS os_{OS::Unk}; | ||
Arch arch_{Arch::Unk}; | ||
BitLen bitlen_{BitLen::Unk}; | ||
}; | ||
|
||
std::ostream& operator<<(std::ostream& os, const Target& target) { | ||
std::cout << target.str() << std::endl; | ||
} | ||
bool operator==(const Target& lhs, const Target& rhs) { | ||
return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) && | ||
(lhs.bitlen_ == rhs.bitlen_); | ||
} | ||
bool operator!=(const Target& lhs, const Target& rhs) { | ||
return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) && | ||
(lhs.bitlen_ != rhs.bitlen_); | ||
} | ||
|
||
} // namespace common | ||
} // namespace colossalAI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +0,0 @@ | ||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h | ||
#ifndef TORCH_CHECK | ||
#define TORCH_CHECK AT_CHECK | ||
#endif | ||
|
||
#ifdef VERSION_GE_1_3 | ||
#define DATA_PTR data_ptr | ||
#else | ||
#define DATA_PTR data | ||
#endif | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.