Welcome back (previous post) to my machine learning journey! In the last post, I explained the setup environment and demonstrated the first steps of building an interactive machine learning model. Today, I’ll attempt to (with my limited knowledge) explain creating the web application using Flask. It serves as an example of the many ways you could interact with a model in real-time.
Introduction to Flask
Flask is a lightweight web framework for Python that’s perfect for building simple web applications quickly and efficiently. It provides the tools needed to set up routes, handle requests, and render templates.
To get started with Flask, you’ll need to install it:
bashCopy codepip install Flask
Once installed, you can create a basic Flask app by defining routes that handle different actions.
Creating Routes for Querying and Labeling Instances
The core functionality of our web app revolves around querying instances from the pool and allowing users to label them. Let’s set up the necessary routes and create HTML templates for user interaction.
Here’s the basic structure of our Flask app:
pythonCopy codeimport numpy as np
import joblib
from flask import Flask, request, render_template, redirect, url_for
from modAL.models import ActiveLearner
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
# Load the ActiveLearner model
learner = joblib.load('active_learner_model.pkl')
# Load the pool data
X_pool, y_pool = joblib.load('X_y_pool.pkl')
# Split the pool data into train and test sets
X_pool, X_test, y_pool, y_test = train_test_split(X_pool, y_pool, test_size=0.2, random_state=42)
# Initialize Flask app
app = Flask(__name__)
# Global variable to store labeled instances
labeled_instances = []
@app.route('/')
def home():
return render_template('index.html')
@app.route('/query', methods=['GET', 'POST'])
def query():
global X_pool, y_pool, query_instance, query_idx, labeled_instances
if request.method == 'POST':
# Get the label from the user
label = int(request.form['label'])
# Teach the model the new label
learner.teach(X_pool[query_idx].reshape(1, -1), np.array([label]))
# Store the labeled instance and its label
labeled_instances.append((X_pool[query_idx].reshape(1, -1), label))
# Remove the queried instance from the pool
X_pool = np.delete(X_pool, query_idx, axis=0)
y_pool = np.delete(y_pool, query_idx, axis=0)
# Evaluate the model
accuracy = learner.score(X_test, y_test)
# Query the model for the next instance
query_idx, query_instance = learner.query(X_pool)
return render_template('query.html', instance=query_instance.tolist(), accuracy=accuracy)
# Initial query
query_idx, query_instance = learner.query(X_pool)
accuracy = learner.score(X_test, y_test)
return render_template('query.html', instance=query_instance.tolist(), accuracy=accuracy)
@app.route('/review')
def review():
accuracy = learner.score(X_test, y_test)
return render_template('review.html', labeled_instances=labeled_instances, accuracy=accuracy, enumerate=enumerate)
@app.route('/correct', methods=['POST'])
def correct():
global labeled_instances, learner, X_pool, y_pool
try:
# Get the index and new label from the form
index = int(request.form['index'])
new_label = int(request.form['new_label'])
# Ensure the index is within the valid range
if 0 <= index < len(labeled_instances):
# Update the label in the labeled instances
instance, _ = labeled_instances[index]
labeled_instances[index] = (instance, new_label)
# Recreate the dataset with corrected labels
X_corrected = np.vstack([instance for instance, label in labeled_instances])
y_corrected = np.array([label for instance, label in labeled_instances])
# Clear and reinitialize the learner with corrected data
learner = ActiveLearner(
estimator=RandomForestClassifier(),
X_training=X_corrected,
y_training=y_corrected
)
accuracy = learner.score(X_test, y_test)
return render_template('review.html', labeled_instances=labeled_instances, accuracy=accuracy, enumerate=enumerate)
else:
return f"Error: Index {index} is out of range. Valid range is 0 to {len(labeled_instances) - 1}.", 400
except Exception as e:
return str(e), 500
if __name__ == '__main__':
app.run(debug=True)
Handling and Correcting Labels
One of the critical features of our app is the ability to correct labels. This ensures that any mistakes made during the labeling process can be rectified, improving the model’s learning over time.
In the review
route, users can see all labeled instances and correct any mistakes:
pythonCopy code@app.route('/review')
def review():
accuracy = learner.score(X_test, y_test)
return render_template('review.html', labeled_instances=labeled_instances, accuracy=accuracy, enumerate=enumerate)
@app.route('/correct', methods=['POST'])
def correct():
global labeled_instances, learner, X_pool, y_pool
try:
# Get the index and new label from the form
index = int(request.form['index'])
new_label = int(request.form['new_label'])
# Ensure the index is within the valid range
if 0 <= index < len(labeled_instances):
# Update the label in the labeled instances
instance, _ = labeled_instances[index]
labeled_instances[index] = (instance, new_label)
# Recreate the dataset with corrected labels
X_corrected = np.vstack([instance for instance, label in labeled_instances])
y_corrected = np.array([label for instance, label in labeled_instances])
# Clear and reinitialize the learner with corrected data
learner = ActiveLearner(
estimator=RandomForestClassifier(),
X_training=X_corrected,
y_training=y_corrected
)
accuracy = learner.score(X_test, y_test)
return render_template('review.html', labeled_instances=labeled_instances, accuracy=accuracy, enumerate=enumerate)
else:
return f"Error: Index {index} is out of range. Valid range is 0 to {len(labeled_instances) - 1}.", 400
except Exception as e:
return str(e), 500
Challenges and Reflections
Creating the web application presented several challenges, from handling data correctly to ensuring that the user interface was intuitive and responsive. Here are some of the key challenges I faced and how I overcame them:
- Handling 3D Arrays: Initially, the model expected 2D arrays but received 3D arrays. This was resolved by ensuring data reshaping maintained the correct dimensions.pythonCopy code
learner.teach(X_pool[query_idx].reshape(1, -1), np.array([label]))
- Model Accuracy Calculation: To ensure consistent accuracy measurement, I used a fixed test set derived from the initial pool data. This helped in evaluating the model’s performance on unseen data, providing a more reliable metric.pythonCopy code
X_pool, X_test, y_pool, y_test = train_test_split(X_pool, y_pool, test_size=0.2, random_state=42)
- Correcting Labels: The ability to review and correct labels was crucial. By storing labeled instances and allowing corrections, the model could be retrained with updated labels, improving its learning over time.pythonCopy code
X_corrected = np.vstack([instance for instance, label in labeled_instances]) y_corrected = np.array([label for instance, label in labeled_instances])
Reflecting on this part of the journey, I realized the importance of building a user-friendly interface and the need for careful data handling. These elements are essential for creating a robust interactive machine learning model.
In the next post, we’ll explore how visualizing data can help us understand our model better. Stay tuned!