Skip to content

Commit

Permalink
fix bug where if one pipeline hyperparam optimization converges, run …
Browse files Browse the repository at this point in the history
…terminates (dotnet#36)
  • Loading branch information
daholste authored Jan 28, 2019
1 parent bf42ba5 commit 6609cd9
Showing 1 changed file with 34 additions and 16 deletions.
50 changes: 34 additions & 16 deletions src/AutoML/PipelineSuggesters/PipelineSuggester.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,33 @@ public static InferredPipeline GetNextInferredPipeline(IEnumerable<InferredPipel
return GetNextFirstStagePipeline(history, availableTrainers, transforms);
}

// get next trainer
// get top trainers from stage 1 runs
var topTrainers = GetTopTrainers(history, availableTrainers, isMaximizingMetric);
var nextTrainerIndex = (history.Count() - availableTrainers.Count()) % topTrainers.Count();
var trainer = topTrainers.ElementAt(nextTrainerIndex).Clone();

// make sure we have not seen pipeline before.
// repeat until passes or runs out of chances.
var visitedPipelines = new HashSet<InferredPipeline>(history.Select(h => h.Pipeline));
const int maxNumberAttempts = 10;
var count = 0;
do

// sort top trainers by # of times they've been run, from lowest to highest
var orderedTopTrainers = OrderTrainersByNumTrials(history, topTrainers);

// iterate over top trainers (from least run to most run),
// to find next pipeline
foreach(var trainer in orderedTopTrainers)
{
SampleHyperparameters(trainer, history, isMaximizingMetric);
var pipeline = new InferredPipeline(transforms, trainer);
if(!visitedPipelines.Contains(pipeline))
var newTrainer = trainer.Clone();

// make sure we have not seen pipeline before.
// repeat until passes or runs out of chances
var visitedPipelines = new HashSet<InferredPipeline>(history.Select(h => h.Pipeline));
const int maxNumberAttempts = 10;
var count = 0;
do
{
return pipeline;
}
} while (++count <= maxNumberAttempts);
SampleHyperparameters(newTrainer, history, isMaximizingMetric);
var pipeline = new InferredPipeline(transforms, newTrainer);
if (!visitedPipelines.Contains(pipeline))
{
return pipeline;
}
} while (++count <= maxNumberAttempts);
}

return null;
}
Expand All @@ -84,6 +92,16 @@ private static IEnumerable<SuggestedTrainer> GetTopTrainers(IEnumerable<Inferred
return topTrainers;
}

private static IEnumerable<SuggestedTrainer> OrderTrainersByNumTrials(IEnumerable<InferredPipelineRunResult> history,
IEnumerable<SuggestedTrainer> selectedTrainers)
{
var selectedTrainerNames = new HashSet<TrainerName>(selectedTrainers.Select(t => t.TrainerName));
return history.Where(h => selectedTrainerNames.Contains(h.Pipeline.Trainer.TrainerName))
.GroupBy(h => h.Pipeline.Trainer.TrainerName)
.OrderBy(x => x.Count())
.Select(x => x.First().Pipeline.Trainer);
}

private static InferredPipeline GetNextFirstStagePipeline(IEnumerable<InferredPipelineRunResult> history,
IEnumerable<SuggestedTrainer> availableTrainers,
IEnumerable<SuggestedTransform> transforms)
Expand Down

0 comments on commit 6609cd9

Please sign in to comment.