Skip to content

Commit

Permalink
__resultRef implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kohanis committed Feb 23, 2024
1 parent 2a09e7a commit 83bd183
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 14 deletions.
4 changes: 4 additions & 0 deletions Harmony/Documentation/articles/patching-injections.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Patches can use an argument called **`__instance`** to access the instance value

Patches can use an argument called **`__result`** to access the returned value. The type must match the return type of the original or be assignable from it. For prefixes, as the original method hasn't run yet, the value of `__result` is the default for that type. For most reference types, that would be `null`. If you wish to **alter** the `__result`, you need to define it **by reference** like `ref string name`.

### __resultRef

Patches can use an argument called **`__resultRef`** to alter the "**ref return**" reference itself. The type must be `RefResult<T>` by reference, where `T` must match the return type of the original, without `ref` modifier. For example `ref RefResult<string> __resultRef`.

### __state

Patches can use an argument called **`__state`** to store information in the prefix method that can be accessed again in the postfix method. Think of it as a local variable. It can be any type and you are responsible to initialize its value in the prefix. **Note:** It only works if both patches are defined in the same class.
Expand Down
5 changes: 5 additions & 0 deletions Harmony/Extras/RefResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace HarmonyLib;

/// <summary>Delegate type for "ref return" injections</summary>
/// <typeparam name="T">Return type of the original method, without ref modifier</typeparam>
public delegate ref T RefResult<T>();
98 changes: 88 additions & 10 deletions Harmony/Internal/MethodPatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ internal class MethodPatcher
const string ORIGINAL_METHOD_PARAM = "__originalMethod";
const string ARGS_ARRAY_VAR = "__args";
const string RESULT_VAR = "__result";
const string RESULT_REF_VAR = "__resultRef";
const string STATE_VAR = "__state";
const string EXCEPTION_VAR = "__exception";
const string RUN_ORIGINAL_VAR = "__runOriginal";
Expand Down Expand Up @@ -76,6 +77,19 @@ internal MethodInfo CreateReplacement(out Dictionary<int, CodeInstruction> final
privateVars[RESULT_VAR] = resultVariable;
}

if (fixes.Any(fix => fix.GetParameters().Any(p => p.Name == RESULT_REF_VAR)))
{
if(returnType.IsByRef)
{
var resultRefVariable = il.DeclareLocal(
typeof(RefResult<>).MakeGenericType(returnType.GetElementType())
);
emitter.Emit(OpCodes.Ldnull);
emitter.Emit(OpCodes.Stloc, resultRefVariable);
privateVars[RESULT_REF_VAR] = resultRefVariable;
}
}

LocalBuilder argsArrayVariable = null;
if (fixes.Any(fix => fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR)))
{
Expand Down Expand Up @@ -432,10 +446,11 @@ bool EmitOriginalBaseMethod()
return true;
}

void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variables, LocalBuilder runOriginalVariable, bool allowFirsParamPassthrough, out LocalBuilder tmpInstanceBoxingVar, out LocalBuilder tmpObjectVar, List<KeyValuePair<LocalBuilder, Type>> tmpBoxVars)
void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variables, LocalBuilder runOriginalVariable, bool allowFirsParamPassthrough, out LocalBuilder tmpInstanceBoxingVar, out LocalBuilder tmpObjectVar, out bool refResultUsed, List<KeyValuePair<LocalBuilder, Type>> tmpBoxVars)
{
tmpInstanceBoxingVar = null;
tmpObjectVar = null;
refResultUsed = false;

var isInstance = original.IsStatic is false;
var originalParameters = original.GetParameters();
Expand Down Expand Up @@ -474,10 +489,10 @@ void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variab
else
{
var paramType = patchParam.ParameterType;

var parameterIsRef = paramType.IsByRef;
var parameterIsObject = paramType == typeof(object) || paramType == typeof(object).MakeByRefType();

if (AccessTools.IsStruct(originalType))
{
if (parameterIsObject)
Expand Down Expand Up @@ -571,7 +586,6 @@ void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variab
// treat __result var special
if (patchParam.Name == RESULT_VAR)
{
var returnType = AccessTools.GetReturnedType(original);
if (returnType == typeof(void))
throw new Exception($"Cannot get result from void method {original.FullDescription()}");
var resultType = patchParam.ParameterType;
Expand All @@ -597,6 +611,25 @@ void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variab
continue;
}

// treat __resultRef delegate special
if (patchParam.Name == RESULT_REF_VAR)
{
if (!returnType.IsByRef)
throw new Exception(
$"Cannot use {RESULT_REF_VAR} with non-ref return type {returnType.FullName} of method {original.FullDescription()}");

var resultType = patchParam.ParameterType;
var expectedTypeRef = typeof(RefResult<>).MakeGenericType(returnType.GetElementType()).MakeByRefType();
if (resultType != expectedTypeRef)
throw new Exception(
$"Wrong type of {RESULT_REF_VAR} for method {original.FullDescription()}. Expected {expectedTypeRef.FullName}, got {resultType.FullName}");

emitter.Emit(OpCodes.Ldloca, variables[RESULT_REF_VAR]);

refResultUsed = true;
continue;
}

// any other declared variables
if (variables.TryGetValue(patchParam.Name, out var localBuilder))
{
Expand Down Expand Up @@ -763,7 +796,7 @@ void AddPrefixes(Dictionary<string, LocalBuilder> variables, LocalBuilder runOri
}

var tmpBoxVars = new List<KeyValuePair<LocalBuilder, Type>>();
EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars);
EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars);
emitter.Emit(OpCodes.Call, fix);
if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR))
RestoreArgumentArray(variables);
Expand All @@ -774,7 +807,22 @@ void AddPrefixes(Dictionary<string, LocalBuilder> variables, LocalBuilder runOri
emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType);
emitter.Emit(OpCodes.Stobj, original.DeclaringType);
}
if (tmpObjectVar != null)
if (refResultUsed)
{
var label = il.DefineLabel();
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
emitter.Emit(OpCodes.Brfalse_S, label);

emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke"));
emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]);
emitter.Emit(OpCodes.Ldnull);
emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]);

