April 19, 2024BLOG

Machine Unlearning for LLMs: Build Apps that Self-Correct in Real-Time

Building a basic Large Language Model (LLM) application is easy: there are hundreds if not thousands of open-source solutions available on GitHub. However, running a full-featured LLM application in production is difficult. In particular, LLMs are not built to forget. This means that incorrect data points can mess up your LLM pipeline and cause bad-quality outputs. You can fix this by using a Retrieval-Augmented Generation (RAG) approach together with a real-time vector index that automatically updates your LLM knowledge base when compromised data entries are corrected.

This article explains the methodology of machine unlearning in LLM-based apps: how to make a machine learning pipeline forget incorrect or unwanted information it has learned. You will learn what machine unlearning is, why it is important for Large Language Models (LLMs), and how to implement real-time machine unlearning in your own LLM application. You will use Pathway to build an LLM app that can automatically self-correct in real-time.

The following sections set the context for understanding machine unlearning for LLMs. You can also skip directly to the implementation code below or watch a live demonstration of this code presented at Conf42.

What is machine unlearning?

Machine unlearning is the practice of removing learned information from an ML system. A typical example of machine unlearning is the removal of learned data points from a trained machine-learning model. More broadly speaking, machine unlearning refers to the removal of data involved in the ML pipeline and making sure this data is not taken into account by the system anymore. This can be necessary when a machine learning model has learned incorrect information, for example. It is also relevant in situations where a model contains information that it shouldn't have access to due to privacy regulations such as the EU General Data Protection Regulation (GDPR). Machine unlearning can also be used to reduce the number of hallucinations in LLMs.

Why is machine unlearning important for LLMs?

Machine unlearning is important for LLM applications in situations when you have to supplement pre-trained LLMs with additional knowledge. This is often necessary to update LLMs with the latest knowledge (after their training cut-off dates) or to provide access to specific information, such as private domain knowledge. This is usually achieved by either fine-tuning the LLM with additional training or by Retrieval-Augmented Generation (RAG).

But what if the data you have added into the LLM system is compromised and needs to be corrected or forgotten? For example, maybe your LLM app has learned something incorrectly based on outdated knowledge (Pluto is a planet) or it has picked up an unverified rumor from a website and incorporated it as ground truth (Season 2 of 3 Body Problem will be out next week!). Or maybe it has learned sensitive information from private documents that it shouldn't have access to, risking a breach of privacy. In situations like this, you need a way to remove the bad data input from your LLM knowledge base. Your LLM app needs to unlearn the incorrect data points so that it can maintain a high level of output quality.

Machine unlearning for LLMs: Fine-Tuning vs RAG

There are two ways you can apply machine unlearning for your LLM app: fine-tuning and Retrieval-Augmented Generation (RAG). Fine-tuning a generic LLM can be useful when working with batch workloads. Use RAG with a real-time vector index for workloads that need to self-correct in real-time.

Fine-Tuning in Batch

Imagine you are developing an LLM application that works with private data that is updated in real-time. This LLM app will be used internally by your colleagues to retrieve important information about the company's data.

Your company does not have the resources to build or train an LLM from scratch, so you will use a generic GPT model, for example from OpenAI, Mistral, or Cohere. This generic model knows a lot... but nothing about your private data. You decide to fine-tune the generic model by training it further on your company's private data. This is a tedious process involving data preprocessing and adhering strictly to the training data schema. But it's worth the effort because once the model has been fine-tuned you can now use it to ask questions about the company's data.

But remember that your company's data changes in real-time. This means that the moment you have finished fine-tuning…new data will be entering the system and your model is already outdated. Some records may be missing and others may be entirely incorrect. You could perform another round of fine-tuning, but since you are working with real-time data this would be an endless process. And fine-tuning is expensive.

RAG for Real-Time LLM Unlearning

Instead, you choose to implement Retrieval-Augmented Generation (RAG), a popular strategy for updating or improving the knowledge base of your LLM. RAG is a form of prompt engineering: we include the documents we want to query as part of our prompt to the LLM. Your prompt will look something like:

