Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UR] Improve handling of error cases in urProgramLink #1458

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4277,6 +4277,11 @@ urProgramCompile(
/// in `phProgram` will contain a binary of the
/// ::UR_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in
/// `hContext`.
/// - If a non-success code is returned and `phProgram` is not `nullptr`, it
/// will contain an unspecified program or `nullptr`. Implementations may
/// use the build log of this program (accessible via
/// ::urProgramGetBuildInfo) to provide an error log for the linking
/// failure.
///
/// @remarks
/// _Analogues_
Expand Down Expand Up @@ -9278,6 +9283,11 @@ urProgramCompileExp(
/// in `phProgram` will contain a binary of the
/// ::UR_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in
/// `phDevices`.
/// - If a non-success code is returned and `phProgram` is not `nullptr`, it
/// will contain an unspecified program or `nullptr`. Implementations may
/// use the build log of this program (accessible via
/// ::urProgramGetBuildInfo) to provide an error log for the linking
/// failure.
///
/// @remarks
/// _Analogues_
Expand Down
1 change: 1 addition & 0 deletions scripts/core/exp-multi-device-compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ analogue:
details:
- "The application may call this function from simultaneous threads."
- "Following a successful call to this entry point the program returned in `phProgram` will contain a binary of the $X_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in `phDevices`."
- "If a non-success code is returned and `phProgram` is not `nullptr`, it will contain an unspecified program or `nullptr`. Implementations may use the build log of this program (accessible via $xProgramGetBuildInfo) to provide an error log for the linking failure."
params:
- type: $x_context_handle_t
name: hContext
Expand Down
1 change: 1 addition & 0 deletions scripts/core/program.yml
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ analogue:
details:
- "The application may call this function from simultaneous threads."
- "Following a successful call to this entry point the program returned in `phProgram` will contain a binary of the $X_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in `hContext`."
- "If a non-success code is returned and `phProgram` is not `nullptr`, it will contain an unspecified program or `nullptr`. Implementations may use the build log of this program (accessible via $xProgramGetBuildInfo) to provide an error log for the linking failure."
params:
- type: $x_context_handle_t
name: hContext
Expand Down
24 changes: 24 additions & 0 deletions scripts/templates/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,30 @@ def get_pfntables(specs, meta, namespace, tags):

return tables

"""
Public:
returns an expression setting required output parameters to null on entry
"""
def get_initial_null_set(obj):
cname = obj_traits.class_name(obj)
lvalue = {
('$xProgram', 'Link'): 'phProgram',
('$xProgram', 'LinkExp'): 'phProgram',
}.get((cname, obj['name']))
if lvalue is not None:
return 'if (nullptr != {0}) {{*{0} = nullptr;}}'.format(lvalue)
return ""

"""
Public:
returns true if the function always wraps output pointers in loader handles
"""
def always_wrap_outputs(obj):
cname = obj_traits.class_name(obj)
return (cname, obj['name']) in [
('$xProgram', 'Link'),
('$xProgram', 'LinkExp'),
]

"""
Private:
Expand Down
6 changes: 3 additions & 3 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace ur_loader
{
${x}_result_t result = ${X}_RESULT_SUCCESS;<%
add_local = False
%>
%>${th.get_initial_null_set(obj)}

%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):

Expand Down Expand Up @@ -271,7 +271,7 @@ namespace ur_loader
del add_local
%>
%for i, item in enumerate(epilogue):
%if 0 == i and not item['release']:
%if 0 == i and not item['release'] and not th.always_wrap_outputs(obj):
if( ${X}_RESULT_SUCCESS != result )
return result;

Expand Down Expand Up @@ -309,7 +309,7 @@ namespace ur_loader
${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
%else:
// convert platform handle to loader handle
%if item['optional']:
%if item['optional'] or th.always_wrap_outputs(obj):
if( nullptr != ${item['name']} )
*${item['name']} = reinterpret_cast<${item['type']}>(
${item['factory']}.getInstance( *${item['name']}, dditable ) );
Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/libapi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ try {
%elif th.obj_traits.is_loader_only(obj):
return ur_lib::${th.make_func_name(n, tags, obj)}(${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
%else:
auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
${th.get_initial_null_set(obj)}auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNINITIALIZED;

Expand Down
1 change: 1 addition & 0 deletions scripts/templates/nullddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace driver
)
try {
${x}_result_t result = ${X}_RESULT_SUCCESS;
${th.get_initial_null_set(obj)}

// if the driver has created a custom function, then call it instead of using the generic path
auto ${th.make_pfn_name(n, tags, obj)} = d_context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/trcddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ur_tracing_layer
${line}
%endfor
)
{
{${th.get_initial_null_set(obj)}
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace ur_validation_layer
${line}
%endfor
)
{
{${th.get_initial_null_set(obj)}
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) {
Expand Down
9 changes: 8 additions & 1 deletion source/adapters/cuda/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t hContext,

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand All @@ -284,6 +287,10 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
const ur_program_handle_t *phPrograms, const char *pOptions,
ur_program_handle_t *phProgram) {
ur_result_t Result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// All programs must be associated with the same device
for (auto i = 1u; i < count; ++i)
UR_ASSERT(phPrograms[i]->getDevice() == phPrograms[0]->getDevice(),
Expand Down
15 changes: 10 additions & 5 deletions source/adapters/hip/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t,

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramLink(ur_context_handle_t, uint32_t,
const ur_program_handle_t *,
const char *,
ur_program_handle_t *) {
UR_APIEXPORT ur_result_t UR_APICALL
urProgramLink(ur_context_handle_t, uint32_t, const ur_program_handle_t *,
const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand Down
2 changes: 2 additions & 0 deletions source/adapters/level_zero/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ ur_result_t ze2urResult(ze_result_t ZeResult) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
case ZE_RESULT_ERROR_UNSUPPORTED_FEATURE:
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
case ZE_RESULT_ERROR_MODULE_LINK_FAILURE:
return UR_RESULT_ERROR_PROGRAM_LINK_FAILURE;
default:
return UR_RESULT_ERROR_UNKNOWN;
}
Expand Down
9 changes: 4 additions & 5 deletions source/adapters/level_zero/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
for (uint32_t i = 0; i < numDevices; i++) {
UR_ASSERT(hContext->isValidDevice(phDevices[i]),
UR_RESULT_ERROR_INVALID_DEVICE);
Expand Down Expand Up @@ -445,11 +448,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
// because the ZeBuildLog tells which symbols are unresolved.
if (ZeResult == ZE_RESULT_SUCCESS) {
ZeResult = checkUnresolvedSymbols(ZeModule, &ZeBuildLog);
if (ZeResult == ZE_RESULT_ERROR_MODULE_LINK_FAILURE) {
UrResult =
UR_RESULT_ERROR_UNKNOWN; // TODO:
// UR_RESULT_ERROR_PROGRAM_LINK_FAILURE;
} else if (ZeResult != ZE_RESULT_SUCCESS) {
if (ZeResult != ZE_RESULT_SUCCESS) {
return ze2urResult(ZeResult);
}
}
Expand Down
9 changes: 7 additions & 2 deletions source/adapters/native_cpu/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ UR_APIEXPORT ur_result_t UR_APICALL
urProgramLink(ur_context_handle_t hContext, uint32_t count,
const ur_program_handle_t *phPrograms, const char *pOptions,
ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
std::ignore = hContext;
std::ignore = count;
std::ignore = phPrograms;
std::ignore = pOptions;
std::ignore = phProgram;

DIE_NO_IMPLEMENTATION
}
Expand All @@ -144,7 +146,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand Down
6 changes: 6 additions & 0 deletions source/adapters/null/ur_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
*phProgram ///< [out] pointer to handle of program object created.
) try {
ur_result_t result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// if the driver has created a custom function, then call it instead of using the generic path
auto pfnLink = d_context.urDdiTable.Program.pfnLink;
Expand Down Expand Up @@ -5728,6 +5731,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
*phProgram ///< [out] pointer to handle of program object created.
) try {
ur_result_t result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// if the driver has created a custom function, then call it instead of using the generic path
auto pfnLinkExp = d_context.urDdiTable.ProgramExp.pfnLinkExp;
Expand Down
12 changes: 11 additions & 1 deletion source/adapters/opencl/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
pOptions, cl_adapter::cast<cl_uint>(count),
cl_adapter::cast<const cl_program *>(phPrograms), nullptr,
nullptr, &CLResult));

if (CL_INVALID_BINARY == CLResult) {
// Some OpenCL drivers incorrectly return CL_INVALID_BINARY here, convert it
// to CL_LINK_PROGRAM_FAILURE
CLResult = CL_LINK_PROGRAM_FAILURE;
}

CL_RETURN_ON_FAILURE(CLResult);

return UR_RESULT_SUCCESS;
Expand All @@ -236,7 +243,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand Down
6 changes: 6 additions & 0 deletions source/loader/layers/tracing/ur_trcddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2486,6 +2486,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
auto pfnLink = context.urDdiTable.Program.pfnLink;

if (nullptr == pfnLink) {
Expand Down Expand Up @@ -7639,6 +7642,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
auto pfnLinkExp = context.urDdiTable.ProgramExp.pfnLinkExp;

if (nullptr == pfnLinkExp) {
Expand Down
6 changes: 6 additions & 0 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
auto pfnLink = context.urDdiTable.Program.pfnLink;

if (nullptr == pfnLink) {
Expand Down Expand Up @@ -9285,6 +9288,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
auto pfnLinkExp = context.urDdiTable.ProgramExp.pfnLinkExp;

if (nullptr == pfnLinkExp) {
Expand Down
26 changes: 14 additions & 12 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2556,6 +2556,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
*phProgram ///< [out] pointer to handle of program object created.
) {
ur_result_t result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// extract platform's function pointer table
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
Expand All @@ -2578,14 +2581,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
result =
pfnLink(hContext, count, phProgramsLocal.data(), pOptions, phProgram);

if (UR_RESULT_SUCCESS != result) {
return result;
}

try {
// convert platform handle to loader handle
*phProgram = reinterpret_cast<ur_program_handle_t>(
ur_program_factory.getInstance(*phProgram, dditable));
if (nullptr != phProgram) {
*phProgram = reinterpret_cast<ur_program_handle_t>(
ur_program_factory.getInstance(*phProgram, dditable));
}
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}
Expand Down Expand Up @@ -7909,6 +7910,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
*phProgram ///< [out] pointer to handle of program object created.
) {
ur_result_t result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// extract platform's function pointer table
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
Expand Down Expand Up @@ -7938,14 +7942,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
result = pfnLinkExp(hContext, numDevices, phDevicesLocal.data(), count,
phProgramsLocal.data(), pOptions, phProgram);

if (UR_RESULT_SUCCESS != result) {
return result;
}

try {
// convert platform handle to loader handle
*phProgram = reinterpret_cast<ur_program_handle_t>(
ur_program_factory.getInstance(*phProgram, dditable));
if (nullptr != phProgram) {
*phProgram = reinterpret_cast<ur_program_handle_t>(
ur_program_factory.getInstance(*phProgram, dditable));
}
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}
Expand Down
Loading
Loading