emitter.MarkLabel(label);
emitter.Emit(OpCodes.Nop);
}
else if (tmpObjectVar != null)
{
emitter.Emit(OpCodes.Ldloc, tmpObjectVar);
emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original));
Expand Down Expand Up @@ -815,7 +863,7 @@ bool AddPostfixes(Dictionary<string, LocalBuilder> variables, LocalBuilder runOr
// throw new Exception("Methods without body cannot have postfixes. Use a transpiler instead.");

var tmpBoxVars = new List<KeyValuePair<LocalBuilder, Type>>();
EmitCallParameter(fix, variables, runOriginalVariable, true, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars);
EmitCallParameter(fix, variables, runOriginalVariable, true, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars);
emitter.Emit(OpCodes.Call, fix);
if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR))
RestoreArgumentArray(variables);
Expand All @@ -826,7 +874,22 @@ bool AddPostfixes(Dictionary<string, LocalBuilder> variables, LocalBuilder runOr
emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType);
emitter.Emit(OpCodes.Stobj, original.DeclaringType);
}
if (tmpObjectVar != null)
if (refResultUsed)
{
var label = il.DefineLabel();
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
emitter.Emit(OpCodes.Brfalse_S, label);

emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke"));
emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]);
emitter.Emit(OpCodes.Ldnull);
emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]);

emitter.MarkLabel(label);
emitter.Emit(OpCodes.Nop);
}
else if (tmpObjectVar != null)
{
emitter.Emit(OpCodes.Ldloc, tmpObjectVar);
emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original));
Expand Down Expand Up @@ -871,7 +934,7 @@ bool AddFinalizers(Dictionary<string, LocalBuilder> variables, LocalBuilder runO
emitter.MarkBlockBefore(new ExceptionBlock(ExceptionBlockType.BeginExceptionBlock), out var label);

var tmpBoxVars = new List<KeyValuePair<LocalBuilder, Type>>();
EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars);
EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars);
emitter.Emit(OpCodes.Call, fix);
if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR))
RestoreArgumentArray(variables);
Expand All @@ -882,7 +945,22 @@ bool AddFinalizers(Dictionary<string, LocalBuilder> variables, LocalBuilder runO
emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType);
emitter.Emit(OpCodes.Stobj, original.DeclaringType);
}
if (tmpObjectVar != null)
if (refResultUsed)
{
var label = il.DefineLabel();
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
emitter.Emit(OpCodes.Brfalse_S, label);

emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke"));
emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]);
emitter.Emit(OpCodes.Ldnull);
emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]);

emitter.MarkLabel(label);
emitter.Emit(OpCodes.Nop);
}
else if (tmpObjectVar != null)
{
emitter.Emit(OpCodes.Ldloc, tmpObjectVar);
emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original));
Expand Down
64 changes: 64 additions & 0 deletions HarmonyTests/Patching/Assets/Specials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,70 @@ public static void ResetTest()

// -----------------------------------------------------

public class ResultRefStruct
{
// ReSharper disable FieldCanBeMadeReadOnly.Global
public static int[] numbersPrefix = [0, 0];
public static int[] numbersPostfix = [0, 0];
public static int[] numbersPostfixWithNull = [0];
public static int[] numbersFinalizer = [0];
public static int[] numbersMixed = [0, 0];
// ReSharper restore FieldCanBeMadeReadOnly.Global

[MethodImpl(MethodImplOptions.NoInlining)]
public ref int ToPrefix() => ref numbersPrefix[0];

[MethodImpl(MethodImplOptions.NoInlining)]
public ref int ToPostfix() => ref numbersPostfix[0];

[MethodImpl(MethodImplOptions.NoInlining)]
public ref int ToPostfixWithNull() => ref numbersPostfixWithNull[0];

[MethodImpl(MethodImplOptions.NoInlining)]
public ref int ToFinalizer() => throw new Exception();

[MethodImpl(MethodImplOptions.NoInlining)]
public ref int ToMixed() => ref numbersMixed[0];
}

