2024년 5월 21일 화요일

LLM 기반 그래프 RAG 기술 구현하기

이 글은 그래프 RAG 기술을 구현하는 방법을 정리한다. 

이 장은 Microsoft 에서 언급한 Graph RAG를 기반으로 실습한다.

개발 환경 설치
다음 명령을 실행해 설치한다. 
pip install lanchain
pip install openai
pip install neo4j

개발
다음과 같이 코딩한다.
import os
import json 
import pandas as pd
from dotenv import load_dotenv
load_dotenv()

오픈API 키를 설정한다.
# os.environ["OPENAI_API_KEY"] = "<your_api_key>"
# print(os.environ["OPENAI_API_KEY"])

실습할 데이터셋 로딩한다.
# Loading a json dataset from a file
file_path = 'data/amazon_product_kg.json'
with open(file_path, 'r') as file:
    jsonData = json.load(file)
df =  pd.read_json(file_path)
df.head()

데이터베이스에 연결하고, 파싱한다.
url = "bolt://localhost:7687"
username ="neo4j"
password = "<your_password_here>"

from langchain.graphs import Neo4jGraph

graph = Neo4jGraph(
    url=url, 
    username=username, 
    password=password
)

def sanitize(text):
    text = str(text).replace("'","").replace('"','').replace('{','').replace('}', '')
    return text

각 JSON 파일의 객체를 데이터베이스에 추가한다.
i = 1
for obj in jsonData:
    print(f"{i}. {obj['product_id']} -{obj['relationship']}-> {obj['entity_value']}")
    i+=1
    query = f'''
        MERGE (product:Product {{id: {obj['product_id']}}})
        ON CREATE SET product.name = "{sanitize(obj['product'])}", 
                       product.title = "{sanitize(obj['TITLE'])}", 
                       product.bullet_points = "{sanitize(obj['BULLET_POINTS'])}", 
                       product.size = {sanitize(obj['PRODUCT_LENGTH'])}

        MERGE (entity:{obj['entity_type']} {{value: "{sanitize(obj['entity_value'])}"}})

        MERGE (product)-[:{obj['relationship']}]->(entity)
        '''
    graph.query(query)

질의 처리를 위해 각 유형 속성에 대한 벡터 인덱스를 생성한다. 이는 임베딩 함수를 사용해 처리한다. 
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings_model = "text-embedding-3-small"

vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(model=embeddings_model),
    url=url,
    username=username,
    password=password,
    index_name='products',
    node_label="Product",
    text_node_properties=['name', 'title'],
    embedding_node_property='embedding',
)

def embed_entities(entity_type):
    vector_index = Neo4jVector.from_existing_graph(
        OpenAIEmbeddings(model=embeddings_model),
        url=url,
        username=username,
        password=password,
        index_name=entity_type,
        node_label=entity_type,
        text_node_properties=['value'],
        embedding_node_property='embedding',
    )
    
entities_list = df['entity_type'].unique()

for t in entities_list:
    embed_entities(t)

데이터베이스 질의를 한다.
from langchain.chains import GraphCypherQAChain
from langchain.chat_models import ChatOpenAI

chain = GraphCypherQAChain.from_llm(
    ChatOpenAI(temperature=0), graph=graph, verbose=True,
)
chain.run("""
Help me find curtains
""")

프롬프트 템플릿을 이용해, 좀 더 상세한 결과가 나올 수 있도록 한다.
entity_types = {
    "product": "Item detailed type, for example 'high waist pants', 'outdoor plant pot', 'chef kitchen knife'",
    "category": "Item category, for example 'home decoration', 'women clothing', 'office supply'",
    "characteristic": "if present, item characteristics, for example 'waterproof', 'adhesive', 'easy to use'",
    "measurement": "if present, dimensions of the item", 
    "brand": "if present, brand of the item",
    "color": "if present, color of the item",
    "age_group": "target age group for the product, one of 'babies', 'children', 'teenagers', 'adults'. If suitable for multiple age groups, pick the oldest (latter in the list)."
}

relation_types = {
    "hasCategory": "item is of this category",
    "hasCharacteristic": "item has this characteristic",
    "hasMeasurement": "item is of this measurement",
    "hasBrand": "item is of this brand",
    "hasColor": "item is of this color", 
    "isFor": "item is for this age_group"
 }

