1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
| import os from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END from langgraph.graph.message import add_messages from langchain_openai import ChatOpenAI from dotenv import load_dotenv from langchain_tavily import TavilySearch from langgraph.prebuilt import ToolNode, tools_condition from langgraph.checkpoint.memory import InMemorySaver from langgraph.types import Command, interrupt from langchain_core.tools import InjectedToolCallId, tool
import json
from langchain_core.messages import ToolMessage
load_dotenv()
memory = InMemorySaver()
class State(TypedDict): messages: Annotated[list, add_messages] name: str birthday: str
graph_builder = StateGraph(State)
llm = ChatOpenAI( model=os.getenv("LLM_MODEL_NAME"), api_key=os.getenv("LLM_API_KEY"), base_url=os.getenv("LLM_BASE_URL"), temperature=os.getenv("LLM_TEMPERATURE"), max_tokens=os.getenv("LLM_MAX_TOKENS") )
@tool
def human_assistance( name: str, birthday: str, tool_call_id: Annotated[str, InjectedToolCallId] ) -> str: """Request assistance from a human.""" human_response = interrupt( { "question": "Is this correct?", "name": name, "birthday": birthday, }, ) if human_response.get("correct", "").lower().startswith("y"): verified_name = name verified_birthday = birthday response = "Correct" else: verified_name = human_response.get("name", name) verified_birthday = human_response.get("birthday", birthday) response = f"Made a correction: {human_response}"
state_update = { "name": verified_name, "birthday": verified_birthday, "messages": [ToolMessage(response, tool_call_id=tool_call_id)], } return Command(update=state_update)
tool = TavilySearch(max_results=2) tools = [tool, human_assistance] llm_with_tools = llm.bind_tools(tools)
def chatbot(state: State): message = llm_with_tools.invoke(state["messages"]) assert len(message.tool_calls) <= 1 return {"messages": [message]}
graph_builder.add_node("chatbot", chatbot)
tool_node = ToolNode(tools=tools) graph_builder.add_node("tools", tool_node)
graph_builder.add_conditional_edges( "chatbot", tools_condition, )
graph_builder.add_edge("tools", "chatbot") graph_builder.add_edge(START, "chatbot") graph = graph_builder.compile(checkpointer=memory)
png_bytes = graph.get_graph().draw_mermaid_png()
with open("graph.png", "wb") as f: f.write(png_bytes)
user_input = ( "Can you look up when LangGraph was released? " "When you have the answer, use the human_assistance tool for review." ) config = {"configurable": {"thread_id": "1"}}
events = graph.stream( {"messages": [{"role": "user", "content": user_input}]}, config, stream_mode="values", ) for event in events: if "messages" in event: event["messages"][-1].pretty_print()
human_command = Command( resume={ "name": "LangGraph", "birthday": "Jan 17, 2024", }, )
events = graph.stream(human_command, config, stream_mode="values") for event in events: if "messages" in event: event["messages"][-1].pretty_print()
snapshot = graph.get_state(config) print({k: v for k, v in snapshot.values.items() if k in ("name", "birthday")})
graph.update_state(config, {"name": "LangGraph (library)"})
snapshot = graph.get_state(config) print({k: v for k, v in snapshot.values.items() if k in ("name", "birthday")})
|