[HarmonyPatch(typeof(ResultRefStruct))]
public class ResultRefStruct_Patch
{
[HarmonyPatch(nameof(ResultRefStruct.ToPrefix))]
[HarmonyPrefix]
public static bool Prefix(ref RefResult<int> __resultRef)
{
__resultRef = () => ref ResultRefStruct.numbersPrefix[1];
return false;
}

[HarmonyPatch(nameof(ResultRefStruct.ToPostfix))]
[HarmonyPostfix]
public static void Postfix(ref RefResult<int> __resultRef) => __resultRef = () => ref ResultRefStruct.numbersPostfix[1];

[HarmonyPatch(nameof(ResultRefStruct.ToPostfixWithNull))]
[HarmonyPostfix]
public static void PostfixWithNull(ref RefResult<int> __resultRef) => __resultRef = null;

[HarmonyPatch(nameof(ResultRefStruct.ToFinalizer))]
[HarmonyFinalizer]
public static Exception Finalizer(ref RefResult<int> __resultRef)
{
__resultRef = () => ref ResultRefStruct.numbersFinalizer[0];
return null;
}

[HarmonyPatch(nameof(ResultRefStruct.ToMixed))]
[HarmonyPostfix]
public static void PostfixMixed(ref int __result, ref RefResult<int> __resultRef)
{
__result = 42;
__resultRef = () => ref ResultRefStruct.numbersMixed[1];
}
}

// -----------------------------------------------------

public class DeadEndCode
{
[MethodImpl(MethodImplOptions.NoInlining)]
Expand Down
51 changes: 47 additions & 4 deletions HarmonyTests/Patching/Specials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,49 @@ public void Test_HttpWebRequestGetResponse()
Assert.True(HttpWebRequestPatches.postfixCalled, "Postfix not called");
}

[Test]
public void Test_PatchResultRef()
{
ResultRefStruct.numbersPrefix = [0, 0];
ResultRefStruct.numbersPostfix = [0, 0];
ResultRefStruct.numbersPostfixWithNull = [0];
ResultRefStruct.numbersFinalizer = [0];
ResultRefStruct.numbersMixed = [0, 0];

var test = new ResultRefStruct();

var instance = new Harmony("result-ref-test");
Assert.NotNull(instance);
var processor = instance.CreateClassProcessor(typeof(ResultRefStruct_Patch));
Assert.NotNull(processor, "processor");

test.ToPrefix() = 1;
test.ToPostfix() = 2;
test.ToPostfixWithNull() = 3;
test.ToMixed() = 5;

Assert.AreEqual(new[] { 1, 0 }, ResultRefStruct.numbersPrefix);
Assert.AreEqual(new[] { 2, 0 }, ResultRefStruct.numbersPostfix);
Assert.AreEqual(new[] { 3 }, ResultRefStruct.numbersPostfixWithNull);
Assert.Throws<Exception>(() => test.ToFinalizer(), "ToFinalizer method does not throw");
Assert.AreEqual(new[] { 5, 0 }, ResultRefStruct.numbersMixed);

var replacements = processor.Patch();
Assert.NotNull(replacements, "replacements");

test.ToPrefix() = -1;
test.ToPostfix() = -2;
test.ToPostfixWithNull() = -3;
test.ToFinalizer() = -4;
test.ToMixed() = -5;

Assert.AreEqual(new[] { 1, -1 }, ResultRefStruct.numbersPrefix);
Assert.AreEqual(new[] { 2, -2 }, ResultRefStruct.numbersPostfix);
Assert.AreEqual(new[] { -3 }, ResultRefStruct.numbersPostfixWithNull);
Assert.AreEqual(new[] { -4 }, ResultRefStruct.numbersFinalizer);
Assert.AreEqual(new[] { 42, -5 }, ResultRefStruct.numbersMixed);
}

[Test]
public void Test_Patch_ConcreteClass()
{
Expand Down Expand Up @@ -327,7 +370,7 @@ public void Test_PatchExternalMethod()
Assert.NotNull(patcher, "Patch processor");
_ = patcher.Patch();
}

[Test]
public void Test_PatchEventHandler()
{
Expand All @@ -348,7 +391,7 @@ public void Test_PatchEventHandler()
new EventHandlerTestClass().Run();
Console.WriteLine($"### EventHandlerTestClass AFTER");
}

[Test]
public void Test_PatchMarshalledClass()
{
Expand All @@ -369,7 +412,7 @@ public void Test_PatchMarshalledClass()
new MarshalledTestClass().Run();
Console.WriteLine($"### MarshalledTestClass AFTER");
}

[Test]
public void Test_MarshalledWithEventHandler1()
{
Expand All @@ -390,7 +433,7 @@ public void Test_MarshalledWithEventHandler1()
new MarshalledWithEventHandlerTest1Class().Run();
Console.WriteLine($"### MarshalledWithEventHandlerTest1 AFTER");
}

[Test]
public void Test_MarshalledWithEventHandler2()
{
Expand Down

0 comments on commit 83bd183

Please sign in to comment.