Skip to content

Commit

Permalink
prefix estimator type variables with "T" and use single-letters for c…
Browse files Browse the repository at this point in the history
…ategory in classes that don't have other type variables
  • Loading branch information
douira committed Feb 22, 2025
1 parent 68bec84 commit 554685f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import java.util.Locale;

public abstract class Average1DEstimator<Category> extends Estimator<Category, Average1DEstimator.Value<Category>, Average1DEstimator.ValueBatch<Category>, Void, Long, Average1DEstimator.Average<Category>> {
public abstract class Average1DEstimator<C> extends Estimator<C, Average1DEstimator.Value<C>, Average1DEstimator.ValueBatch<C>, Void, Long, Average1DEstimator.Average<C>> {
private final float newDataRatio;
private final long initialEstimate;

Expand Down Expand Up @@ -39,11 +39,11 @@ public float getAverage() {
}

@Override
protected ValueBatch<Category> createNewDataBatch() {
protected ValueBatch<C> createNewDataBatch() {
return new ValueBatch<>();
}

protected static class Average<ModelCategory> implements Estimator.Model<Void, Long, ValueBatch<ModelCategory>, Average<ModelCategory>> {
protected static class Average<C> implements Estimator.Model<Void, Long, ValueBatch<C>, Average<C>> {
private final float newDataRatio;
private boolean hasRealData = false;
private float average;
Expand All @@ -54,7 +54,7 @@ public Average(float newDataRatio, float initialValue) {
}

@Override
public Average<ModelCategory> update(ValueBatch<ModelCategory> batch) {
public Average<C> update(ValueBatch<C> batch) {
if (batch.count > 0) {
if (this.hasRealData) {
this.average = MathUtil.exponentialMovingAverage(this.average, batch.getAverage(), this.newDataRatio);
Expand All @@ -79,11 +79,11 @@ public String toString() {
}

@Override
protected Average<Category> createNewModel() {
protected Average<C> createNewModel() {
return new Average<>(this.newDataRatio, this.initialEstimate);
}

public Long predict(Category category) {
public Long predict(C category) {
return super.predict(category, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,46 @@
/**
* This generic model learning class that can be used to estimate values based on a set of data points. It performs batch-wise model updates. The actual data aggregation and model updates are delegated to the implementing classes. The estimator stores multiple models in a map, one for each category.
*
* @param <Category> The type of the category key
* @param <Point> A data point contains a category and one piece of data
* @param <Batch> A data batch contains multiple data points
* @param <Input> The input to the model
* @param <Output> The output of the model
* @param <Model> The model that is used to predict values
* @param <TCategory> The type of the category key
* @param <TPoint> A data point contains a category and one piece of data
* @param <TBatch> A data batch contains multiple data points
* @param <TInput> The input to the model
* @param <TOutput> The output of the model
* @param <TModel> The model that is used to predict values
*/
public abstract class Estimator<
Category,
Point extends Estimator.DataPoint<Category>,
Batch extends Estimator.DataBatch<Point>,
Input,
Output,
Model extends Estimator.Model<Input, Output, Batch, Model>> {
protected final Map<Category, Model> models = createMap();
protected final Map<Category, Batch> batches = createMap();

protected interface DataBatch<BatchPoint> {
void addDataPoint(BatchPoint input);
TCategory,
TPoint extends Estimator.DataPoint<TCategory>,
TBatch extends Estimator.DataBatch<TPoint>,
TInput,
TOutput,
TModel extends Estimator.Model<TInput, TOutput, TBatch, TModel>> {
protected final Map<TCategory, TModel> models = createMap();
protected final Map<TCategory, TBatch> batches = createMap();

protected interface DataBatch<TBatchPoint> {
void addDataPoint(TBatchPoint input);

void reset();
}

protected interface DataPoint<PointCategory> {
PointCategory category();
protected interface DataPoint<TPointCategory> {
TPointCategory category();
}

protected interface Model<ModelInput, ModelOutput, ModelBatch, ModelSelf extends Model<ModelInput, ModelOutput, ModelBatch, ModelSelf>> {
ModelSelf update(ModelBatch batch);
protected interface Model<TModelInput, TModelOutput, TModelBatch, TModelSelf extends Model<TModelInput, TModelOutput, TModelBatch, TModelSelf>> {
TModelSelf update(TModelBatch batch);

ModelOutput predict(ModelInput input);
TModelOutput predict(TModelInput input);
}

protected abstract Batch createNewDataBatch();
protected abstract TBatch createNewDataBatch();

protected abstract Model createNewModel();
protected abstract TModel createNewModel();

protected abstract <T> Map<Category, T> createMap();
protected abstract <T> Map<TCategory, T> createMap();

public void addData(Point data) {
public void addData(TPoint data) {
var category = data.category();
var batch = this.batches.get(category);
if (batch == null) {
Expand All @@ -54,7 +54,7 @@ public void addData(Point data) {
batch.addDataPoint(data);
}

private Model ensureModel(Category category) {
private TModel ensureModel(TCategory category) {
var model = this.models.get(category);
if (model == null) {
model = this.createNewModel();
Expand All @@ -77,11 +77,11 @@ public void updateModels() {
});
}

public Output predict(Category category, Input input) {
return (Output) this.ensureModel(category).predict(input);
public TOutput predict(TCategory category, TInput input) {
return this.ensureModel(category).predict(input);
}

public String toString(Category category) {
public String toString(TCategory category) {
var model = this.models.get(category);
if (model == null) {
return "-";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import java.util.Locale;

public abstract class Linear2DEstimator<Category> extends Estimator<Category, Linear2DEstimator.DataPair<Category>, Linear2DEstimator.LinearRegressionBatch<Category>, Long, Long, Linear2DEstimator.LinearFunction<Category>> {
public abstract class Linear2DEstimator<C> extends Estimator<C, Linear2DEstimator.DataPair<C>, Linear2DEstimator.LinearRegressionBatch<C>, Long, Long, Linear2DEstimator.LinearFunction<C>> {
private final float newDataRatio;
private final int initialSampleTarget;
private final long initialOutput;
Expand Down Expand Up @@ -34,7 +34,7 @@ public void reset() {
}

@Override
protected LinearRegressionBatch<Category> createNewDataBatch() {
protected LinearRegressionBatch<C> createNewDataBatch() {
return new LinearRegressionBatch<>();
}

Expand Down Expand Up @@ -137,7 +137,7 @@ public String toString() {
}

@Override
protected LinearFunction<Category> createNewModel() {
protected LinearFunction<C> createNewModel() {
return new LinearFunction<>(this.newDataRatio, this.initialSampleTarget, this.initialOutput);
}
}

0 comments on commit 554685f

Please sign in to comment.