entity_relationship_match = {
    "category": "hasCategory",
    "characteristic": "hasCharacteristic",
    "measurement": "hasMeasurement", 
    "brand": "hasBrand",
    "color": "hasColor",
    "age_group": "isFor"
}

system_prompt = f'''
    You are a helpful agent designed to fetch information from a graph database. 
    
    The graph database links products to the following entity types:
    {json.dumps(entity_types)}
    
    Each link has one of the following relationships:
    {json.dumps(relation_types)}

    Depending on the user prompt, determine if it possible to answer with the graph database.
        
    The graph database can match products with multiple relationships to several entities.
    
    Example user input:
    "Which blue clothing items are suitable for adults?"
    
    There are three relationships to analyse:
    1. The mention of the blue color means we will search for a color similar to "blue"
    2. The mention of the clothing items means we will search for a category similar to "clothing"
    3. The mention of adults means we will search for an age_group similar to "adults"
        
    Return a json object following the following rules:
    For each relationship to analyse, add a key value pair with the key being an exact match for one of the entity types provided, and the value being the value relevant to the user query.
    
    For the example provided, the expected output would be:
    {{
        "color": "blue",
        "category": "clothing",
        "age_group": "adults"
    }}
    
    If there are no relevant entities in the user prompt, return an empty json object.
'''
print(system_prompt)

from openai import OpenAI
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

질의 함수를 정의하고, 원하는 제품 아이템을 질의해본다. Cypher 쿼리를 사용한다.
def define_query(prompt, model="gpt-4-1106-preview"):
    completion = client.chat.completions.create(
        model=model,
        temperature=0,
        response_format= {
            "type": "json_object"
        },
    messages=[
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": prompt
        }
        ]
    )
    return completion.choices[0].message.content

example_queries = [
    "Which pink items are suitable for children?",
    "Help me find gardening gear that is waterproof",
    "I'm looking for a bench with dimensions 100x50 for my living room"
]

for q in example_queries:
    print(f"Q: '{q}'\n{define_query(q)}\n")

추출된 엔티티가 현재 로딩한 데이터와 정확히 일치하지 않을 수 있으므로, GDS 코사인 유사도 함수를 이용해 질의된 제품을 반환할 수 있도록 한다.
def create_embedding(text):
    result = client.embeddings.create(model=embeddings_model, input=text)
    return result.data[0].embedding

# The threshold defines how closely related words should be. Adjust the threshold to return more or less results
def create_query(text, threshold=0.81):
    query_data = json.loads(text)
    # Creating embeddings
    embeddings_data = []
    for key, val in query_data.items():
        if key != 'product':
            embeddings_data.append(f"${key}Embedding AS {key}Embedding")
    query = "WITH " + ",\n".join(e for e in embeddings_data)
    # Matching products to each entity
    query += "\nMATCH (p:Product)\nMATCH "
    match_data = []
    for key, val in query_data.items():
        if key != 'product':
            relationship = entity_relationship_match[key]
            match_data.append(f"(p)-[:{relationship}]->({key}Var:{key})")
    query += ",\n".join(e for e in match_data)
    similarity_data = []
    for key, val in query_data.items():
        if key != 'product':
            similarity_data.append(f"gds.similarity.cosine({key}Var.embedding, ${key}Embedding) > {threshold}")
    query += "\nWHERE "
    query += " AND ".join(e for e in similarity_data)
    query += "\nRETURN p"
    return query

def query_graph(response):
    embeddingsParams = {}
    query = create_query(response)
    query_data = json.loads(response)
    for key, val in query_data.items():
        embeddingsParams[f"{key}Embedding"] = create_embedding(val)
    result = graph.query(query, params=embeddingsParams)
    return result

example_response = '''{
    "category": "clothes",
    "color": "blue",
    "age_group": "adults"
}'''

result = query_graph(example_response)

# Result
print(f"Found {len(result)} matching product(s):\n")
for r in result:
    print(f"{r['p']['name']} ({r['p']['id']})")

질의 결과, neo4j 데이터베이스를 검색한 후, 이를 바탕으로 LLM 생성된 텍스트를 json 형식으로 결과 얻을 수 있다.

레퍼런스

댓글 없음:

댓글 쓰기