Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow CachedModelDependencyContainer to cache models with non-bindable fields #6416

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

#nullable disable

using System;
using System.Diagnostics.CodeAnalysis;
using NUnit.Framework;
using osu.Framework.Allocation;
Expand All @@ -15,24 +12,6 @@ namespace osu.Framework.Tests.Dependencies.Reflection
[SuppressMessage("Performance", "OFSG001:Class contributes to dependency injection and should be partial")]
public class CachedModelDependenciesTest
{
[Test]
public void TestModelWithNonBindableFieldsFails()
{
IReadOnlyDependencyContainer unused;

Assert.Throws<TypeInitializationException>(() => unused = new CachedModelDependencyContainer<NonBindablePublicFieldModel>(null));
Assert.Throws<TypeInitializationException>(() => unused = new CachedModelDependencyContainer<NonBindablePrivateFieldModel>(null));
}

[Test]
public void TestModelWithNonReadOnlyFieldsFails()
{
IReadOnlyDependencyContainer unused;

Assert.Throws<TypeInitializationException>(() => unused = new CachedModelDependencyContainer<NonReadOnlyFieldModel>(null));
Assert.Throws<TypeInitializationException>(() => unused = new CachedModelDependencyContainer<PropertyModel>(null));
}

[Test]
public void TestSettingNoModelResolvesDefault()
{
Expand Down Expand Up @@ -195,7 +174,7 @@ public void TestSetModelToNullAfterResolved()

var model = new FieldModel { Bindable = { Value = 2 } };

var dependencies = new CachedModelDependencyContainer<FieldModel>(null)
var dependencies = new CachedModelDependencyContainer<FieldModel?>(null)
{
Model = { Value = model }
};
Expand Down Expand Up @@ -248,7 +227,7 @@ public void TestResolveIndividualProperties()
BindableString = { Value = "3" }
};

var dependencies = new CachedModelDependencyContainer<DerivedFieldModel>(null)
var dependencies = new CachedModelDependencyContainer<DerivedFieldModel?>(null)
{
Model = { Value = model1 }
};
Expand All @@ -269,33 +248,6 @@ public void TestResolveIndividualProperties()
Assert.AreEqual(null, resolver.BindableString.Value);
}

private class NonBindablePublicFieldModel : IDependencyInjectionCandidate
{
#pragma warning disable 649
public readonly int FailingField;
#pragma warning restore 649
}

private class NonBindablePrivateFieldModel : IDependencyInjectionCandidate
{
#pragma warning disable 169
private readonly int failingField;
#pragma warning restore 169
}

private class NonReadOnlyFieldModel : IDependencyInjectionCandidate
{
#pragma warning disable 649
public Bindable<int> Bindable;
#pragma warning restore 649
}

private class PropertyModel : IDependencyInjectionCandidate
{
// ReSharper disable once UnusedMember.Local
public Bindable<int> Bindable { get; private set; }
}

