Skip to content

Commit

Permalink
Ranker train context and FastTree ranking xtensions (#1068)
Browse files Browse the repository at this point in the history
* Adding the Ranker TrainContext, the Ranker TrainerEstimatorReconcilier, and an Evaluate method + metrics class to the existing RankerEvaluator.
* Adding the FastTree ranking xtension and test.
* Grouping the xtensions in classes with more meaningful names, since the docs site displays the methods per class, not file.
  • Loading branch information
sfilipi authored Sep 29, 2018
1 parent 95f5f27 commit 3cdd3c8
Show file tree
Hide file tree
Showing 12 changed files with 398 additions and 84 deletions.
Binary file added docs/images/DCG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/NDCG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,5 +200,40 @@ public static RegressionEvaluator.Result Evaluate<T>(
args.LossFunction = new TrivialRegressionLossFactory(loss);
return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName);
}

/// <summary>
/// Evaluates scored ranking data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <typeparam name="TVal">The type of data, before being converted to a key.</typeparam>
/// <param name="ctx">The ranking context.</param>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="groupId">The index delegate for the groupId column. </param>
/// <param name="score">The index delegate for predicted score column.</param>
/// <returns>The evaluation metrics.</returns>
public static RankerEvaluator.Result Evaluate<T, TVal>(
this RankerContext ctx,
DataView<T> data,
Func<T, Scalar<float>> label,
Func<T, Key<uint, TVal>> groupId,
Func<T, Scalar<float>> score)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(groupId, nameof(groupId));
env.CheckValue(score, nameof(score));

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
string scoreName = indexer.Get(score(indexer.Indices));
string groupIdName = indexer.Get(groupId(indexer.Indices));

var args = new RankerEvaluator.Arguments() { };

return new RankerEvaluator(env, args).Evaluate(data.AsDynamic, labelName, groupIdName, scoreName);
}
}
}
74 changes: 71 additions & 3 deletions src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ public sealed class Arguments
public bool OutputGroupSummary;
}

public const string LoadName = "RankingEvaluator";
internal const string LoadName = "RankingEvaluator";

public const string Ndcg = "NDCG";
public const string Dcg = "DCG";
public const string MaxDcg = "MaxDCG";

/// <summary>
/// <value>
/// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group.
/// It contains four columns: GroupId, NDCG, DCG and MaxDCG. Each row in the data view corresponds to one
/// group in the scored data.
/// </summary>
/// </value>
public const string GroupSummary = "GroupSummary";

private const string GroupId = "GroupId";
Expand Down Expand Up @@ -234,6 +234,40 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A
};
}

/// <summary>
/// Evaluates scored regression data.
/// </summary>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The name of the label column.</param>
/// <param name="groupId">The name of the groupId column.</param>
/// <param name="score">The name of the predicted score column.</param>
/// <returns>The evaluation metrics for these outputs.</returns>
public Result Evaluate(IDataView data, string label, string groupId, string score)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
var roles = new RoleMappedData(data, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(label),
RoleMappedSchema.ColumnRole.Group.Bind(groupId),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score));

var resultDict = Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];

Result result;
using (var cursor = overall.GetRowCursor(i => true))
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new Result(Host, cursor);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
return result;
}

public sealed class Aggregator : AggregatorBase
{
public sealed class Counters
Expand Down Expand Up @@ -509,6 +543,40 @@ public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> slotNames)
slotNames = new VBuffer<ReadOnlyMemory<char>>(UnweightedCounters.TruncationLevel, values);
}
}

public sealed class Result
{
/// <summary>
/// Normalized Discounted Cumulative Gain
/// <a href="https://github.com/dotnet/machinelearning/tree/master/docs/images/ndcg.png"></a>
/// </summary>
public double[] Ndcg { get; }

/// <summary>
/// <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain">Discounted Cumulative gain</a>
/// is the sum of the gains, for all the instances i, normalized by the natural logarithm of the instance + 1.
/// Note that unline the Wikipedia article, ML.Net uses the natural logarithm.
/// <a href="https://github.com/dotnet/machinelearning/tree/master/docs/images/dcg.png"></a>
/// </summary>
public double[] Dcg { get; }

private static T Fetch<T>(IExceptionContext ectx, IRow row, string name)
{
if (!row.Schema.TryGetColumnIndex(name, out int col))
throw ectx.Except($"Could not find column '{name}'");
T val = default;
row.GetGetter<T>(col)(ref val);
return val;
}

internal Result(IExceptionContext ectx, IRow overallResult)
{
VBuffer<double> Fetch(string name) => Fetch<VBuffer<double>>(ectx, overallResult, name);

Dcg = Fetch(RankerEvaluator.Dcg).Values;
Ndcg = Fetch(RankerEvaluator.Ndcg).Values;
}
}
}

