From 067fdf8c700818817e4de6959668dcaafffb8b84 Mon Sep 17 00:00:00 2001 From: Alex Malyshev Date: Wed, 24 Jan 2024 08:22:45 -0800 Subject: [PATCH] Expect PyFunctionObject* in _PyJit_IsCompiled() Summary: If you look at most of its callsites, they already have a PyFunctionObject* but have to cast it to PyObject*. Within the function, the argument is expected to be a PyFunctionObject*. is_jit_compiled() stands out here because unlike jit_suppress(), disassemble(), and print_hir(), it does not error when it gets an argument that's not a PyFunctionObject. It just returns false. This diff keeps this behavior. Reviewed By: swtaarrs Differential Revision: D53007959 fbshipit-source-id: 1d012eae2a5a01d4d7f03b5738e1005df93f9bb9 --- cinderx/Common/watchers.cpp | 2 +- cinderx/Interpreter/interpreter.c | 2 +- cinderx/Jit/lir/generator.cpp | 2 +- cinderx/Jit/pyjit.cpp | 39 +++++++++++++----------------- cinderx/Jit/pyjit.h | 2 +- cinderx/StaticPython/classloader.c | 2 +- 6 files changed, 22 insertions(+), 27 deletions(-) diff --git a/cinderx/Common/watchers.cpp b/cinderx/Common/watchers.cpp index ed3185898f2..341d152f40c 100644 --- a/cinderx/Common/watchers.cpp +++ b/cinderx/Common/watchers.cpp @@ -81,7 +81,7 @@ static int install_func_watcher() { break; case PyFunction_EVENT_MODIFY_QUALNAME: // allow reconsideration of whether this function should be compiled - if (!_PyJIT_IsCompiled((PyObject*)func)) { + if (!_PyJIT_IsCompiled(func)) { // func_set_qualname will assign this again, but we need to assign it // now so that PyEntry_init can consider the new qualname Py_INCREF(new_value); diff --git a/cinderx/Interpreter/interpreter.c b/cinderx/Interpreter/interpreter.c index f342d358e57..ddaa3ab18a0 100644 --- a/cinderx/Interpreter/interpreter.c +++ b/cinderx/Interpreter/interpreter.c @@ -5329,7 +5329,7 @@ PyEntry_AutoJIT(PyFunctionObject *func, void PyEntry_init(PyFunctionObject *func) { - assert(!_PyJIT_IsCompiled((PyObject *)func)); + assert(!_PyJIT_IsCompiled(func)); if (_PyJIT_IsAutoJITEnabled()) { func->vectorcall = (vectorcallfunc)PyEntry_AutoJIT; return; diff --git a/cinderx/Jit/lir/generator.cpp b/cinderx/Jit/lir/generator.cpp index 72cc09c8da9..d31136cbca7 100644 --- a/cinderx/Jit/lir/generator.cpp +++ b/cinderx/Jit/lir/generator.cpp @@ -1865,7 +1865,7 @@ LIRGenerator::TranslatedBlock LIRGenerator::TranslateOneBasicBlock( std::stringstream ss; Instruction* lir; - if (_PyJIT_IsCompiled((PyObject*)func)) { + if (_PyJIT_IsCompiled(func)) { lir = bbb.appendInstr( instr->dst(), Instruction::kCall, diff --git a/cinderx/Jit/pyjit.cpp b/cinderx/Jit/pyjit.cpp index 29344ace1ec..881830c9016 100644 --- a/cinderx/Jit/pyjit.cpp +++ b/cinderx/Jit/pyjit.cpp @@ -933,17 +933,19 @@ static PyObject* get_batch_compilation_time_ms(PyObject*, PyObject*) { return PyLong_FromLong(g_batch_compilation_time_ms); } -static PyObject* force_compile(PyObject* /* self */, PyObject* func) { - if (!PyFunction_Check(func)) { +static PyObject* force_compile(PyObject* /* self */, PyObject* func_obj) { + if (!PyFunction_Check(func_obj)) { PyErr_SetString(PyExc_TypeError, "force_compile expected a function"); return nullptr; } + BorrowedRef func = func_obj; + if (_PyJIT_IsCompiled(func)) { Py_RETURN_FALSE; } - switch (_PyJIT_CompileFunction(reinterpret_cast(func))) { + switch (_PyJIT_CompileFunction(func)) { case PYJIT_RESULT_OK: Py_RETURN_TRUE; case PYJIT_RESULT_CANNOT_SPECIALIZE: @@ -972,28 +974,21 @@ static PyObject* auto_jit_threshold(PyObject* /* self */, PyObject*) { return PyLong_FromLong(getConfig().auto_jit_threshold); } -int _PyJIT_IsCompiled(PyObject* func) { - if (jit_ctx == nullptr) { - return 0; - } - JIT_DCHECK( - PyFunction_Check(func), - "Expected PyFunctionObject, got '{:.200}'", - Py_TYPE(func)->tp_name); - - return int{jit_ctx->didCompile(func)}; +int _PyJIT_IsCompiled(PyFunctionObject* func) { + return jit_ctx != nullptr ? jit_ctx->didCompile(func) : 0; } static PyObject* is_jit_compiled(PyObject* /* self */, PyObject* func) { - int st = _PyJIT_IsCompiled(func); - PyObject* res = nullptr; - if (st == 1) { - res = Py_True; - } else if (st == 0) { - res = Py_False; - } - Py_XINCREF(res); - return res; + if (!PyFunction_Check(func)) { + PyErr_SetString( + PyExc_RuntimeError, "Must call is_jit_compiled with a function object"); + return nullptr; + } + + if (_PyJIT_IsCompiled(reinterpret_cast(func))) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; } static PyObject* print_hir(PyObject* /* self */, PyObject* func) { diff --git a/cinderx/Jit/pyjit.h b/cinderx/Jit/pyjit.h index 416515a4b37..bf6dae7a3c2 100644 --- a/cinderx/Jit/pyjit.h +++ b/cinderx/Jit/pyjit.h @@ -239,7 +239,7 @@ PyAPI_FUNC(PyObject*) _PyJIT_GenYieldFromValue(PyGenObject* gen); * Returns 1 if the function is JITed, 0 if not. */ -PyAPI_FUNC(int) _PyJIT_IsCompiled(PyObject* func); +PyAPI_FUNC(int) _PyJIT_IsCompiled(PyFunctionObject* func); /* * Returns a borrowed reference to the globals for the top-most Python function diff --git a/cinderx/StaticPython/classloader.c b/cinderx/StaticPython/classloader.c index cfb49022103..7a78d47204b 100644 --- a/cinderx/StaticPython/classloader.c +++ b/cinderx/StaticPython/classloader.c @@ -1305,7 +1305,7 @@ void set_entry_from_func(_PyType_VTableEntry *entry, PyFunctionObject *func) { /* this will always be invoked statically via the v-table */ entry->vte_entry = (vectorcallfunc)vtable_static_function_dont_bolt; } else { - assert(_PyJIT_IsCompiled((PyObject *)func)); + assert(_PyJIT_IsCompiled(func)); entry->vte_entry = JITRT_GET_STATIC_ENTRY(func->vectorcall); } }