Agentic RAG with Root Signals Relevance Judge
Replication of Agentic RAG tutorial from LangGraph, where the decision of whether to use the retrieved content or not to answer a question is powered by Root Signals Evaluators.
The following is from LangGraph docs:
%%capture --no-stderr
%pip install -U --quiet langchain-community tiktoken langchain-openai langchainhub chromadb langchain langgraph langchain-text-splitters
import getpass
import os
def _set_env(key: str):
if key not in os.environ:
os.environ[key] = getpass.getpass(f"{key}:")
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import Annotated, Sequence, Literal
from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from langchain import hub
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from langgraph.prebuilt import tools_condition
from import create_retriever_tool
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
import pprint
urls = [
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=100, chunk_overlap=50
doc_splits = text_splitter.split_documents(docs_list)
# Add to vectorDB
vectorstore = Chroma.from_documents(
retriever = vectorstore.as_retriever()
retriever_tool = create_retriever_tool(
"Search and return information about Root Signals blog posts on LLM evaluation.",
tools = [retriever_tool]
class AgentState(TypedDict):
# The add_messages function defines how an update should be processed
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]
### Nodes
def agent(state):
Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
state (messages): The current state
dict: The updated state with the agent response appended to messages
print("---CALL AGENT---")
messages = state["messages"]
model = ChatOpenAI(temperature=0, streaming=True, model="gpt-4-turbo")
model = model.bind_tools(tools)
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
def rewrite(state):
Transform the query to produce a better question.
state (messages): The current state
dict: The updated state with re-phrased question
print("---TRANSFORM QUERY---")
messages = state["messages"]
question = messages[0].content
msg = [
content=f""" \n
Look at the input and try to reason about the underlying semantic intent / meaning. \n
Here is the initial question:
\n ------- \n
\n ------- \n
Formulate an improved question: """,
# Grader
model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
response = model.invoke(msg)
return {"messages": [response]}
def generate(state):
Generate answer
state (messages): The current state
dict: The updated state with re-phrased question
messages = state["messages"]
question = messages[0].content
last_message = messages[-1]
docs = last_message.content
# Prompt
prompt = hub.pull("rlm/rag-prompt")
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# Chain
rag_chain = prompt | llm | StrOutputParser()
# Run
response = rag_chain.invoke({"context": docs, "question": question})
return {"messages": [response]}
print("*" * 20 + "Prompt[rlm/rag-prompt]" + "*" * 20)
prompt = hub.pull("rlm/rag-prompt").pretty_print() # Show what the prompt looks like
Define the Decision-maker as a Root Judge
Now we define Root Signals Relevance evaluator as the decision maker for whether the answer should come from retrieved docs or not. The advantage of using Root Signals (as opposed to original LangGraph method) is:
We can control the relevance threshold because Root Signals evaluators always return a normalized score between
.If we want, we can incorporate the Justification in the decision-making process.
The code is much shorter, i.e. about ⅓ of that of LangGraph tutorial.
from root import RootSignals
client = RootSignals()
def grade_relevance(state) -> Literal["generate", "rewrite"]:
Determines whether the retrieved documents are relevant to the question.
state (messages): The current state
str: A decision for whether the documents are relevant or not
messages = state["messages"]
question = messages[0].content
docs = messages[-1].content
result = client.evaluators.Relevance(
if result.score > 0.5: # we can control the threshold
return "generate"
return "rewrite"
Rest of the tutorial is still from LangGraph:
# Define a new graph
workflow = StateGraph(AgentState)
# Define the nodes we will cycle between
workflow.add_node("agent", agent) # agent
retrieve = ToolNode([retriever_tool])
workflow.add_node("retrieve", retrieve) # retrieval
workflow.add_node("rewrite", rewrite) # Re-writing the question
"generate", generate
) # Generating a response after we know the documents are relevant
# Call agent node to decide to retrieve or not
workflow.add_edge(START, "agent")
# Decide whether to retrieve
# Assess agent decision
# Translate the condition outputs to nodes in our graph
"tools": "retrieve",
# Edges taken after the `action` node is called.
# Assess agent decision
grade_relevance, # this is Root Signals evaluator
workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")
# Compile
graph = workflow.compile()
Our RAG Agent is ready:
inputs = {
"messages": [
("user", "What is EvalOps?"),
for output in
for key, value in output.items():
pprint.pprint(f"Output from node '{key}':")
pprint.pprint(value, indent=2, width=80, depth=None)
Last updated