Skip to content

Commit

Permalink
Clean up JITList and associated helpers
Browse files Browse the repository at this point in the history
Summary:
* Renamed onJITList{,Impl} in pyjit.cpp to shouldCompile(), which is more
  descriptive of what the code does.  Even when there's no JIT list the function
  currently returns "true".  Speaking of returns, there was a bug where if the
  JIT list failed to look up an item it would return -1, but all uses of this
  value blindly converted it to a true/false success status so it looked like
  success.

* JITList::lookup() is unused. Its functional equivalent is onJitListImpl() in
  pyjit.cpp, so move that definiton over and name it lookupFunc() to be
  consistent with lookupCode().

* Use std::string_view instead of const char* wherever possible.  Added a
  stringAsUnicode() helper to make this easier, there's many other places in
  Jit/ where we can use this as well.

* Added nullptr checks to uses of JITList::pathBasename(), which can return
  nullptr on error.

Reviewed By: jbower-fb

Differential Revision: D51764325

fbshipit-source-id: 30f39f8f29b8c092aa7909b17219f9afa689199b
  • Loading branch information
Alex Malyshev authored and facebook-github-bot committed Dec 5, 2023
1 parent bea832f commit 67f5fe3
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 184 deletions.
231 changes: 113 additions & 118 deletions CinderX/Jit/jit_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,53 +59,33 @@ bool JITList::parseFile(const char* filename) {
return true;
}

