Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generic math for bindable numbers #6248

Merged
merged 9 commits into from
Apr 22, 2024
51 changes: 26 additions & 25 deletions osu.Framework.Tests/Visual/Bindables/TestSceneBindableNumbers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Globalization;
using System.Numerics;
using osu.Framework.Bindables;
using osu.Framework.Graphics;
using osu.Framework.Graphics.Containers;
Expand Down Expand Up @@ -167,45 +168,45 @@ private void testFractionalPrecision()
private bool checkExact(decimal value) => checkExact(value, value);

private bool checkExact(decimal floatValue, decimal intValue)
=> bindableInt.Value == Convert.ToInt32(intValue)
&& bindableLong.Value == Convert.ToInt64(intValue)
&& bindableFloat.Value == Convert.ToSingle(floatValue)
&& bindableDouble.Value == Convert.ToDouble(floatValue);
=> bindableInt.Value == (int)intValue
&& bindableLong.Value == (long)intValue
&& bindableFloat.Value == (float)floatValue
&& bindableDouble.Value == (double)floatValue;

private void setMin<T>(T value)
private void setMin<T>(T value) where T : INumber<T>
{
bindableInt.MinValue = Convert.ToInt32(value);
bindableLong.MinValue = Convert.ToInt64(value);
bindableFloat.MinValue = Convert.ToSingle(value);
bindableDouble.MinValue = Convert.ToDouble(value);
bindableInt.MinValue = int.CreateTruncating(value);
bindableLong.MinValue = long.CreateTruncating(value);
bindableFloat.MinValue = float.CreateTruncating(value);
bindableDouble.MinValue = double.CreateTruncating(value);
}

private void setMax<T>(T value)
private void setMax<T>(T value) where T : INumber<T>
{
bindableInt.MaxValue = Convert.ToInt32(value);
bindableLong.MaxValue = Convert.ToInt64(value);
bindableFloat.MaxValue = Convert.ToSingle(value);
bindableDouble.MaxValue = Convert.ToDouble(value);
bindableInt.MaxValue = int.CreateTruncating(value);
bindableLong.MaxValue = long.CreateTruncating(value);
bindableFloat.MaxValue = float.CreateTruncating(value);
bindableDouble.MaxValue = double.CreateTruncating(value);
}

private void setValue<T>(T value)
private void setValue<T>(T value) where T : INumber<T>
{
bindableInt.Value = Convert.ToInt32(value);
bindableLong.Value = Convert.ToInt64(value);
bindableFloat.Value = Convert.ToSingle(value);
bindableDouble.Value = Convert.ToDouble(value);
bindableInt.Value = int.CreateTruncating(value);
bindableLong.Value = long.CreateTruncating(value);
bindableFloat.Value = float.CreateTruncating(value);
bindableDouble.Value = double.CreateTruncating(value);
}

private void setPrecision<T>(T precision)
private void setPrecision<T>(T precision) where T : INumber<T>
{
bindableInt.Precision = Convert.ToInt32(precision);
bindableLong.Precision = Convert.ToInt64(precision);
bindableFloat.Precision = Convert.ToSingle(precision);
bindableDouble.Precision = Convert.ToDouble(precision);
bindableInt.Precision = int.CreateTruncating(precision);
bindableLong.Precision = long.CreateTruncating(precision);
bindableFloat.Precision = float.CreateTruncating(precision);
bindableDouble.Precision = double.CreateTruncating(precision);
}