"Based on {these documents}: answer the following question: …"

Passing all the documents for every query would be inefficient and costly. Instead, we only supply the documents that are similar to the query. This is done by converting the query and the documents into vector embeddings and storing them in a vector index. A similarity search gives us the relevant documents for each query, and we pass only these documents to the LLM. To save money and time, you can even use Adaptive RAG to incrementally increase the number of documents until you get the desired result.

Using RAG with a real-time vector index means that you can easily forget or unlearn parts of the context that are outdated, incorrect, or private. The moment you remove a document from the database, the vector index will be immediately updated. Any query you submit after this will no longer include the compromised data point.

Let's see this in action with some code.

Implementing machine unlearning for LLMs in Python

We'll be working with some Python code that you can access on GitHub. Our code will read two documents from disk and use RAG with a real-time vector index to self-correct when we remove one of the documents.

You can also watch a live demonstration of this code presented at Conf42.

Start by reading the documents:

documents = pw.io.fs.read("./documents/", format="binary", with_metadata=True)

Then define your embedder, LLM, and tokenizer:

embedder = embedders.OpenAIEmbedder(model="text-embedding-ada-002")
chat = llms.OpenAIChat(model="gpt-3.5-turbo",temperature=0.05)
text_splitter = TokenCountSplitter(max_tokens=400)

Initialize the Pathway real-time vector store:

vector_server = VectorStoreServer(
   documents,
   embedder=embedder,
   splitter=text_splitter,
   parser=ParseUnstructured(),
)

Connect to the Pathway webserver to fetch the queries:

webserver = pw.io.http.PathwayWebserver(host="0.0.0.0", port=8000)
queries, writer = pw.io.http.rest_connector(
   webserver=webserver,
   schema=PWAIQuerySchema,
   autocommit_duration_ms=50,
   delete_completed_queries=True,
)

Then query the vector index to get the relevant documents:

results = queries + vector_server.retrieve_query(
   queries.select(
       query=pw.this.query,
       k=1,
       metadata_filter=pw.cast(str | None, None),
       filepath_globpattern=pw.cast(str | None, None)
   )
).select(
   docs=pw.this.result,
)

Now you're all set to create the RAG prompt using only the relevant documents:

# Generate the prompt
results += results.select(
   rag_prompt=prep_rag_prompt(pw.this.query, pw.this.docs)
)
# Query the LLM with the prompt
results += results.select(
   result=chat(
       llms.prompt_chat_single_qa(pw.this.rag_prompt),
   )
)

# Send back the answer
writer(results)

Use pw.run() to run the pipeline in real-time:

pw.run()

You're now all set to query your documents. You can do so using the CLI.

You can use curl to submit a query to your application:

curl --data '{
  "user": "user",
  "query": "What is the revenue of Alphabet in 2022 in millions of dollars?"
}' http://localhost:8000/  

Which will return the answer:

$282,836

You can open the relevant PDF in the documents folder to confirm that this number is correct. Great!

Now comes the fun part. Let's see if we can make our LLM app unlearn this information.

Remove the Alphabet financial document from the documents folder, for example by running the following command from the project root:

rm documents/20230203_alphabet_10K.pdf

Now immediately re-run the same query you ran above with curl. Your LLM app will show the following output:

No information found.

Congratulations! You've successfully removed information from your LLM knowledge base.

Reactive Vector Index for LLM Unlearning

This example shows how to build an LLM app that can react instantly to changes to its knowledge context. This process scales to real-time databases processing thousands of updates per minute.

The real-time reactivity of the vector index is key here. Because it is an event-driven process, the vector index is updated every time a change is made. New data is available to the LLM app as soon as it is inserted. Irrelevant or compromised data is unlearned as soon as it is removed. This strategy for effective machine unlearning with LLM apps requires a streaming architecture for real-time unlearning.

If you want to learn more about adaptive RAG, don't hesitate to read our article about how to decrease RAG cost or join us on Discord.

Avril Aysha

Developer Advocate

machine unlearningLLMschatbotadaptive RAGRAGadaptive indexindex
Share this article
Share new articles with me each month

Comments