diff --git a/README.md b/README.md index 79a8af3..1d2d9bb 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ In the tradition of Tkinter SVM GUI, the purpose of this app is to demonstrate how machine learning model forms are affected by the shape of the underlying dataset. By selecting a dataset or by creating one of your own, you can fit a model to the data and see how the model would make decisions based on the data it has been trained on. Although this is a toy example, hopefully it helps give you the intuition that the machine learning process is a model selection search for the best combination of features, algorithm, and hyperparameter that generalize well in a bounded feature space. +![Screenshot](static/img/screenshot.png) + ## Getting Started To run this app locally, first clone the repository and install the requirements: diff --git a/app.py b/app.py index f70666d..5978736 100644 --- a/app.py +++ b/app.py @@ -19,6 +19,7 @@ ########################################################################## import json +import numpy as np from flask import Flask from flask import render_template, jsonify, request @@ -94,6 +95,7 @@ def fit(): data = request.get_json() params = data.get("model", {}) dataset = data.get("dataset", []) + grid = data.get("grid", []) model = { 'gaussiannb': GaussianNB(), 'multinomialnb': MultinomialNB(), @@ -132,8 +134,22 @@ def fit(): yhat = model.predict(X) metrics = prfs(y, yhat, average="macro") + # Make probability predictions on the grid to implement contours + # The returned value is the class index + the probability + # To get the selected class in JavaScript, use Math.floor(p) + # Where p is the probability returned by the grid. Note that this + # method guarantees that no P(c) == 1 to prevent class misidentification + Xp = asarray([ + [point["x"], point["y"]] for point in grid + ]) + preds = [] + for proba in model.predict_proba(Xp): + c = np.argmax(proba) + preds.append(float(c+proba[c])-0.000001) + return jsonify({ - "metrics": dict(zip(["precision", "recall", "f1", "support"], metrics)) + "metrics": dict(zip(["precision", "recall", "f1", "support"], metrics)), + "grid": preds, }) ########################################################################## diff --git a/static/img/screenshot.png b/static/img/screenshot.png new file mode 100644 index 0000000..c18f40c Binary files /dev/null and b/static/img/screenshot.png differ diff --git a/static/js/dataspace.js b/static/js/dataspace.js index 532e9ec..85140c4 100644 --- a/static/js/dataspace.js +++ b/static/js/dataspace.js @@ -24,48 +24,50 @@ function alertMessage(message) { class Dataspace { constructor(selector) { - this.svg = d3.select(selector); - this.$svg = $(selector); - this.dataset = []; - - // drawing properties are hardcoded for now - this.width = this.$svg.width(); - this.height = this.$svg.height(); - this.color = d3.scaleOrdinal(d3.schemeCategory10); - - this.xScale = d3.scaleLinear() - .domain([0, 1]) - .range([margin.left, this.width - margin.right]); - - this.yScale = d3.scaleLinear() - .domain([0, 1]) - .range([margin.top, this.height - margin.bottom]) + this.svg = d3.select(selector); + this.$svg = $(selector); + this.dataset = []; + this.grid = null; + + // drawing properties are hardcoded for now + this.width = this.$svg.width(); + this.height = this.$svg.height(); + this.color = d3.scaleOrdinal(d3.schemeCategory10); + + this.xScale = d3.scaleLinear() + .domain([0, 1]) + .range([margin.left, this.width - margin.right]); + + this.yScale = d3.scaleLinear() + .domain([0, 1]) + .range([margin.top, this.height - margin.bottom]) } draw() { var self = this; self.svg.selectAll("circle") - .data(self.dataset) - .enter() - .append("circle") - .attr('cx', function (d) { return self.xScale(d.x); }) - .attr('cy', function (d) { return self.yScale(d.y); }) - .attr('fill', function (d) { return self.color(d.c); }) - .attr('r', radius); + .data(self.dataset) + .enter() + .append("circle") + .attr('cx', function (d) { return self.xScale(d.x); }) + .attr('cy', function (d) { return self.yScale(d.y); }) + .attr('fill', function (d) { return self.color(d.c); }) + .attr("stroke", "#FFFFFF") + .attr('r', radius); } // Add raw data point (e.g. where x and y are between 0 and 1) addPoint(point) { - this.dataset.push(point); - this.draw(); + this.dataset.push(point); + this.draw(); } // Add coordinates data point (e.g. where x and y are in the svg) addCoords(coords) { var point = { - x: this.xScale.invert(coords[0]), - y: this.yScale.invert(coords[1]), - c: currentClass + x: this.xScale.invert(coords[0]), + y: this.yScale.invert(coords[1]), + c: currentClass }; this.addPoint(point); } @@ -74,14 +76,14 @@ class Dataspace { fetch(data) { this.reset(); d3.json("/generate", { - method: "POST", - body: JSON.stringify(data), - headers: { - "Content-Type": "application/json; charset=UTF-8" - } + method: "POST", + body: JSON.stringify(data), + headers: { + "Content-Type": "application/json; charset=UTF-8" + } }).then(json => { - this.dataset = json; - this.draw(); + this.dataset = json; + this.draw(); }).catch(error => { console.log(error); alertMessage("Server could not generate dataset!"); @@ -92,13 +94,18 @@ class Dataspace { fit(model) { $("#metrics").removeClass("visible").addClass("invisible"); if (this.dataset.length == 0) { - console.log("cannot fit model to no data!"); - return + console.log("cannot fit model to no data!"); + return } + // The contours grid determines what to make predictions on. + // TODO: don't pass this to the server but allow the server to compute it. + var self = this; + self.grid = self.contoursGrid() var data = { - model: model, - dataset: this.dataset, + model: model, + dataset: self.dataset, + grid: self.grid, } d3.json("/fit", { @@ -108,18 +115,64 @@ class Dataspace { "Content-Type": "application/json; charset=UTF-8" } }).then(json => { + // Reset the old contours + self.svg.selectAll("g").remove(); + + // Update the metrics $("#f1score").text(json.metrics.f1); $("#metrics").removeClass("invisible").addClass("visible"); + + // Update the grid with the predictions values. + $.each(json.grid, function(i, val) { + self.grid[i] = val; + }) + + // Compute the thresholds from the classes, then compute the colors + var thresholds = self.classes().map(i => d3.range(i, i + 1, 0.1)).flat().sort(); + var colorMap = {} + $.each(self.classes(), c => { + colorMap[c] = d3.scaleLinear().domain([c, c+1]) + .interpolate(d3.interpolateHcl) + .range(["#FFFFFF", self.color(c)]) + }); + + var getColor = d => { + console.log(d.value) + return colorMap[Math.floor(d.value)](d.value) + } + + // Add the contours from the predictions for each class + var contours = d3.contours() + .size([self.grid.n, self.grid.m]) + .thresholds(thresholds) + .smooth(true) + (self.grid) + .map(self.grid.transform) + + // Draw the contours on the SVG + self.svg.insert("g", ":first-child") + .attr("fill", "none") + .attr("stroke", "#FFFFFF") + .attr("stroke-opacity", 0.65) + .selectAll("path") + .data(contours) // Here is where the contours gets added + .join("path") + .attr("fill", getColor) // Here is the color value! + .style("opacity", 0.85) + .attr("d", d3.geoPath()); + }).catch(error => { + console.log(error); alertMessage("Could not fit model, check JSON hyperparams and try again!"); }); } // Reset the plotting area reset() { - this.dataset = []; - this.svg.selectAll("circle").remove(); - $("#metrics").removeClass("visible").addClass("invisible"); + this.dataset = []; + this.svg.selectAll("circle").remove(); + this.svg.selectAll("g").remove(); + $("#metrics").removeClass("visible").addClass("invisible"); } // Count the number of classes in the dataset @@ -132,6 +185,46 @@ class Dataspace { }, []); } + // Create the contours grid to pass to the predict function. + contoursGrid() { + var self = this; + const q = 4; + const x0 = -q / 2, x1 = this.width + margin.right + q; + const y0 = -q / 2, y1 = this.height + q; + const n = Math.ceil((x1-x0) / q); + const m = Math.ceil((y1-y0) / q); + const grid = new Array(n*m); + grid.x = -q; + grid.y = -q; + grid.k = q; + grid.n = n; + grid.m = m; + + // Converts from grid coordinates (indexes) to screen coordinates (pixels). + grid.transform = ({ type, value, coordinates }) => { + return { + type, value, coordinates: coordinates.map(rings => { + return rings.map(points => { + return points.map(([x, y]) => ([ + grid.x + grid.k * x, + grid.y + grid.k * y + ])); + }); + }) + }; + } + + // We just have to pass the x and y values to the server to predict them using the model, then the rest of the code is the sames? + for (let j = 0; j < m; ++j) { + for (let i = 0; i < n; ++i) { + var obj = { x: this.xScale.invert(i * q + x0), y: this.yScale.invert(j * q + y0) }; + grid[j * grid.n + i] = obj; + } + } + + return grid; + } + } $(document).ready(function() { diff --git a/templates/index.html b/templates/index.html index ec1ab5c..310db2f 100644 --- a/templates/index.html +++ b/templates/index.html @@ -200,6 +200,7 @@
+
@@ -445,6 +446,7 @@ machine learning model forms are affected by the shape of the underlying dataset. By selecting a dataset or by creating one of your own, you can fit a model to the data and see how the model would make decisions based on the data it has been trained on. + The fitted contours display the highest likelihoods of the class the model would select. Although this is a toy example, hopefully it helps give you the intuition that the machine learning process is a model selection search for the best combination of features, algorithm, and hyperparameter that generalize well in a bounded feature space.