Welcome back (previous post) to my machine learning journey! In the last post, I explained the setup environment and demonstrated the first steps of building an interactive machine learning model. Today, I’ll attempt to (with my limited knowledge) explain creating the web application using Flask. It serves as an example of the many ways you could interact with a model in real-time.


Introduction to Flask

Flask is a lightweight web framework for Python that’s perfect for building simple web applications quickly and efficiently. It provides the tools needed to set up routes, handle requests, and render templates.

To get started with Flask, you’ll need to install it:

bashCopy codepip install Flask

Once installed, you can create a basic Flask app by defining routes that handle different actions.

Creating Routes for Querying and Labeling Instances

The core functionality of our web app revolves around querying instances from the pool and allowing users to label them. Let’s set up the necessary routes and create HTML templates for user interaction.

Here’s the basic structure of our Flask app:

pythonCopy codeimport numpy as np
import joblib
from flask import Flask, request, render_template, redirect, url_for
from modAL.models import ActiveLearner
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Load the ActiveLearner model
learner = joblib.load('active_learner_model.pkl')

# Load the pool data
X_pool, y_pool = joblib.load('X_y_pool.pkl')

# Split the pool data into train and test sets
X_pool, X_test, y_pool, y_test = train_test_split(X_pool, y_pool, test_size=0.2, random_state=42)

# Initialize Flask app
app = Flask(__name__)

# Global variable to store labeled instances
labeled_instances = []

@app.route('/')
def home():
    return render_template('index.html')

@app.route('/query', methods=['GET', 'POST'])
def query():
    global X_pool, y_pool, query_instance, query_idx, labeled_instances

    if request.method == 'POST':
        # Get the label from the user
        label = int(request.form['label'])

        # Teach the model the new label
        learner.teach(X_pool[query_idx].reshape(1, -1), np.array([label]))

        # Store the labeled instance and its label
        labeled_instances.append((X_pool[query_idx].reshape(1, -1), label))

        # Remove the queried instance from the pool
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)

        # Evaluate the model
        accuracy = learner.score(X_test, y_test)

        # Query the model for the next instance
        query_idx, query_instance = learner.query(X_pool)

        return render_template('query.html', instance=query_instance.tolist(), accuracy=accuracy)

    # Initial query
    query_idx, query_instance = learner.query(X_pool)
    accuracy = learner.score(X_test, y_test)

    return render_template('query.html', instance=query_instance.tolist(), accuracy=accuracy)

@app.route('/review')
def review():
    accuracy = learner.score(X_test, y_test)
    return render_template('review.html', labeled_instances=labeled_instances, accuracy=accuracy, enumerate=enumerate)

@app.route('/correct', methods=['POST'])
def correct():
    global labeled_instances, learner, X_pool, y_pool

    try:
        # Get the index and new label from the form
        index = int(request.form['index'])
        new_label = int(request.form['new_label'])

        # Ensure the index is within the valid range
        if 0 <= index < len(labeled_instances):
            # Update the label in the labeled instances
            instance, _ = labeled_instances[index]
            labeled_instances[index] = (instance, new_label)

            # Recreate the dataset with corrected labels
            X_corrected = np.vstack([instance for instance, label in labeled_instances])
            y_corrected = np.array([label for instance, label in labeled_instances])

            # Clear and reinitialize the learner with corrected data
            learner = ActiveLearner(
                estimator=RandomForestClassifier(),
                X_training=X_corrected,
                y_training=y_corrected
            )

            accuracy = learner.score(X_test, y_test)

            return render_template('review.html', labeled_instances=labeled_instances, accuracy=accuracy, enumerate=enumerate)
        else:
            return f"Error: Index {index} is out of range. Valid range is 0 to {len(labeled_instances) - 1}.", 400
    except Exception as e:
        return str(e), 500

if __name__ == '__main__':
    app.run(debug=True)

Handling and Correcting Labels

One of the critical features of our app is the ability to correct labels. This ensures that any mistakes made during the labeling process can be rectified, improving the model’s learning over time.

In the review route, users can see all labeled instances and correct any mistakes:

pythonCopy code@app.route('/review')
def review():
    accuracy = learner.score(X_test, y_test)
    return render_template('review.html', labeled_instances=labeled_instances, accuracy=accuracy, enumerate=enumerate)

@app.route('/correct', methods=['POST'])
def correct():
    global labeled_instances, learner, X_pool, y_pool

    try:
        # Get the index and new label from the form
        index = int(request.form['index'])
        new_label = int(request.form['new_label'])

        # Ensure the index is within the valid range
        if 0 <= index < len(labeled_instances):
            # Update the label in the labeled instances
            instance, _ = labeled_instances[index]
            labeled_instances[index] = (instance, new_label)

            # Recreate the dataset with corrected labels
            X_corrected = np.vstack([instance for instance, label in labeled_instances])
            y_corrected = np.array([label for instance, label in labeled_instances])

            # Clear and reinitialize the learner with corrected data
            learner = ActiveLearner(
                estimator=RandomForestClassifier(),
                X_training=X_corrected,
                y_training=y_corrected
            )

            accuracy = learner.score(X_test, y_test)

            return render_template('review.html', labeled_instances=labeled_instances, accuracy=accuracy, enumerate=enumerate)
        else:
            return f"Error: Index {index} is out of range. Valid range is 0 to {len(labeled_instances) - 1}.", 400
    except Exception as e:
        return str(e), 500

Challenges and Reflections

Creating the web application presented several challenges, from handling data correctly to ensuring that the user interface was intuitive and responsive. Here are some of the key challenges I faced and how I overcame them:

  1. Handling 3D Arrays: Initially, the model expected 2D arrays but received 3D arrays. This was resolved by ensuring data reshaping maintained the correct dimensions.pythonCopy codelearner.teach(X_pool[query_idx].reshape(1, -1), np.array([label]))
  2. Model Accuracy Calculation: To ensure consistent accuracy measurement, I used a fixed test set derived from the initial pool data. This helped in evaluating the model’s performance on unseen data, providing a more reliable metric.pythonCopy codeX_pool, X_test, y_pool, y_test = train_test_split(X_pool, y_pool, test_size=0.2, random_state=42)
  3. Correcting Labels: The ability to review and correct labels was crucial. By storing labeled instances and allowing corrections, the model could be retrained with updated labels, improving its learning over time.pythonCopy codeX_corrected = np.vstack([instance for instance, label in labeled_instances]) y_corrected = np.array([label for instance, label in labeled_instances])

Reflecting on this part of the journey, I realized the importance of building a user-friendly interface and the need for careful data handling. These elements are essential for creating a robust interactive machine learning model.

In the next post, we’ll explore how visualizing data can help us understand our model better. Stay tuned!

Leave a Reply

Your email address will not be published. Required fields are marked *