Skip to content

Commit

Permalink
[UR] Improve handling of error cases in urProgramLink
Browse files Browse the repository at this point in the history
Note that this change includes a specification change:
urProgramLink now requires the output parameter to contain either
nullptr or some unspecified binary on failure.

As well as this change, a number of bugs have been fixed:
* The Level Zero adapter now correctly returns
  `UR_RESULT_ERROR_PROGRAM_LINK_FAILURE` when linking fails, rather
  than `UR_RESULT_ERROR_UNKNOWN`.
* A workaround has been added for some OpenCL devices that return
  `CL_INVALID_BINARY` rather than `CL_LINK_PROGRAM_FAILURE` on
  linker failure.
* The `phProgram` handle is wrapped in a loader handle by the
  loader even if an error would be returned. This is required by
  Level Zero, which outputs a "dummy" program to store the linker
  log.

Conformance tests have also been added.
  • Loading branch information
RossBrunton committed Jul 9, 2024
1 parent 9d3bce6 commit 25c75c0
Show file tree
Hide file tree
Showing 25 changed files with 252 additions and 33 deletions.
10 changes: 10 additions & 0 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4277,6 +4277,11 @@ urProgramCompile(
/// in `phProgram` will contain a binary of the
/// ::UR_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in
/// `hContext`.
/// - If a non-success code is returned and `phProgram` is not `nullptr`, it
/// will contain an unspecified program or `nullptr`. Implementations may
/// use the build log of this program (accessible via
/// ::urProgramGetBuildInfo) to provide an error log for the linking
/// failure.
///
/// @remarks
/// _Analogues_
Expand Down Expand Up @@ -9278,6 +9283,11 @@ urProgramCompileExp(
/// in `phProgram` will contain a binary of the
/// ::UR_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in
/// `phDevices`.
/// - If a non-success code is returned and `phProgram` is not `nullptr`, it
/// will contain an unspecified program or `nullptr`. Implementations may
/// use the build log of this program (accessible via
/// ::urProgramGetBuildInfo) to provide an error log for the linking
/// failure.
///
/// @remarks
/// _Analogues_
Expand Down
1 change: 1 addition & 0 deletions scripts/core/exp-multi-device-compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ analogue:
details:
- "The application may call this function from simultaneous threads."
- "Following a successful call to this entry point the program returned in `phProgram` will contain a binary of the $X_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in `phDevices`."
- "If a non-success code is returned and `phProgram` is not `nullptr`, it will contain an unspecified program or `nullptr`. Implementations may use the build log of this program (accessible via $xProgramGetBuildInfo) to provide an error log for the linking failure."
params:
- type: $x_context_handle_t
name: hContext
Expand Down
1 change: 1 addition & 0 deletions scripts/core/program.yml
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ analogue:
details:
- "The application may call this function from simultaneous threads."
- "Following a successful call to this entry point the program returned in `phProgram` will contain a binary of the $X_PROGRAM_BINARY_TYPE_EXECUTABLE type for each device in `hContext`."
- "If a non-success code is returned and `phProgram` is not `nullptr`, it will contain an unspecified program or `nullptr`. Implementations may use the build log of this program (accessible via $xProgramGetBuildInfo) to provide an error log for the linking failure."
params:
- type: $x_context_handle_t
name: hContext
Expand Down
24 changes: 24 additions & 0 deletions scripts/templates/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,30 @@ def get_pfntables(specs, meta, namespace, tags):

return tables

"""
Public:
returns an expression setting required output parameters to null on entry
"""
def get_initial_null_set(obj):
cname = obj_traits.class_name(obj)
lvalue = {
('$xProgram', 'Link'): 'phProgram',
('$xProgram', 'LinkExp'): 'phProgram',
}.get((cname, obj['name']))
if lvalue is not None:
return 'if (nullptr != {0}) {{*{0} = nullptr;}}'.format(lvalue)
return ""

"""
Public:
returns true if the function always wraps output pointers in loader handles
"""
def always_wrap_outputs(obj):
cname = obj_traits.class_name(obj)
return (cname, obj['name']) in [
('$xProgram', 'Link'),
('$xProgram', 'LinkExp'),
]

"""
Private:
Expand Down
6 changes: 3 additions & 3 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace ur_loader
{
${x}_result_t result = ${X}_RESULT_SUCCESS;<%
add_local = False
%>
%>${th.get_initial_null_set(obj)}

%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):

Expand Down Expand Up @@ -271,7 +271,7 @@ namespace ur_loader
del add_local
%>
%for i, item in enumerate(epilogue):
%if 0 == i and not item['release']:
%if 0 == i and not item['release'] and not th.always_wrap_outputs(obj):
if( ${X}_RESULT_SUCCESS != result )
return result;
Expand Down Expand Up @@ -309,7 +309,7 @@ namespace ur_loader
${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
%else:
// convert platform handle to loader handle
%if item['optional']:
%if item['optional'] or th.always_wrap_outputs(obj):
if( nullptr != ${item['name']} )
*${item['name']} = reinterpret_cast<${item['type']}>(
${item['factory']}.getInstance( *${item['name']}, dditable ) );
Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/libapi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ try {
%elif th.obj_traits.is_loader_only(obj):
return ur_lib::${th.make_func_name(n, tags, obj)}(${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
%else:
auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
${th.get_initial_null_set(obj)}auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNINITIALIZED;

Expand Down
1 change: 1 addition & 0 deletions scripts/templates/nullddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace driver
)
try {
${x}_result_t result = ${X}_RESULT_SUCCESS;
${th.get_initial_null_set(obj)}

// if the driver has created a custom function, then call it instead of using the generic path
auto ${th.make_pfn_name(n, tags, obj)} = d_context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/trcddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ur_tracing_layer
${line}
%endfor
)
{
{${th.get_initial_null_set(obj)}
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace ur_validation_layer
${line}
%endfor
)
{
{${th.get_initial_null_set(obj)}
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) {
Expand Down
9 changes: 8 additions & 1 deletion source/adapters/cuda/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t hContext,

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand All @@ -284,6 +287,10 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
const ur_program_handle_t *phPrograms, const char *pOptions,
ur_program_handle_t *phProgram) {
ur_result_t Result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// All programs must be associated with the same device
for (auto i = 1u; i < count; ++i)
UR_ASSERT(phPrograms[i]->getDevice() == phPrograms[0]->getDevice(),
Expand Down
15 changes: 10 additions & 5 deletions source/adapters/hip/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t,

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramLink(ur_context_handle_t, uint32_t,
const ur_program_handle_t *,
const char *,
ur_program_handle_t *) {
UR_APIEXPORT ur_result_t UR_APICALL
urProgramLink(ur_context_handle_t, uint32_t, const ur_program_handle_t *,
const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand Down
2 changes: 2 additions & 0 deletions source/adapters/level_zero/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ ur_result_t ze2urResult(ze_result_t ZeResult) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
case ZE_RESULT_ERROR_UNSUPPORTED_FEATURE:
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
case ZE_RESULT_ERROR_MODULE_LINK_FAILURE:
return UR_RESULT_ERROR_PROGRAM_LINK_FAILURE;
default:
return UR_RESULT_ERROR_UNKNOWN;
}
Expand Down
9 changes: 4 additions & 5 deletions source/adapters/level_zero/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
for (uint32_t i = 0; i < numDevices; i++) {
UR_ASSERT(hContext->isValidDevice(phDevices[i]),
UR_RESULT_ERROR_INVALID_DEVICE);
Expand Down Expand Up @@ -445,11 +448,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
// because the ZeBuildLog tells which symbols are unresolved.
if (ZeResult == ZE_RESULT_SUCCESS) {
ZeResult = checkUnresolvedSymbols(ZeModule, &ZeBuildLog);
if (ZeResult == ZE_RESULT_ERROR_MODULE_LINK_FAILURE) {
UrResult =
UR_RESULT_ERROR_UNKNOWN; // TODO:
// UR_RESULT_ERROR_PROGRAM_LINK_FAILURE;
} else if (ZeResult != ZE_RESULT_SUCCESS) {
if (ZeResult != ZE_RESULT_SUCCESS) {
return ze2urResult(ZeResult);
}
}
Expand Down
9 changes: 7 additions & 2 deletions source/adapters/native_cpu/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ UR_APIEXPORT ur_result_t UR_APICALL
urProgramLink(ur_context_handle_t hContext, uint32_t count,
const ur_program_handle_t *phPrograms, const char *pOptions,
ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
std::ignore = hContext;
std::ignore = count;
std::ignore = phPrograms;
std::ignore = pOptions;
std::ignore = phProgram;

DIE_NO_IMPLEMENTATION
}
Expand All @@ -144,7 +146,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand Down
6 changes: 6 additions & 0 deletions source/adapters/null/ur_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
*phProgram ///< [out] pointer to handle of program object created.
) try {
ur_result_t result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// if the driver has created a custom function, then call it instead of using the generic path
auto pfnLink = d_context.urDdiTable.Program.pfnLink;
Expand Down Expand Up @@ -5728,6 +5731,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
*phProgram ///< [out] pointer to handle of program object created.
) try {
ur_result_t result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// if the driver has created a custom function, then call it instead of using the generic path
auto pfnLinkExp = d_context.urDdiTable.ProgramExp.pfnLinkExp;
Expand Down
12 changes: 11 additions & 1 deletion source/adapters/opencl/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
pOptions, cl_adapter::cast<cl_uint>(count),
cl_adapter::cast<const cl_program *>(phPrograms), nullptr,
nullptr, &CLResult));

if (CL_INVALID_BINARY == CLResult) {
// Some OpenCL drivers incorrectly return CL_INVALID_BINARY here, convert it
// to CL_LINK_PROGRAM_FAILURE
CLResult = CL_LINK_PROGRAM_FAILURE;
}

CL_RETURN_ON_FAILURE(CLResult);

return UR_RESULT_SUCCESS;
Expand All @@ -236,7 +243,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand Down
6 changes: 6 additions & 0 deletions source/loader/layers/tracing/ur_trcddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2486,6 +2486,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
auto pfnLink = context.urDdiTable.Program.pfnLink;

if (nullptr == pfnLink) {
Expand Down Expand Up @@ -7639,6 +7642,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
auto pfnLinkExp = context.urDdiTable.ProgramExp.pfnLinkExp;

if (nullptr == pfnLinkExp) {
Expand Down
6 changes: 6 additions & 0 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
auto pfnLink = context.urDdiTable.Program.pfnLink;

if (nullptr == pfnLink) {
Expand Down Expand Up @@ -9285,6 +9288,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
ur_program_handle_t
*phProgram ///< [out] pointer to handle of program object created.
) {
if (nullptr != phProgram) {
*phProgram = nullptr;
}
auto pfnLinkExp = context.urDdiTable.ProgramExp.pfnLinkExp;

if (nullptr == pfnLinkExp) {
Expand Down
26 changes: 14 additions & 12 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2556,6 +2556,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
*phProgram ///< [out] pointer to handle of program object created.
) {
ur_result_t result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// extract platform's function pointer table
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
Expand All @@ -2578,14 +2581,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink(
result =
pfnLink(hContext, count, phProgramsLocal.data(), pOptions, phProgram);

if (UR_RESULT_SUCCESS != result) {
return result;
}

try {
// convert platform handle to loader handle
*phProgram = reinterpret_cast<ur_program_handle_t>(
ur_program_factory.getInstance(*phProgram, dditable));
if (nullptr != phProgram) {
*phProgram = reinterpret_cast<ur_program_handle_t>(
ur_program_factory.getInstance(*phProgram, dditable));
}
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}
Expand Down Expand Up @@ -7909,6 +7910,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
*phProgram ///< [out] pointer to handle of program object created.
) {
ur_result_t result = UR_RESULT_SUCCESS;
if (nullptr != phProgram) {
*phProgram = nullptr;
}

// extract platform's function pointer table
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
Expand Down Expand Up @@ -7938,14 +7942,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp(
result = pfnLinkExp(hContext, numDevices, phDevicesLocal.data(), count,
phProgramsLocal.data(), pOptions, phProgram);

if (UR_RESULT_SUCCESS != result) {
return result;
}

try {
// convert platform handle to loader handle
*phProgram = reinterpret_cast<ur_program_handle_t>(
ur_program_factory.getInstance(*phProgram, dditable));
if (nullptr != phProgram) {
*phProgram = reinterpret_cast<ur_program_handle_t>(
ur_program_factory.getInstance(*phProgram, dditable));
}
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}
Expand Down
Loading

0 comments on commit 25c75c0

Please sign in to comment.