bool JITList::parseLine(const std::string& line) {
bool JITList::parseLine(std::string_view line) {
if (line.empty() || line.at(0) == '#') {
return true;
}
auto atpos = line.find("@");
if (atpos == std::string::npos) {
if (atpos == std::string_view::npos) {
auto cln_pos = line.find(":");
if (cln_pos == std::string::npos) {
return false;
}
std::string mod = line.substr(0, cln_pos);
std::string qualname = line.substr(cln_pos + 1);
if (!addEntryFO(mod.c_str(), qualname.c_str())) {
return false;
}
} else {
std::string name = line.substr(0, atpos);
std::string loc_str = line.substr(atpos + 1);
auto cln_pos = loc_str.find(":");
if (cln_pos == std::string::npos) {
return false;
}
std::string file = line.substr(atpos + 1, cln_pos);
std::string file_line = loc_str.substr(cln_pos + 1);
if (!addEntryCO(name.c_str(), file.c_str(), file_line.c_str())) {
if (cln_pos == std::string_view::npos) {
return false;
}
std::string_view mod = line.substr(0, cln_pos);
std::string_view qualname = line.substr(cln_pos + 1);
return addEntryFunc(mod, qualname);
}
return true;
}

bool JITList::addEntryFO(const char* module_name, const char* qualname) {
JIT_DCHECK(
!g_threaded_compile_context.compileRunning(),
"unexpected multithreading");
auto mn_obj = Ref<>::steal(PyUnicode_FromString(module_name));
if (mn_obj == nullptr) {
std::string_view name = line.substr(0, atpos);
std::string_view loc_str = line.substr(atpos + 1);
auto cln_pos = loc_str.find(":");
if (cln_pos == std::string_view::npos) {
return false;
}
auto qn_obj = Ref<>::steal(PyUnicode_FromString(qualname));
if (qn_obj == nullptr) {
return false;
}
return addEntryFO(mn_obj, qn_obj);
std::string_view file = line.substr(atpos + 1, cln_pos);
std::string_view file_line = loc_str.substr(cln_pos + 1);
return addEntryCode(name, file, file_line);
}

bool JITList::addEntryFO(BorrowedRef<> module_name, BorrowedRef<> qualname) {
bool JITList::addEntryFunc(BorrowedRef<> module_name, BorrowedRef<> qualname) {
JIT_DCHECK(
!g_threaded_compile_context.compileRunning(),
"unexpected multithreading");
Expand All @@ -122,57 +102,22 @@ bool JITList::addEntryFO(BorrowedRef<> module_name, BorrowedRef<> qualname) {
return PySet_Add(qualname_set, qualname) == 0;
}

bool JITList::addEntryCO(
const char* name,
const char* file,
const char* line_no_str) {
bool JITList::addEntryFunc(std::string_view module_name, std::string_view qualname) {
JIT_DCHECK(
!g_threaded_compile_context.compileRunning(),
"unexpected multithreading");
auto name_obj = Ref<>::steal(PyUnicode_FromString(name));
if (name_obj == nullptr) {
return false;
}
auto file_obj = Ref<>::steal(PyUnicode_FromString(file));
if (file_obj == nullptr) {
return false;
}
Ref<> basename_obj(pathBasename(file_obj));
long line_no;
try {
line_no = std::stol(line_no_str);
} catch (...) {
Ref<> mn_obj = stringAsUnicode(module_name);
if (mn_obj == nullptr) {
return false;
}
auto line_no_obj = Ref<>::steal(PyLong_FromLong(line_no));
if (file_obj == nullptr) {
Ref<> qn_obj = stringAsUnicode(qualname);
if (qn_obj == nullptr) {
return false;
}
return addEntryCO(name_obj, basename_obj, line_no_obj);
return addEntryFunc(mn_obj, qn_obj);
}

Ref<> JITList::pathBasename(BorrowedRef<> path) {
JIT_DCHECK(
!g_threaded_compile_context.compileRunning(),
"unexpected multithreading");
if (path_sep_ == nullptr) {
const wchar_t* sep_str = L"/";
auto sep_str_obj = Ref<>::steal(PyUnicode_FromWideChar(&sep_str[0], 1));
if (sep_str_obj == nullptr) {
return nullptr;
}
path_sep_ = std::move(sep_str_obj);
}
auto split_path_obj = Ref<>::steal(PyUnicode_RSplit(path, path_sep_, 1));
if (split_path_obj == nullptr || !PyList_Check(split_path_obj) ||
PyList_GET_SIZE(split_path_obj.get()) < 1) {
return nullptr;
}
return Ref<>::create(PyList_GET_ITEM(
split_path_obj.get(), PyList_GET_SIZE(split_path_obj.get()) - 1));
}

bool JITList::addEntryCO(
bool JITList::addEntryCode(
BorrowedRef<> name,
BorrowedRef<> file,
BorrowedRef<> line_no) {
Expand Down Expand Up @@ -202,39 +147,61 @@ bool JITList::addEntryCO(
return PySet_Add(line_set, line_no) == 0;
}

int JITList::lookup(BorrowedRef<PyFunctionObject> func) {
int res;
if (func->func_module) {
if ((res = lookupFO(func->func_module, func->func_qualname))) {
return res;
}
bool JITList::addEntryCode(
std::string_view name,
std::string_view file,
std::string_view line_no_str) {
JIT_DCHECK(
!g_threaded_compile_context.compileRunning(),
"unexpected multithreading");
Ref<> name_obj = stringAsUnicode(name);
if (name_obj == nullptr) {
return false;
}
if (func->func_code) {
return lookupCO(reinterpret_cast<PyCodeObject*>(func->func_code));
Ref<> file_obj = stringAsUnicode(file);
if (file_obj == nullptr) {
return false;
}
Ref<> basename_obj = pathBasename(file_obj);
if (basename_obj == nullptr) {
return false;
}
return 0;
}

int JITList::lookupFO(BorrowedRef<> mod, BorrowedRef<> qualname) {
if (mod == nullptr) {
return 0;
long line_no = 0;
auto result = std::from_chars(
line_no_str.begin(), line_no_str.end(), line_no);
if (result.ec != std::errc{}) {
return false;
}
// Check for an exact module:qualname match
BorrowedRef<> name_set = PyDict_GetItemWithError(qualnames_, mod);
if (name_set == nullptr) {
return 0;

auto line_no_obj = Ref<>::steal(PyLong_FromLong(line_no));
if (file_obj == nullptr) {
return false;
}
return addEntryCode(name_obj, basename_obj, line_no_obj);
}

int JITList::lookupFunc(BorrowedRef<PyFunctionObject> func) const {
BorrowedRef<PyCodeObject> code = reinterpret_cast<PyCodeObject*>(
func->func_code
);
if (lookupCode(code) == 1) {
return 1;
}
return PySet_Contains(name_set, qualname);
return lookupName(func->func_module, func->func_qualname);
}

int JITList::lookupCO(BorrowedRef<PyCodeObject> code) {
int JITList::lookupCode(BorrowedRef<PyCodeObject> code) const {
JIT_DCHECK(
!g_threaded_compile_context.compileRunning(),
"unexpected multithreading");
"Unexpected multithreading");

auto name =
Ref<>::create(code->co_qualname ? code->co_qualname : code->co_name);
Ref<> line_no = Ref<>::steal(PyLong_FromLong(code->co_firstlineno));
Ref<> file(pathBasename(code->co_filename));
Ref<> file = pathBasename(code->co_filename);
if (file == nullptr) {
return 0;
}

BorrowedRef<> file_set = PyDict_GetItemWithError(name_file_line_no_, name);
if (file_set == nullptr) {
Expand All @@ -245,7 +212,22 @@ int JITList::lookupCO(BorrowedRef<PyCodeObject> code) {
return 0;
}

return g_jitlist_match_line_numbers ? PySet_Contains(line_set, line_no) : 1;
if (!g_jitlist_match_line_numbers) {
return 1;
}

Ref<> line_no = Ref<>::steal(PyLong_FromLong(code->co_firstlineno));
return PySet_Contains(line_set, line_no);
}

int JITList::lookupName(BorrowedRef<> module_name, BorrowedRef<> qualname) const {
if (module_name == nullptr) {
return 0;
}

// Check for an exact module:qualname match.
BorrowedRef<> name_set = PyDict_GetItemWithError(qualnames_, module_name);
return name_set != nullptr ? PySet_Contains(name_set, qualname) : 0;
}

Ref<> JITList::getList() const {
Expand All @@ -265,7 +247,7 @@ std::unique_ptr<WildcardJITList> WildcardJITList::create() {
return nullptr;
}

auto wildcard = Ref<>::steal(PyUnicode_FromString("*"));
Ref<> wildcard = stringAsUnicode("*");
if (wildcard == nullptr) {
return nullptr;
}
Expand All @@ -274,32 +256,47 @@ std::unique_ptr<WildcardJITList> WildcardJITList::create() {
new WildcardJITList(std::move(wildcard), std::move(qualnames)));
}

bool WildcardJITList::addEntryFO(
const char* module_name,
const char* qualname) {
if ((strcmp(module_name, "*") == 0) && (strcmp(qualname, "*") == 0)) {
// *:* is invalid
return false;
Ref<> JITList::pathBasename(BorrowedRef<> path) const {
JIT_DCHECK(
!g_threaded_compile_context.compileRunning(),
"unexpected multithreading");
if (path_sep_ == nullptr) {
const wchar_t* sep_str = L"/";
auto sep_str_obj = Ref<>::steal(PyUnicode_FromWideChar(&sep_str[0], 1));
if (sep_str_obj == nullptr) {
return nullptr;
}
path_sep_ = std::move(sep_str_obj);
}
auto split_path_obj = Ref<>::steal(PyUnicode_RSplit(path, path_sep_, 1));
if (split_path_obj == nullptr || !PyList_Check(split_path_obj) ||
PyList_GET_SIZE(split_path_obj.get()) < 1) {
return nullptr;
}
return JITList::addEntryFO(module_name, qualname);
return Ref<>::create(PyList_GET_ITEM(
split_path_obj.get(), PyList_GET_SIZE(split_path_obj.get()) - 1));
}

bool WildcardJITList::addEntryFunc(
std::string_view module_name,
std::string_view qualname) {
// *:* is invalid.
return (module_name != "*" || qualname != "*") && JITList::addEntryFunc(module_name, qualname);
}

int WildcardJITList::lookupFO(BorrowedRef<> mod, BorrowedRef<> qualname) {
int WildcardJITList::lookupName(BorrowedRef<> module_name, BorrowedRef<> qualname) const {
// Check for an exact match
int st = JITList::lookupFO(mod, qualname);
if (st != 0) {
if (int st = JITList::lookupName(module_name, qualname); st != 0) {
return st;
}

// Check if all functions in the module are enabled
st = JITList::lookupFO(mod, wildcard_);
if (st != 0) {
if (int st = JITList::lookupName(module_name, wildcard_); st != 0) {
return st;
}

// Check if the qualname is unconditionally enabled
st = JITList::lookupFO(wildcard_, qualname);
if (st != 0) {
if (int st = JITList::lookupName(wildcard_, qualname); st != 0) {
return st;
}

Expand Down Expand Up @@ -327,14 +324,12 @@ int WildcardJITList::lookupFO(BorrowedRef<> mod, BorrowedRef<> qualname) {
}

// Check if the instance method is unconditionally enabled
st = JITList::lookupFO(wildcard_, query);
if (st != 0) {
if (int st = JITList::lookupName(wildcard_, query); st != 0) {
return st;
}

// Check if the instance method is enabled in the module
st = JITList::lookupFO(mod, query);
if (st != 0) {
if (int st = JITList::lookupName(module_name, query); st != 0) {
return st;
}

Expand Down
Loading

0 comments on commit 67f5fe3

Please sign in to comment.