private partial class BindableDisplayContainer<T> : CompositeDrawable
where T : struct, IComparable<T>, IConvertible, IEquatable<T>
where T : struct, INumber<T>, IMinMaxValue<T>, IConvertible
{
public BindableDisplayContainer(BindableNumber<T> bindable)
{
Expand Down
204 changes: 23 additions & 181 deletions osu.Framework/Bindables/BindableNumber.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
#nullable disable

using System;
using System.Diagnostics;
using System.Globalization;
using System.Numerics;
using JetBrains.Annotations;
using osu.Framework.Extensions.TypeExtensions;
using osu.Framework.Utils;

namespace osu.Framework.Bindables
{
public class BindableNumber<T> : RangeConstrainedBindable<T>, IBindableNumber<T>
where T : struct, IComparable<T>, IConvertible, IEquatable<T>
where T : struct, INumber<T>, IMinMaxValue<T>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this can be combination of operator interfaces rather than the meta one, but you shouldn't be happy about that.

{
[CanBeNull]
public event Action<T> PrecisionChanged;
Expand All @@ -40,10 +38,10 @@ public T Precision
get => precision;
set
{
if (precision.Equals(value))
if (precision == value)
return;

if (value.CompareTo(default) <= 0)
if (value <= T.Zero)
throw new ArgumentOutOfRangeException(nameof(Precision), value, "Must be greater than 0.");

SetPrecision(value, true, this);
Expand Down Expand Up @@ -76,102 +74,22 @@ public override T Value

private void setValue(T value)
{
if (Precision.CompareTo(DefaultPrecision) > 0)
if (Precision > DefaultPrecision)
{
// this rounding is purposefully performed on `decimal` to ensure that the resulting value is the closest possible floating-point
// number to actual real-world base-10 decimals, as that is the most common usage of precision.
decimal accurateResult = ClampValue(value, MinValue, MaxValue).ToDecimal(NumberFormatInfo.InvariantInfo);
accurateResult = Math.Round(accurateResult / Precision.ToDecimal(NumberFormatInfo.InvariantInfo)) * Precision.ToDecimal(NumberFormatInfo.InvariantInfo);
decimal accurateResult = decimal.CreateTruncating(T.Clamp(value, MinValue, MaxValue));
accurateResult = Math.Round(accurateResult / decimal.CreateTruncating(Precision)) * decimal.CreateTruncating(Precision);

base.Value = convertFromDecimal(accurateResult);
base.Value = T.CreateTruncating(accurateResult);
}
else
base.Value = value;
}

private T convertFromDecimal(decimal value)
{
if (typeof(T) == typeof(sbyte))
return (T)(object)Convert.ToSByte(value);
if (typeof(T) == typeof(byte))
return (T)(object)Convert.ToByte(value);
if (typeof(T) == typeof(short))
return (T)(object)Convert.ToInt16(value);
if (typeof(T) == typeof(ushort))
return (T)(object)Convert.ToUInt16(value);
if (typeof(T) == typeof(int))
return (T)(object)Convert.ToInt32(value);
if (typeof(T) == typeof(uint))
return (T)(object)Convert.ToUInt32(value);
if (typeof(T) == typeof(long))
return (T)(object)Convert.ToInt64(value);
if (typeof(T) == typeof(ulong))
return (T)(object)Convert.ToUInt64(value);
if (typeof(T) == typeof(float))
return (T)(object)Convert.ToSingle(value);
if (typeof(T) == typeof(double))
return (T)(object)Convert.ToDouble(value);

throw new InvalidCastException($"Cannot convert from decimal to {typeof(T).ReadableName()}");
}
protected override T DefaultMinValue => T.MinValue;

protected override T DefaultMinValue
{
get
{
Debug.Assert(Validation.IsSupportedBindableNumberType<T>());

if (typeof(T) == typeof(sbyte))
return (T)(object)sbyte.MinValue;
if (typeof(T) == typeof(byte))
return (T)(object)byte.MinValue;
if (typeof(T) == typeof(short))
return (T)(object)short.MinValue;
if (typeof(T) == typeof(ushort))
return (T)(object)ushort.MinValue;
if (typeof(T) == typeof(int))
return (T)(object)int.MinValue;
if (typeof(T) == typeof(uint))
return (T)(object)uint.MinValue;
if (typeof(T) == typeof(long))
return (T)(object)long.MinValue;
if (typeof(T) == typeof(ulong))
return (T)(object)ulong.MinValue;
if (typeof(T) == typeof(float))
return (T)(object)float.MinValue;

return (T)(object)double.MinValue;
}
}

protected override T DefaultMaxValue
{
get
{
Debug.Assert(Validation.IsSupportedBindableNumberType<T>());

if (typeof(T) == typeof(sbyte))
return (T)(object)sbyte.MaxValue;
if (typeof(T) == typeof(byte))
return (T)(object)byte.MaxValue;
if (typeof(T) == typeof(short))
return (T)(object)short.MaxValue;
if (typeof(T) == typeof(ushort))
return (T)(object)ushort.MaxValue;
if (typeof(T) == typeof(int))
return (T)(object)int.MaxValue;
if (typeof(T) == typeof(uint))
return (T)(object)uint.MaxValue;
if (typeof(T) == typeof(long))
return (T)(object)long.MaxValue;
if (typeof(T) == typeof(ulong))
return (T)(object)ulong.MaxValue;
if (typeof(T) == typeof(float))
return (T)(object)float.MaxValue;

return (T)(object)double.MaxValue;
}
}
protected override T DefaultMaxValue => T.MaxValue;

/// <summary>
/// The default <see cref="Precision"/>.
Expand All @@ -180,26 +98,12 @@ protected virtual T DefaultPrecision
{
get
{
if (typeof(T) == typeof(sbyte))
return (T)(object)(sbyte)1;
if (typeof(T) == typeof(byte))
return (T)(object)(byte)1;
if (typeof(T) == typeof(short))
return (T)(object)(short)1;
if (typeof(T) == typeof(ushort))
return (T)(object)(ushort)1;
if (typeof(T) == typeof(int))
return (T)(object)1;
if (typeof(T) == typeof(uint))
return (T)(object)1U;
if (typeof(T) == typeof(long))
return (T)(object)1L;
if (typeof(T) == typeof(ulong))
return (T)(object)1UL;
if (typeof(T) == typeof(float))
return (T)(object)float.Epsilon;
if (typeof(T) == typeof(double))
return (T)(object)double.Epsilon;

return (T)(object)double.Epsilon;
return T.One;
Comment on lines 101 to +106
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dotnet/csharplang#6308
Just testing floating-point types here to make things simpler.

}
}

Expand Down Expand Up @@ -249,63 +153,11 @@ public override void UnbindEvents()
typeof(T) != typeof(float) &&
typeof(T) != typeof(double); // Will be **constant** after JIT.

public void Set<TNewValue>(TNewValue val) where TNewValue : struct,
IFormattable, IConvertible, IComparable<TNewValue>, IEquatable<TNewValue>
{
Debug.Assert(Validation.IsSupportedBindableNumberType<T>());

// Comparison between typeof(T) and type literals are treated as **constant** on value types.
// Code paths for other types will be eliminated.
if (typeof(T) == typeof(byte))
((BindableNumber<byte>)(object)this).Value = val.ToByte(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(sbyte))
((BindableNumber<sbyte>)(object)this).Value = val.ToSByte(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(ushort))
((BindableNumber<ushort>)(object)this).Value = val.ToUInt16(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(short))
((BindableNumber<short>)(object)this).Value = val.ToInt16(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(uint))
((BindableNumber<uint>)(object)this).Value = val.ToUInt32(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(int))
((BindableNumber<int>)(object)this).Value = val.ToInt32(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(ulong))
((BindableNumber<ulong>)(object)this).Value = val.ToUInt64(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(long))
((BindableNumber<long>)(object)this).Value = val.ToInt64(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(float))
((BindableNumber<float>)(object)this).Value = val.ToSingle(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(double))
((BindableNumber<double>)(object)this).Value = val.ToDouble(NumberFormatInfo.InvariantInfo);
}
public void Set<TNewValue>(TNewValue val) where TNewValue : struct, INumber<TNewValue>
=> Value = T.CreateTruncating(val);

public void Add<TNewValue>(TNewValue val) where TNewValue : struct,
IFormattable, IConvertible, IComparable<TNewValue>, IEquatable<TNewValue>
{
Debug.Assert(Validation.IsSupportedBindableNumberType<T>());

// Comparison between typeof(T) and type literals are treated as **constant** on value types.
// Code pathes for other types will be eliminated.
if (typeof(T) == typeof(byte))
((BindableNumber<byte>)(object)this).Value += val.ToByte(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(sbyte))
((BindableNumber<sbyte>)(object)this).Value += val.ToSByte(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(ushort))
((BindableNumber<ushort>)(object)this).Value += val.ToUInt16(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(short))
((BindableNumber<short>)(object)this).Value += val.ToInt16(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(uint))
((BindableNumber<uint>)(object)this).Value += val.ToUInt32(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(int))
((BindableNumber<int>)(object)this).Value += val.ToInt32(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(ulong))
((BindableNumber<ulong>)(object)this).Value += val.ToUInt64(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(long))
((BindableNumber<long>)(object)this).Value += val.ToInt64(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(float))
((BindableNumber<float>)(object)this).Value += val.ToSingle(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(double))
((BindableNumber<double>)(object)this).Value += val.ToDouble(NumberFormatInfo.InvariantInfo);
}
public void Add<TNewValue>(TNewValue val) where TNewValue : struct, INumber<TNewValue>
=> Value += T.CreateTruncating(val);

/// <summary>
/// Sets the value of the bindable to Min + (Max - Min) * amt
Expand All @@ -314,8 +166,10 @@ public void Add<TNewValue>(TNewValue val) where TNewValue : struct,
/// </summary>
public void SetProportional(float amt, float snap = 0)
{
double min = MinValue.ToDouble(NumberFormatInfo.InvariantInfo);
double max = MaxValue.ToDouble(NumberFormatInfo.InvariantInfo);
// TODO: Use IFloatingPointIeee754<T>.Lerp when applicable

double min = double.CreateTruncating(MinValue);
double max = double.CreateTruncating(MaxValue);
double value = min + (max - min) * amt;
if (snap > 0)
value = Math.Round(value / snap) * snap;
Expand Down Expand Up @@ -350,20 +204,8 @@ public override bool IsDefault

protected override Bindable<T> CreateInstance() => new BindableNumber<T>();

protected sealed override T ClampValue(T value, T minValue, T maxValue) => max(minValue, min(maxValue, value));

protected sealed override bool IsValidRange(T min, T max) => min.CompareTo(max) <= 0;
protected sealed override T ClampValue(T value, T minValue, T maxValue) => T.Clamp(value, minValue, maxValue);

private static T max(T value1, T value2)
{
int comparison = value1.CompareTo(value2);
return comparison > 0 ? value1 : value2;
}

private static T min(T value1, T value2)
{
int comparison = value1.CompareTo(value2);
return comparison > 0 ? value2 : value1;
}
protected sealed override bool IsValidRange(T min, T max) => min <= max;
}
}
3 changes: 2 additions & 1 deletion osu.Framework/Bindables/BindableNumberWithCurrent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#nullable disable

using System;
using System.Numerics;

namespace osu.Framework.Bindables
{
Expand All @@ -12,7 +13,7 @@ namespace osu.Framework.Bindables
/// </summary>
/// <typeparam name="T">The type of our stored <see cref="Bindable{T}.Value"/>.</typeparam>
public class BindableNumberWithCurrent<T> : BindableNumber<T>, IBindableWithCurrent<T>
where T : struct, IComparable<T>, IConvertible, IEquatable<T>
where T : struct, INumber<T>, IMinMaxValue<T>
{
private BindableNumber<T> currentBound;

Expand Down
6 changes: 4 additions & 2 deletions osu.Framework/Bindables/RangeConstrainedBindable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ public override void CopyTo(Bindable<T> them)
// as Value assignment (in the base call below) automatically clamps to [MinValue, MaxValue].
if (them is RangeConstrainedBindable<T> other)
{
other.MinValue = MinValue;
other.MaxValue = MaxValue;
// copy the bounds over without updating the current value, to avoid clamping on invalid ranges.
// there is no need to clamp `Value` after that directly - the `base.CopyTo()` call will change `Value` anyway.
other.SetMinValue(MinValue, false, this);
other.SetMaxValue(MaxValue, false, this);
}

base.CopyTo(them);
Expand Down
Loading
Loading