private class FieldModel : IDependencyInjectionCandidate
{
[Cached]
Expand All @@ -311,22 +263,22 @@ private class DerivedFieldModel : FieldModel
private class FieldModelResolver : IDependencyInjectionCandidate
{
[Resolved]
public FieldModel Model { get; private set; }
public FieldModel Model { get; private set; } = null!;
}

private class DerivedFieldModelResolver : IDependencyInjectionCandidate
{
[Resolved]
public DerivedFieldModel Model { get; private set; }
public DerivedFieldModel Model { get; private set; } = null!;
}

private class DerivedFieldModelPropertyResolver : IDependencyInjectionCandidate
{
[Resolved(typeof(DerivedFieldModel))]
public Bindable<int> Bindable { get; private set; }
public Bindable<int> Bindable { get; private set; } = null!;

[Resolved(typeof(DerivedFieldModel))]
public Bindable<string> BindableString { get; private set; }
public Bindable<string> BindableString { get; private set; } = null!;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

#nullable disable

using System;
using NUnit.Framework;
using osu.Framework.Allocation;
using osu.Framework.Bindables;
Expand All @@ -13,24 +10,6 @@ namespace osu.Framework.Tests.Dependencies.SourceGeneration
[TestFixture]
public partial class CachedModelDependenciesTest
{
[Test]
public void TestModelWithNonBindableFieldsFails()
{
IReadOnlyDependencyContainer unused;

Assert.Throws<TypeInitializationException>(() => unused = new CachedModelDependencyContainer<NonBindablePublicFieldModel>(null));
Assert.Throws<TypeInitializationException>(() => unused = new CachedModelDependencyContainer<NonBindablePrivateFieldModel>(null));
}

[Test]
public void TestModelWithNonReadOnlyFieldsFails()
{
IReadOnlyDependencyContainer unused;

Assert.Throws<TypeInitializationException>(() => unused = new CachedModelDependencyContainer<NonReadOnlyFieldModel>(null));
Assert.Throws<TypeInitializationException>(() => unused = new CachedModelDependencyContainer<PropertyModel>(null));
}

[Test]
public void TestSettingNoModelResolvesDefault()
{
Expand Down Expand Up @@ -193,7 +172,7 @@ public void TestSetModelToNullAfterResolved()

var model = new FieldModel { Bindable = { Value = 2 } };

var dependencies = new CachedModelDependencyContainer<FieldModel>(null)
var dependencies = new CachedModelDependencyContainer<FieldModel?>(null)
{
Model = { Value = model }
};
Expand Down Expand Up @@ -246,7 +225,7 @@ public void TestResolveIndividualProperties()
BindableString = { Value = "3" }
};

var dependencies = new CachedModelDependencyContainer<DerivedFieldModel>(null)
var dependencies = new CachedModelDependencyContainer<DerivedFieldModel?>(null)
{
Model = { Value = model1 }
};
Expand All @@ -267,33 +246,6 @@ public void TestResolveIndividualProperties()
Assert.AreEqual(null, resolver.BindableString.Value);
}

private partial class NonBindablePublicFieldModel : IDependencyInjectionCandidate
{
#pragma warning disable 649
public readonly int FailingField;
#pragma warning restore 649
}

private partial class NonBindablePrivateFieldModel : IDependencyInjectionCandidate
{
#pragma warning disable 169
private readonly int failingField;
#pragma warning restore 169
}

private partial class NonReadOnlyFieldModel : IDependencyInjectionCandidate
{
#pragma warning disable 649
public Bindable<int> Bindable;
#pragma warning restore 649
}

private partial class PropertyModel : IDependencyInjectionCandidate
{
// ReSharper disable once UnusedMember.Local
public Bindable<int> Bindable { get; private set; }
}

private partial class FieldModel : IDependencyInjectionCandidate
{
[Cached]
Expand All @@ -309,22 +261,22 @@ private partial class DerivedFieldModel : FieldModel
private partial class FieldModelResolver : IDependencyInjectionCandidate
{
[Resolved]
public FieldModel Model { get; private set; }
public FieldModel Model { get; private set; } = null!;
}

private partial class DerivedFieldModelResolver : IDependencyInjectionCandidate
{
[Resolved]
public DerivedFieldModel Model { get; private set; }
public DerivedFieldModel Model { get; private set; } = null!;
}

private partial class DerivedFieldModelPropertyResolver : IDependencyInjectionCandidate
{
[Resolved(typeof(DerivedFieldModel))]
public Bindable<int> Bindable { get; private set; }
public Bindable<int> Bindable { get; private set; } = null!;

[Resolved(typeof(DerivedFieldModel))]
public Bindable<string> BindableString { get; private set; }
public Bindable<string> BindableString { get; private set; } = null!;
}
}
}
69 changes: 22 additions & 47 deletions osu.Framework/Allocation/CachedModelDependencyContainer.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

#nullable disable

using System;
using System.Reflection;
using osu.Framework.Bindables;
Expand All @@ -19,7 +17,7 @@ namespace osu.Framework.Allocation
/// </remarks>
/// <typeparam name="TModel">The type of the model to cache. Must contain only <see cref="Bindable{T}"/> fields or auto-properties.</typeparam>
public class CachedModelDependencyContainer<TModel> : IReadOnlyDependencyContainer
where TModel : class, IDependencyInjectionCandidate, new()
where TModel : class?, IDependencyInjectionCandidate?, new()
{
private const BindingFlags activator_flags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.DeclaredOnly;

Expand All @@ -32,17 +30,16 @@ public class CachedModelDependencyContainer<TModel> : IReadOnlyDependencyContain
public readonly Bindable<TModel> Model = new Bindable<TModel>();

private readonly TModel shadowModel = new TModel();

private readonly IReadOnlyDependencyContainer parent;
private readonly IReadOnlyDependencyContainer? parent;
private readonly IReadOnlyDependencyContainer shadowDependencies;

public CachedModelDependencyContainer(IReadOnlyDependencyContainer parent)
public CachedModelDependencyContainer(IReadOnlyDependencyContainer? parent)
{
this.parent = parent;

shadowDependencies = DependencyActivator.MergeDependencies(shadowModel, null, new CacheInfo(parent: typeof(TModel)));

TModel currentModel = null;
TModel? currentModel = null;
Model.BindValueChanged(e =>
{
// When setting a null model, we actually want to reset the shadow model to a default state
Expand All @@ -55,9 +52,9 @@ public CachedModelDependencyContainer(IReadOnlyDependencyContainer parent)
});
}

public object Get(Type type) => Get(type, default);
public object? Get(Type type) => Get(type, default);

public object Get(Type type, CacheInfo info)
public object? Get(Type type, CacheInfo info)
{
if (info.Parent == null)
return type == typeof(TModel) ? createChildShadowModel() : parent?.Get(type, info);
Expand Down Expand Up @@ -87,65 +84,43 @@ private TModel createChildShadowModel()
/// <param name="targetShadowModel">The shadow model to update.</param>
/// <param name="lastModel">The model to unbind from.</param>
/// <param name="newModel">The model to bind to.</param>
private void updateShadowModel(TModel targetShadowModel, TModel lastModel, TModel newModel)
private void updateShadowModel(TModel targetShadowModel, TModel? lastModel, TModel newModel)
{
// Due to static-constructor checks, we are guaranteed that all fields will be IBindable

foreach (var type in typeof(TModel).EnumerateBaseTypes())
if (lastModel != null)
{
foreach (var field in type.GetFields(activator_flags))
foreach (var type in typeof(TModel).EnumerateBaseTypes())
{
perform(targetShadowModel, field, lastModel, (shadowProp, modelProp) => shadowProp.UnbindFrom(modelProp));
foreach (var field in type.GetFields(activator_flags))
perform(field, targetShadowModel, lastModel, (shadowProp, modelProp) => shadowProp.UnbindFrom(modelProp));
}
}

foreach (var type in typeof(TModel).EnumerateBaseTypes())
{
foreach (var field in type.GetFields(activator_flags))
{
perform(targetShadowModel, field, newModel, (shadowProp, modelProp) => shadowProp.BindTo(modelProp));
}
perform(field, targetShadowModel, newModel, (shadowProp, modelProp) => shadowProp.BindTo(modelProp));
}
}

/// <summary>
/// Perform an arbitrary action across a shadow model and model.
/// </summary>
private void perform(TModel targetShadowModel, MemberInfo member, TModel target, Action<IBindable, IBindable> action)
private static void perform(FieldInfo field, TModel shadowModel, TModel targetModel, Action<IBindable, IBindable> action)
{
if (target == null) return;
IBindable? shadowBindable = null;
IBindable? targetBindable = null;

switch (member)
try
{
case PropertyInfo pi:
action((IBindable)pi.GetValue(targetShadowModel), (IBindable)pi.GetValue(target));
break;

case FieldInfo fi:
action((IBindable)fi.GetValue(targetShadowModel), (IBindable)fi.GetValue(target));
break;
shadowBindable = field.GetValue(shadowModel) as IBindable;
targetBindable = field.GetValue(targetModel) as IBindable;
}
}

static CachedModelDependencyContainer()
{
foreach (var type in typeof(TModel).EnumerateBaseTypes())
catch
{
foreach (var field in type.GetFields(activator_flags))
{
if (!typeof(IBindable).IsAssignableFrom(field.FieldType))
{
throw new InvalidOperationException($"\"{field.DeclaringType}.{field.Name}\" does not subclass {nameof(IBindable)}. "
+ $"All fields of {typeof(TModel)} must subclass {nameof(IBindable)} to be used in a {nameof(CachedModelDependencyContainer<TModel>)}.");
}

if (!field.IsInitOnly)
{
throw new InvalidOperationException($"\"{field.DeclaringType}.{field.Name}\" is not readonly. "
+ $"All fields of {typeof(TModel)} must be readonly to be used in a {nameof(CachedModelDependencyContainer<TModel>)}.");
}
}
}

if (shadowBindable != null && targetBindable != null)
action(shadowBindable, targetBindable);
}
}
}
Loading