Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GetSummaryDataView() implementation for Pca and Linear Predictors #185

Merged
merged 19 commits into from
Jun 8, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/Microsoft.ML.PCA/PcaTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ public static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironm
// REVIEW: move the predictor to a different file and fold EigenUtils.cs to this file.
public sealed class PcaPredictor : PredictorBase<Float>,
IValueMapper,
ICanGetSummaryAsIDataView,
ICanSaveInTextFormat, ICanSaveModel, ICanSaveSummary
{
public const string LoaderSignature = "pcaAnomExec";
Expand Down Expand Up @@ -469,6 +470,29 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
}
}

public IDataView GetSummaryDataView(RoleMappedSchema schema)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetSummaryDataView [](start = 25, length = 18)

Could you add a unit test for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

{
var bldr = new ArrayDataViewBuilder(Host);

bldr.AddColumn("MeanVector", NumberType.R4, _mean);
bldr.AddColumn("ProjectedMeanVector", NumberType.R4, _meanProjected);

ValueGetter<VBuffer<DvText>> getSlotNames =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getSlotNames [](start = 41, length = 12)

We probably don't need the slot names here since they don't give any additional information other than the slot index.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you end up defining slot names?


In reply to: 191581196 [](ancestors = 191581196)

(ref VBuffer<DvText> dst) =>
{
var values = new DvText[_rank];
for (var i = 0; i < _rank; ++i)
values[i] = new DvText("V" + i);

// should we reuse dst VBuffer or not?
var tmp = new VBuffer<DvText>(_rank, values);
tmp.CopyTo(ref dst);
};

bldr.AddColumn("EigenVectors", getSlotNames, NumberType.R4, _eigenVectors);
return bldr.GetDataView();
}

public ColumnType InputType
{
get { return _inputType; }
Expand Down
76 changes: 28 additions & 48 deletions src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public abstract class LinearPredictor : PredictorBase<Float>,
ICanSaveInTextFormat,
ICanSaveInSourceCode,
ICanSaveModel,
ICanGetSummaryAsIRow,
ICanSaveSummary,
IPredictorWithFeatureWeights<Float>,
IWhatTheFeatureValueMapper,
Expand Down Expand Up @@ -343,6 +344,30 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema)

public abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema);

public virtual IRow GetSummaryIRowOrNull(RoleMappedSchema schema)
{
var cols = new List<IColumn>();

var names = default(VBuffer<DvText>);
MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names);
var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames,
new VectorType(TextType.Instance, Weight.Length), ref names);
var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol);
var colType = new VectorType(NumberType.R4, Weight.Length);

// Add the bias and the weight columns.
var bias = Bias;
cols.Add(RowColumnUtils.GetColumn("Bias", NumberType.R4, ref bias));
var weights = Weight;
cols.Add(RowColumnUtils.GetColumn("Weights", colType, ref weights, slotNamesRow));
return RowColumnUtils.GetRow(null, cols.ToArray());
}

public virtual IRow GetStatsIRowOrNull(RoleMappedSchema schema)
{
return null;
}

public abstract void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null);

public virtual void GetFeatureWeights(ref VBuffer<Float> weights)
Expand All @@ -366,8 +391,7 @@ public ValueMapper<TSrc, VBuffer<Float>> GetWhatTheFeatureMapper<TSrc, TDstContr

public sealed partial class LinearBinaryPredictor : LinearPredictor,
ICanGetSummaryInKeyValuePairs,
IParameterMixer<Float>,
ICanGetSummaryAsIRow
IParameterMixer<Float>
{
public const string LoaderSignature = "Linear2CExec";
public const string RegistrationName = "LinearBinaryPredictor";
Expand Down Expand Up @@ -503,26 +527,7 @@ public IList<KeyValuePair<string, object>> GetSummaryInKeyValuePairs(RoleMappedS
return results;
}

public IRow GetSummaryIRowOrNull(RoleMappedSchema schema)
{
var cols = new List<IColumn>();

var names = default(VBuffer<DvText>);
MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names);
var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames,
new VectorType(TextType.Instance, Weight.Length), ref names);
var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol);
var colType = new VectorType(NumberType.R4, Weight.Length);

// Add the bias and the weight columns.
var bias = Bias;
cols.Add(RowColumnUtils.GetColumn("Bias", NumberType.R4, ref bias));
var weights = Weight;
cols.Add(RowColumnUtils.GetColumn("Weights", colType, ref weights, slotNamesRow));
return RowColumnUtils.GetRow(null, cols.ToArray());
}

public IRow GetStatsIRowOrNull(RoleMappedSchema schema)
public override IRow GetStatsIRowOrNull(RoleMappedSchema schema)
{
if (_stats == null)
return null;
Expand Down Expand Up @@ -582,8 +587,7 @@ public override void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICali

public sealed class LinearRegressionPredictor : RegressionPredictor,
IParameterMixer<Float>,
ICanGetSummaryInKeyValuePairs,
ICanGetSummaryAsIRow
ICanGetSummaryInKeyValuePairs
{
public const string LoaderSignature = "LinearRegressionExec";
public const string RegistrationName = "LinearRegressionPredictor";
Expand Down Expand Up @@ -663,30 +667,6 @@ public IList<KeyValuePair<string, object>> GetSummaryInKeyValuePairs(RoleMappedS

return results;
}

public IRow GetSummaryIRowOrNull(RoleMappedSchema schema)
{
var cols = new List<IColumn>();

var names = default(VBuffer<DvText>);
MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names);
var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames,
new VectorType(TextType.Instance, Weight.Length), ref names);
var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol);
var colType = new VectorType(NumberType.R4, Weight.Length);

// Add the bias and the weight columns.
var bias = Bias;
cols.Add(RowColumnUtils.GetColumn("Bias", NumberType.R4, ref bias));
var weights = Weight;
cols.Add(RowColumnUtils.GetColumn("Weights", colType, ref weights, slotNamesRow));
return RowColumnUtils.GetRow(null, cols.ToArray());
}

public IRow GetStatsIRowOrNull(RoleMappedSchema schema)
{
return null;
}
}

public sealed class PoissonRegressionPredictor : RegressionPredictor, IParameterMixer<Float>
Expand Down