public sealed class RankerPerInstanceTransform : IDataTransform
Expand Down
64 changes: 64 additions & 0 deletions src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -470,5 +470,69 @@ public ImplScore(MulticlassClassifier<TVal> rec) : base(rec, rec.Inputs) { }
}
}

/// <summary>
/// A reconciler for ranking capable of handling the most common cases for ranking.
/// </summary>
public sealed class Ranker<TVal> : TrainerEstimatorReconciler
{
/// <summary>
/// The delegate to create the ranking trainer instance.
/// </summary>
/// <param name="env">The environment with which to create the estimator</param>
/// <param name="label">The label column name</param>
/// <param name="features">The features column name</param>
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
/// <param name="groupId">The groupId column name.</param>
/// <returns>A estimator producing columns with the fixed name <see cref="DefaultColumnNames.Score"/>.</returns>
public delegate IEstimator<ITransformer> EstimatorFactory(IHostEnvironment env, string label, string features, string weights, string groupId);

private readonly EstimatorFactory _estFact;

/// <summary>
/// The output score column for ranking. This will have this instance as its reconciler.
/// </summary>
public Scalar<float> Score { get; }

protected override IEnumerable<PipelineColumn> Outputs => Enumerable.Repeat(Score, 1);

private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score };

/// <summary>
/// Constructs a new general ranker reconciler.
/// </summary>
/// <param name="estimatorFactory">The delegate to create the training estimator. It is assumed that this estimator
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
/// <param name="label">The input label column.</param>
/// <param name="features">The input features column.</param>
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights.</param>
/// <param name="groupId">The input groupId column.</param>
public Ranker(EstimatorFactory estimatorFactory, Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights)
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)),
Contracts.CheckRef(features, nameof(features)),
Contracts.CheckRef(groupId, nameof(groupId)),
weights),
_fixedOutputNames)
{
Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
_estFact = estimatorFactory;
Contracts.Assert(Inputs.Length == 3 || Inputs.Length == 4);
Score = new Impl(this);
}

private static PipelineColumn[] MakeInputs(Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights)
=> weights == null ? new PipelineColumn[] { label, features, groupId } : new PipelineColumn[] { label, features, groupId, weights };

protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames)
{
Contracts.AssertValue(env);
env.Assert(Utils.Size(inputNames) == Inputs.Length);
return _estFact(env, inputNames[0], inputNames[1], inputNames[2], inputNames.Length > 3 ? inputNames[3] : null);
}

private sealed class Impl : Scalar<float>
{
public Impl(Ranker<TVal> rec) : base(rec, rec.Inputs) { }
}
}
}
}
44 changes: 44 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,48 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
}
}

/// <summary>
/// The central context for regression trainers.
/// </summary>
public sealed class RankerContext : TrainContextBase
{
/// <summary>
/// For trainers for performing regression.
/// </summary>
public RankerTrainers Trainers { get; }

public RankerContext(IHostEnvironment env)
: base(env, nameof(RankerContext))
{
Trainers = new RankerTrainers(this);
}

public sealed class RankerTrainers : ContextInstantiatorBase
{
internal RankerTrainers(RankerContext ctx)
: base(ctx)
{
}
}

/// <summary>
/// Evaluates scored ranking data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
/// <param name="groupId">The name of the groupId column in <paramref name="data"/>.</param>
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public RankerEvaluator.Result Evaluate(IDataView data, string label, string groupId, string score = DefaultColumnNames.Score)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
Host.CheckNonEmpty(groupId, nameof(groupId));

var eval = new RankerEvaluator(Host, new RankerEvaluator.Arguments() { });
return eval.Evaluate(data, label, groupId, score);
}
}
}
Loading

0 comments on commit 3cdd3c8

Please sign in to comment.