forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bring ensembles into codebase (dotnet#379)
Introduce Ensemble codebase
- Loading branch information
1 parent
33c28a7
commit 05863c8
Showing
73 changed files
with
19,405 additions
and
10,044 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime.Data; | ||
|
||
namespace Microsoft.ML.Runtime.Ensemble | ||
{ | ||
public sealed class Batch | ||
{ | ||
public readonly RoleMappedData TrainInstances; | ||
public readonly RoleMappedData TestInstances; | ||
|
||
public Batch(RoleMappedData trainData, RoleMappedData testData) | ||
{ | ||
Contracts.CheckValue(trainData, nameof(trainData)); | ||
Contracts.CheckValue(testData, nameof(testData)); | ||
TrainInstances = trainData; | ||
TestInstances = testData; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
|
||
namespace Microsoft.ML.Runtime.Ensemble | ||
{ | ||
internal static class EnsembleUtils | ||
{ | ||
/// <summary> | ||
/// Return a dataset with non-selected features zeroed out. | ||
/// </summary> | ||
public static RoleMappedData SelectFeatures(IHost host, RoleMappedData data, BitArray features) | ||
{ | ||
Contracts.AssertValue(host); | ||
Contracts.AssertValue(data); | ||
Contracts.Assert(data.Schema.Feature != null); | ||
Contracts.AssertValue(features); | ||
|
||
var type = data.Schema.Feature.Type; | ||
Contracts.Assert(features.Length == type.VectorSize); | ||
int card = Utils.GetCardinality(features); | ||
if (card == type.VectorSize) | ||
return data; | ||
|
||
// REVIEW: This doesn't preserve metadata on the features column. Should it? | ||
var name = data.Schema.Feature.Name; | ||
var view = LambdaColumnMapper.Create( | ||
host, "FeatureSelector", data.Data, name, name, type, type, | ||
(ref VBuffer<Single> src, ref VBuffer<Single> dst) => SelectFeatures(ref src, features, card, ref dst)); | ||
|
||
var res = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames()); | ||
return res; | ||
} | ||
|
||
/// <summary> | ||
/// Fill dst with values selected from src if the indices of the src values are set in includedIndices, | ||
/// otherwise assign default(T). The length of dst will be equal to src.Length. | ||
/// </summary> | ||
public static void SelectFeatures<T>(ref VBuffer<T> src, BitArray includedIndices, int cardinality, ref VBuffer<T> dst) | ||
{ | ||
Contracts.Assert(Utils.Size(includedIndices) == src.Length); | ||
Contracts.Assert(cardinality == Utils.GetCardinality(includedIndices)); | ||
Contracts.Assert(cardinality < src.Length); | ||
|
||
var values = dst.Values; | ||
var indices = dst.Indices; | ||
|
||
if (src.IsDense) | ||
{ | ||
if (cardinality >= src.Length / 2) | ||
{ | ||
T defaultValue = default; | ||
if (Utils.Size(values) < src.Length) | ||
values = new T[src.Length]; | ||
for (int i = 0; i < src.Length; i++) | ||
values[i] = !includedIndices[i] ? defaultValue : src.Values[i]; | ||
dst = new VBuffer<T>(src.Length, values, indices); | ||
} | ||
else | ||
{ | ||
if (Utils.Size(values) < cardinality) | ||
values = new T[cardinality]; | ||
if (Utils.Size(indices) < cardinality) | ||
indices = new int[cardinality]; | ||
|
||
int count = 0; | ||
for (int i = 0; i < src.Length; i++) | ||
{ | ||
if (includedIndices[i]) | ||
{ | ||
Contracts.Assert(count < cardinality); | ||
values[count] = src.Values[i]; | ||
indices[count] = i; | ||
count++; | ||
} | ||
} | ||
|
||
Contracts.Assert(count == cardinality); | ||
dst = new VBuffer<T>(src.Length, count, values, indices); | ||
} | ||
} | ||
else | ||
{ | ||
int valuesSize = Utils.Size(values); | ||
int indicesSize = Utils.Size(indices); | ||
if (valuesSize < src.Count || indicesSize < src.Count) | ||
{ | ||
if (valuesSize < cardinality) | ||
values = new T[cardinality]; | ||
if (indicesSize < cardinality) | ||
indices = new int[cardinality]; | ||
} | ||
|
||
int count = 0; | ||
for (int i = 0; i < src.Count; i++) | ||
{ | ||
if (includedIndices[src.Indices[i]]) | ||
{ | ||
values[count] = src.Values[i]; | ||
indices[count] = src.Indices[i]; | ||
count++; | ||
} | ||
} | ||
|
||
dst = new VBuffer<T>(src.Length, count, values, indices); | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.