Neo4j Python SDK手册

安装

pip3 install neo4j

常规用法

连接数据库

python 复制代码
from neo4j import GraphDatabase

# URI examples: "neo4j://localhost", "neo4j+s://xxx.databases.neo4j.io"
URI = "<database-uri>"
AUTH = ("<username>", "<password>")

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    # 检查是否可以连接到数据库
    driver.verify_connectivity()
    print("Connection established.")

Driver创建了一个数据库的访问信息, 但是并不是实际建立连接, 只有当地一个查询发起时才会建立连接, Driver对象是不可变的、线程安全的,但创建起来开销很大

写数据(CREATE)

python 复制代码
summary = driver.execute_query("""
    CREATE (a:Person {name: $name})
    CREATE (b:Person {name: $friendName})
    CREATE (a)-[:KNOWS]->(b)
    """,
    name="Alice", friendName="David",
    database_="<database-name>",
).summary
print("Created {nodes_created} nodes in {time} ms.".format(
    nodes_created=summary.counters.nodes_created,
    time=summary.result_available_after
))

读数据(MATCH)

python 复制代码
records, summary, keys = driver.execute_query("""
    MATCH (p:Person)-[:KNOWS]->(:Person)
    RETURN p.name AS name
    """,
    database_="<database-name>",
)

# Loop through results and do something with them
for record in records:
    print(record.data())  # get record as dict

# Summary information
print("The query `{query}` returned {records_count} records in {time} ms.".format(
    query=summary.query, records_count=len(records),
    time=summary.result_available_after
))

summary为服务器返回的执行摘要

更新数据(MATCH/SET)

python 复制代码
records, summary, keys = driver.execute_query("""
    MATCH (p:Person {name: $name})
    SET p.age = $age
    """, name="Alice", age=42,
    database_="<database-name>",
)
print(f"Query counters: {summary.counters}.")

新建关系(MATCH/CREATE)

python 复制代码
records, summary, keys = driver.execute_query("""
    MATCH (alice:Person {name: $name})
    MATCH (bob:Person {name: $friend})
    CREATE (alice)-[:KNOWS]->(bob)
    """, name="Alice", friend="Bob",
    database_="<database-name>",
)
print(f"Query counters: {summary.counters}.")

首先通过MATCH匹配不同的节点, 然后通过CREATE将起连接

删除数据(DETACH/DELETE)

python 复制代码
# This does not delete _only_ p, but also all its relationships!
records, summary, keys = driver.execute_query("""
    MATCH (p:Person {name: $name})
    DETACH DELETE p
    """, name="Alice",
    database_="<database-name>",
)
print(f"Query counters: {summary.counters}.")

参数管理

不要将参数直接硬编码或拼接到查询中。而应始终使用占位符,并将动态数据作为Cypher 参数提供。这适用于:

  • 性能优势:Neo4j 可以编译和缓存查询,但只有在查询结构不变的情况下才能做到这一点;
  • 出于安全考虑:请参阅防止密码注入

关键字参数

python 复制代码
driver.execute_query(
    "MERGE (:Person {name: $name})",
    name="Alice", age=42,
    database_="<database-name>",
)

字典参数

python 复制代码
parameters = {
    "name": "Alice",
    "age": 42
}
driver.execute_query(
    "MERGE (:Person {name: $name})",
    parameters_=parameters,
    database_="<database-name>",
)

关键字参数优先级大于字典参数

异常处理

所有来自服务器的异常都是 的子类Neo4jError。您可以使用异常的代码来稳定地识别特定错误;而错误消息则不是稳定的标记,不应依赖它们。

常规异常

python 复制代码
# from neo4j.exceptions import Neo4jError

try:
    driver.execute_query('MATCH (p:Person) RETURN', database_='<database-name>')
except Neo4jError as e:
    print('Neo4j error code:', e.code)
    print('Exception message:', e.message)
'''
Neo4j error code: Neo.ClientError.Statement.SyntaxError
Exception message: Invalid input '': expected an expression, '*', 'ALL' or 'DISTINCT' (line 1, column 24 (offset: 23))
"MATCH (p:Person) RETURN"
                        ^
'''

嵌套异常

异常对象也会将错误以 GQL 状态对象的形式公开。Neo4j错误代码和GQL 错误代码的主要区别在于,GQL 错误代码更加细粒度:一个 Neo4j 错误代码可能细分为多个更具体的 GQL 错误代码。

触发异常的真正原因有时可以在可选的 GQL 状态对象中找到__cause__,该对象本身也是一个异常(或None)。您可能需要递归遍历原因链才能找到捕获到的异常的根本原因。在下面的示例中,异常的 GQL 状态码为 false 42001,但错误的实际来源的状态码为 false 42I06。

python 复制代码
# from neo4j.exceptions import Neo4jError

try:
    driver.execute_query('MATCH (p:Person) RETURN', database_='<database-name>')
except Neo4jError as e:
    print('Exception GQL status:', e.gql_status)
    print('Exception GQL status description:', e.gql_status_description)
    print('Exception GQL classification:', e.gql_classification)
    print('Exception GQL cause:', e.__cause__)
    print('Exception GQL diagnostic record:', e.diagnostic_record)
'''
Exception GQL status: 42001
Exception GQL status description: error: syntax error or access rule violation - invalid syntax
Exception GQL classification: GqlErrorClassification.CLIENT_ERROR
Exception GQL cause: {gql_status: 42I06} {gql_status_description: error: syntax error or access rule violation - invalid input. Invalid input '', expected: an expression, '*', 'ALL' or 'DISTINCT'.} {message: 42I06: Invalid input '', expected: an expression, '*', 'ALL' or 'DISTINCT'.} {diagnostic_record: {'_classification': 'CLIENT_ERROR', '_position': {'line': 1, 'column': 24, 'offset': 23}, 'OPERATION': '', 'OPERATION_CODE': '0', 'CURRENT_SCHEMA': '/'}} {raw_classification: CLIENT_ERROR}
Exception GQL diagnostic record: {'_classification': 'CLIENT_ERROR', '_position': {'line': 1, 'column': 24, 'offset': 23}, 'OPERATION': '', 'OPERATION_CODE': '0', 'CURRENT_SCHEMA': '/'}
'''

区分不同异常

当您希望应用程序根据服务器引发的具体错误而采取不同的行为时,GQL 状态码尤其有用

python 复制代码
# from neo4j.exceptions import Neo4jError

try:
    driver.execute_query('MATCH (p:Person) RETURN', database_='<database-name>')
except Neo4jError as e:
    if e.find_by_gql_status('42001'):
        # Neo.ClientError.Statement.SyntaxError
        # special handling of syntax error in query
        print(e.message)
    elif e.find_by_gql_status('42NFF'):
        # Neo.ClientError.Security.Forbidden
        # special handling of user not having CREATE permissions
        print(e.message)
    else:
        # handling of all other exceptions
        print(e.message)

查询配置

数据库选择

如果不指定数据库, 系统会进行服务器通信获取可用数据库造成资源浪费

python 复制代码
driver.execute_query(
    "MATCH (p:Person) RETURN p.name",
    database_="<database-name>",
)

请求路由

集群模式下默认情况下都会到leader节点可以使用参数路由到读取节点

python 复制代码
driver.execute_query(
    "MATCH (p:Person) RETURN p.name",
    routing_="r",  # short for neo4j.RoutingControl.READ
    database_="<database-name>",
)

自定义用户

python 复制代码
driver.execute_query(
    "MATCH (p:Person) RETURN p.name",
    auth_=("<username>", "<password>"),
    database_="<database-name>",
)

操作查询结果

自己组织查询结果的返回, 通过参数result_transformer_

返回DataFrame

python 复制代码
import neo4j

pandas_df = driver.execute_query(
    "UNWIND range(1, 10) AS n RETURN n, n+1 AS m",
    database_="<database-name>",
    result_transformer_=neo4j.Result.to_df
)
print(type(pandas_df))  # <class 'pandas.core.frame.DataFrame'>

自定义返回结果

python 复制代码
# Get a single record (or an exception) and the summary from a result.
def get_single_person(result):
    record = result.single(strict=True)
    summary = result.consume()
    return record, summary


record, summary = driver.execute_query(
    "MERGE (a:Person {name: $name}) RETURN a.name AS name",
    name="Alice",
    database_="<database-name>",
    result_transformer_=get_single_person,
)
print("The query `{query}` returned {record} in {time} ms.".format(
      query=summary.query, record=record, time=summary.result_available_after))
python 复制代码
# Get exactly 5 records, or an exception.
def exactly_5(result):
    records = result.fetch(5)

    if len(records) != 5:
        raise Exception(f"Expected exactly 5 records, found only {len(records)}.")
    if result.peek():
        raise Exception("Expected exactly 5 records, found more.")

    return records


records = driver.execute_query("""
    UNWIND ['Alice', 'Bob', 'Laura', 'John', 'Patricia'] AS name
    MERGE (a:Person {name: name}) RETURN a.name AS name
    """, database_="<database-name>",
    result_transformer_=exactly_5,
)

结果处理

处理结果最简单的方法是将其强制转换为列表,这将生成一个对象列表Record。此外,Result对象本身也实现了多种处理记录的方法。下面列出了最常用的方法。

姓名 描述
value(key=0, default=None) 将结果的剩余部分以列表形式返回。如果key指定了 is_property 参数,则仅包含给定的属性;default允许为缺少该属性的节点指定一个值。
fetch(n) 返回n结果中的记录数。
single(strict=False) 返回下一个也是唯一剩余的记录,或者None。调用此方法总是会遍历所有结果。如果存在多条(或少于一条)记录,
  • strict==False --- 生成警告并返回第一个警告(如果有);
  • strict==True --- aResultNotSingleError被提起。
函数 说明
peek() 返回下一条记录而不将其读取。这样,该记录将保留在缓冲区中以供进一步处理。
data(*keys) 返回类似 JSON 的数据转储文件。仅用于调试/原型设计目的。
consume() 返回查询结果摘要。它会遍历所有结果,因此只有在数据处理完成后才应调用此函数。
graph() 将结果转换为图对象集合。请参阅"转换为图"。
to_df(expand, parse_dates) 将结果转换为 Pandas 数据帧。请参阅"转换为 Pandas 数据帧"。

高级用法

自定义事物

默认情况系统会自动创建一个事务, 一个事务就是一个工作单元一个语句, 不支持在多个查询之间穿插客户端逻辑(事务交叉)

创建会话

会话不是线程安全的, 每个线程应当创建自己的会话

python 复制代码
with driver.session(database="<database-name>") as session:
    ...

运行和托管事务

一个事务可能包含多个查询语句, 要么全部成功要么全部失败

Session.execute_read()您可以使用 require() remove() 方法创建托管事务Session.execute_write(),具体取决于您是想从数据库检索数据还是修改数据。这两个方法都接受一个事务函数回调,该回调负责执行查询并处理结果。

python 复制代码
# 事务函数回调负责运行查询
def match_person_nodes(tx, name_filter):
    """
    @param tx: session对象
    @param name_filter: 查询过滤
    @return: List, 这里不应当返回result对象, 这样会导致事务提交
    """
    # 使用此方法Transaction.run()运行查询。每次查询运行都会返回一个Result对象
    result = tx.run("""
        MATCH (p:Person) WHERE p.name STARTS WITH $filter
        RETURN p.name AS name ORDER BY name
        """, filter=name_filter)
    # 使用以下任何方法处理结果Result。
    return list(result)  # a list of Record objects

# 创建会话。一个会话可以容纳多个查询。除非使用with构造函数创建,否则请记住在完成后关闭会话。
with driver.session(database="<database-name>") as session:
    # 	.execute_read()`(or )`方法.execute_write()是事务的入口点。它接受一个事务函数的回调函数以及任意数量的位置参数和关键字参数,这些参数将传递给事务函数。
    people = session.execute_read(
        match_person_nodes,
        "Al",
    )
    for person in people:
        print(person.data())  # obtain dict representation

事务参数配置

装饰器unit_of_work()进一步控制了事务功能。它允许指定:

  • 事务超时时间(以秒为单位)。超时事务将被服务器终止。默认值在服务器端设置。最小值为 1 毫秒(0.001)。
  • 一个包含附加到交易的元数据的字典。这些元数据会被记录在服务器中query.log,并显示在 Cypher 命令的输出中SHOW TRANSACTIONS。使用此功能可以标记交易
python 复制代码
from neo4j import unit_of_work

@unit_of_work(timeout=5, metadata={"app_name": "people_tracker"})
def count_people(tx):
    result = tx.run("MATCH (a:Person) RETURN count(a) AS people")
    record = result.single()
    return record["people"]


with driver.session(database="<database-name>") as session:
    people_n = session.execute_read(count_people)

手动控制事务

您可以通过使用该方法手动启动事务,从而实现对事务的完全控制Session.begin_transaction()。然后,您可以使用该方法在显式事务中运行查询Transaction.run()

可以使用 commit 提交显式事务, Transaction.commit()也可以使用 rollback 回滚 显式事务Transaction.rollback()。如果没有采取任何显式操作,驱动程序将在其生命周期结束时自动回滚事务。

python 复制代码
with driver.session(database="<database-name>") as session:
    with session.begin_transaction() as tx:
        # use tx.run() to run queries and tx.commit() when done
        tx.run("<QUERY 1>")
        tx.run("<QUERY 2>")

        tx.commit()

完整示例

如果失败被判定为暂时性失败(例如由于服务器暂时不可用),驱动程序会自动重试运行失败的事务。如果在配置的最大重试次数后操作仍然失败,则会引发错误。

由于事务可能会被重新执行,因此事务函数必须是幂等的(即,多次执行时应产生相同的结果),因为你无法预先知道它们会被执行多少次。实际上,这意味着你不应该修改或依赖全局变量。请注意,尽管事务函数可能会被多次执行,但其中的数据库查询始终只会执行一次。

自定义函数

python 复制代码
from neo4j import GraphDatabase


URI = "<database-uri>"
AUTH = ("<username>", "<password>")
employee_threshold=10


def main():
    with GraphDatabase.driver(URI, auth=AUTH) as driver:
        with driver.session(database="<database-name>") as session:
            for i in range(100):
                name = f"Thor{i}"
                org_id = session.execute_write(employ_person_tx, name)
                print(f"User {name} added to organization {org_id}")


def employ_person_tx(tx, name):
    # Create new Person node with given name, if not exists already
    result = tx.run("""
        MERGE (p:Person {name: $name})
        RETURN p.name AS name
        """, name=name
    )

    # Obtain most recent organization ID and the number of people linked to it
    result = tx.run("""
        MATCH (o:Organization)
        RETURN o.id AS id, COUNT{(p:Person)-[r:WORKS_FOR]->(o)} AS employees_n
        ORDER BY o.created_date DESC
        LIMIT 1
    """)
    org = result.single()

    if org is not None and org["employees_n"] == 0:
        raise Exception("Most recent organization is empty.")
        # Transaction will roll back -> not even Person is created!

    # If org does not have too many employees, add this Person to that
    if org is not None and org.get("employees_n") < employee_threshold:
        result = tx.run("""
            MATCH (o:Organization {id: $org_id})
            MATCH (p:Person {name: $name})
            MERGE (p)-[r:WORKS_FOR]->(o)
            RETURN $org_id AS id
            """, org_id=org["id"], name=name
        )

    # Otherwise, create a new Organization and link Person to it
    else:
        result = tx.run("""
            MATCH (p:Person {name: $name})
            CREATE (o:Organization {id: randomuuid(), created_date: datetime()})
            MERGE (p)-[r:WORKS_FOR]->(o)
            RETURN o.id AS id
            """, name=name
        )

    # Return the Organization ID to which the new Person ends up in
    return result.single()["id"]


if __name__ == "__main__":
    main()

自定义事务提交

python 复制代码
import neo4j


URI = "<database-uri>"
AUTH = ("<username>", "<password>")


def main():
    with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
        customer_id = create_customer(driver)
        other_bank_id = 42
        transfer_to_other_bank(driver, customer_id, other_bank_id, 999)


def create_customer(driver):
    result, _, _ = driver.execute_query("""
        MERGE (c:Customer {id: rand()})
        RETURN c.id AS id
    """, database_ = "<database-name>")
    return result[0]["id"]


def transfer_to_other_bank(driver, customer_id, other_bank_id, amount):
    with driver.session(database="<database-name>") as session:
        with session.begin_transaction() as tx:
            if not customer_balance_check(tx, customer_id, amount):
                # give up
                return

            other_bank_transfer_api(customer_id, other_bank_id, amount)
            # Now the money has been transferred => can't rollback anymore
            # (cannot rollback external services interactions)

            try:
                decrease_customer_balance(tx, customer_id, amount)
                tx.commit()
            except Exception as e:
                request_inspection(customer_id, other_bank_id, amount, e)
                raise  # roll back


def customer_balance_check(tx, customer_id, amount):
    query = ("""
        MATCH (c:Customer {id: $id})
        RETURN c.balance >= $amount AS sufficient
    """)
    result = tx.run(query, id=customer_id, amount=amount)
    record = result.single(strict=True)
    return record["sufficient"]


def other_bank_transfer_api(customer_id, other_bank_id, amount):
    # make some API call to other bank
    pass


def decrease_customer_balance(tx, customer_id, amount):
    query = ("""
        MATCH (c:Customer {id: $id})
        SET c.balance = c.balance - $amount
    """)
    result = tx.run(query, id=customer_id, amount=amount)
    result.consume()


def request_inspection(customer_id, other_bank_id, amount, e):
    # manual cleanup required; log this or similar
    print("WARNING: transaction rolled back due to exception:", repr(e))
    print("customer_id:", customer_id, "other_bank_id:", other_bank_id,
          "amount:", amount)


if __name__ == "__main__":
    main()

获取执行摘要

执行摘要

查询结果的所有记录处理完毕后,服务器会返回执行摘要以结束事务。该摘要以ResultSummary对象的形式返回,其中包含以下信息:

  • 查询计数器------服务器上触发的查询发生了哪些变化
  • 查询执行计划------数据库将如何执行(或已经执行)查询
  • 通知--- 服务器在运行查询时发出的额外信息
  • 时间信息和查询请求摘要

普通任务获取摘要

使用 运行查询时Driver.execute_query(),执行摘要是默认返回值的一部分,作为第二个对象。

python 复制代码
records, result_summary, keys = driver.execute_query("""
    UNWIND ["Alice", "Bob"] AS name
    MERGE (p:Person {name: name})
    """, database_="<database-name>",
)
# or result_summary = driver.execute_query('<QUERY>').summary

自定义事务函数和自定义转换器

如果您使用事务函数或自定义转换器,Driver.execute_query()则可以使用 getExecutiveSummary 方法检索查询执行摘要Result.consume()。 请注意,一旦您请求执行摘要,结果流就会耗尽:任何尚未处理的记录都将不再可用。

python 复制代码
def create_people(tx):
    result = tx.run("""
        UNWIND ["Alice", "Bob"] AS name
        MERGE (p:Person {name: name})
    """)
    return result.consume()

with driver.session(database="<database-name>") as session:
    result_summary = session.execute_write(create_people)
计数器

该属性ResultSummary.counters包含查询触发的操作计数器(以SummaryCounters对象的形式)。

python 复制代码
summary = driver.execute_query("""
    MERGE (p:Person {name: $name})
    MERGE (f:Person {name: $friend})
    MERGE (p)-[:KNOWS]->(f)
    """, name="Mark", friend="Bob",
    database_="<database-name>",
).summary
print(summary.counters)
"""
{'_contains_updates': True, 'labels_added': 2, 'relationships_created': 1,
 'nodes_created': 2, 'properties_set': 2}
"""

另外两个布尔属性用作元计数器:

  • contains_updates--- 查询是否触发了对其运行所在的数据库的任何写入操作
  • contains_system_updates--- 查询是否触发了对system数据库的任何写入操作
执行计划

如果在查询语句前加上 --query-plan 前缀EXPLAIN,服务器会返回用于执行该查询的计划,但实际上并不会执行该查询。该计划位于 --query-plan 属性下,其中包含用于检索结果集的Cypher 运算ResultSummary.plan符列表。利用此信息可以定位潜在的性能瓶颈或潜在的性能改进方案(例如创建索引)。

python 复制代码
_, summary, _ = driver.execute_query("EXPLAIN MATCH (p {name: $name}) RETURN p", name="Alice")
print(summary.plan['args']['string-representation'])

"""
Planner COST
Runtime PIPELINED
Runtime version 5.0
Batch size 128

+-----------------+----------------+----------------+---------------------+
| Operator        | Details        | Estimated Rows | Pipeline            |
+-----------------+----------------+----------------+---------------------+
| +ProduceResults | p              |              1 |                     |
| |               +----------------+----------------+                     |
| +Filter         | p.name = $name |              1 |                     |
| |               +----------------+----------------+                     |
| +AllNodesScan   | p              |             10 | Fused in Pipeline 0 |
+-----------------+----------------+----------------+---------------------+

Total database accesses: ?
"""

查询封装

通用 Neo4j 客户端

python 复制代码
## neo4j_client.py
import re
from typing import (
    Type, Dict, Any, List, Optional, Callable, TypeVar, Union
)
from pydantic import BaseModel
from neo4j import GraphDatabase, Driver, Record, Session

T = TypeVar('T', bound=BaseModel)

class GenericNeo4jClient:
    """
    通用 Neo4j 客户端,支持任意 Pydantic 模型。
    要求模型定义:
      - __label__: ClassVar[str] (节点标签)
      - 字段可通过 Field(..., unique=True) 标记为唯一键(用于 MERGE)
    """

    def __init__(self, uri: str, user: str, password: str):
        self.driver: Driver = GraphDatabase.driver(uri, auth=(user, password))
    
    def close(self):
        self.driver.close()

    ## --- 工具方法 ---
    @staticmethod
    def _validate_identifier(name: str) -> bool:
        """校验标签/关系类型是否合法(防注入)"""
        return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))

    def _extract_node_schema(self, model: Type[BaseModel]) -> Dict[str, Any]:
        """从 Pydantic 模型提取节点元数据"""
        label = getattr(model, "__label__", model.__name__)
        if not self._validate_identifier(label):
            raise ValueError(f"Invalid label: {label}")
        
        fields = model.model_fields
        all_props = list(fields.keys())
        merge_keys = [
            k for k, v in fields.items()
            if v.json_schema_extra and v.json_schema_extra.get("unique")
        ]
        return {"label": label, "all_props": all_props, "merge_keys": merge_keys}

    def _extract_rel_schema(self, model: Type[BaseModel]) -> Dict[str, Any]:
        """从 Pydantic 模型提取关系元数据"""
        rel_type = getattr(model, "__rel_type__", model.__name__.upper())
        if not self._validate_identifier(rel_type):
            raise ValueError(f"Invalid relationship type: {rel_type}")
        fields = model.model_fields
        all_props = list(fields.keys())
        return {"rel_type": rel_type, "all_props": all_props}

    def _model_to_dict(self, obj: BaseModel) -> Dict[str, Any]:
        return obj.model_dump(exclude_none=True)

    def _node_to_model(self, node, model_class: Type[T]) -> T:
        data = dict(node.items())
        return model_class(**data)

    def _record_to_model(self, record: Record, model_class: Type[T]) -> T:
        """将 Neo4j Record 转为 Pydantic 模型(支持嵌套 Node/Relationship)"""
        data = {}
        for field_name in model_class.model_fields:
            if field_name in record:
                value = record[field_name]
                if hasattr(value, 'items'):  ## Node or Relationship
                    data[field_name] = dict(value.items())
                else:
                    data[field_name] = value
            else:
                data[field_name] = record.get(field_name)
        return model_class(**data)

    ## --- 节点操作 ---
    def create_node(self, obj: BaseModel) -> BaseModel:
        schema = self._extract_node_schema(type(obj))
        props = self._model_to_dict(obj)
        cypher = f"CREATE (n:{schema['label']}) SET n = $props RETURN n"
        with self.driver.session() as session:
            result = session.run(cypher, props=props)
            return self._node_to_model(result.single()["n"], type(obj))

    def merge_node(self, obj: BaseModel) -> BaseModel:
        schema = self._extract_node_schema(type(obj))
        if not schema["merge_keys"]:
            raise ValueError(f"No 'unique=True' field in {obj.__class__.__name__}")
        
        props = self._model_to_dict(obj)
        merge_props = {k: props[k] for k in schema["merge_keys"] if k in props}
        other_props = {k: v for k, v in props.items() if k not in schema["merge_keys"]}
        
        merge_clause = ", ".join([f"{k}: ${k}" for k in merge_props.keys()])
        cypher = f"""
        MERGE (n:{schema['label']} {{{merge_clause}}})
        ON CREATE SET n += $other_props
        ON MATCH SET n += $other_props
        RETURN n
        """
        params = {**merge_props, "other_props": other_props}
        
        with self.driver.session() as session:
            result = session.run(cypher, **params)
            return self._node_to_model(result.single()["n"], type(obj))

    def find_nodes(
        self,
        model_class: Type[BaseModel],
        match_props: Optional[Dict[str, Any]] = None
    ) -> List[BaseModel]:
        schema = self._extract_node_schema(model_class)
        if match_props:
            where_clause = " AND ".join([f"n.{k} = ${k}" for k in match_props.keys()])
            cypher = f"MATCH (n:{schema['label']}) WHERE {where_clause} RETURN n"
        else:
            cypher = f"MATCH (n:{schema['label']}) RETURN n"
        
        with self.driver.session() as session:
            results = session.run(cypher, **(match_props or {}))
            return [self._node_to_model(record["n"], model_class) for record in results]

    def update_nodes(
        self,
        model_class: Type[BaseModel],
        set_props: Dict[str, Any],
        match_props: Dict[str, Any]
    ) -> int:
        if not match_props:
            raise ValueError("match_props must be provided for update")
        schema = self._extract_node_schema(model_class)
        set_clause = ", ".join([f"n.{k} = ${k}" for k in set_props.keys()])
        where_clause = " AND ".join([f"n.{k} = ${k}" for k in match_props.keys()])
        cypher = f"MATCH (n:{schema['label']}) WHERE {where_clause} SET {set_clause} RETURN count(n) AS count"
        
        params = {**match_props, **set_props}
        with self.driver.session() as session:
            result = session.run(cypher, **params)
            return result.single()["count"]

    def delete_nodes(
        self,
        model_class: Type[BaseModel],
        match_props: Dict[str, Any]
    ) -> int:
        if not match_props:
            raise ValueError("match_props must be provided for delete")
        schema = self._extract_node_schema(model_class)
        where_clause = " AND ".join([f"n.{k} = ${k}" for k in match_props.keys()])
        cypher = f"MATCH (n:{schema['label']}) WHERE {where_clause} DETACH DELETE n RETURN count(n) AS count"
        
        with self.driver.session() as session:
            result = session.run(cypher, **match_props)
            return result.single()["count"]

    ## --- 关系操作 ---
    def create_relationship(
        self,
        from_obj: BaseModel,
        to_obj: BaseModel,
        rel_obj: BaseModel
    ) -> BaseModel:
        from_schema = self._extract_node_schema(type(from_obj))
        to_schema = self._extract_node_schema(type(to_obj))
        rel_schema = self._extract_rel_schema(type(rel_obj))

        from_merge = {k: getattr(from_obj, k) for k in from_schema["merge_keys"]}
        to_merge = {k: getattr(to_obj, k) for k in to_schema["merge_keys"]}
        if not from_merge or not to_merge:
            raise ValueError("Source and target nodes must have unique fields")

        rel_props = self._model_to_dict(rel_obj)
        from_match = ", ".join([f"a.{k} = ${'from_' + k}" for k in from_merge.keys()])
        to_match = ", ".join([f"b.{k} = ${'to_' + k}" for k in to_merge.keys()])
        
        cypher = f"""
        MATCH (a:{from_schema['label']}), (b:{to_schema['label']})
        WHERE {from_match} AND {to_match}
        CREATE (a)-[r:{rel_schema['rel_type']}]->(b)
        SET r = $rel_props
        RETURN r
        """
        params = {
            **{f"from_{k}": v for k, v in from_merge.items()},
            **{f"to_{k}": v for k, v in to_merge.items()},
            "rel_props": rel_props
        }
        
        with self.driver.session() as session:
            result = session.run(cypher, **params)
            if not result.single():
                raise ValueError("Source or target node not found")
            return rel_obj

    ## --- 批量操作 ---
    def batch_merge_nodes(self, objects: List[BaseModel]) -> List[BaseModel]:
        if not objects:
            return []
        model_class = type(objects[0])
        schema = self._extract_node_schema(model_class)
        if not schema["merge_keys"]:
            raise ValueError("Batch merge requires unique fields")

        def _batch_tx(tx):
            results = []
            for obj in objects:
                props = self._model_to_dict(obj)
                merge_props = {k: props[k] for k in schema["merge_keys"] if k in props}
                other_props = {k: v for k, v in props.items() if k not in schema["merge_keys"]}
                merge_clause = ", ".join([f"{k}: ${k}" for k in merge_props.keys()])
                cypher = f"""
                MERGE (n:{schema['label']} {{{merge_clause}}})
                ON CREATE SET n += $other_props
                ON MATCH SET n += $other_props
                RETURN n
                """
                params = {**merge_props, "other_props": other_props}
                result = tx.run(cypher, **params)
                node = result.single()["n"]
                results.append(self._node_to_model(node, model_class))
            return results

        with self.driver.session() as session:
            return session.write_transaction(_batch_tx)

    def batch_create_relationships(
        self,
        rel_data: List[tuple[BaseModel, BaseModel, BaseModel]]
    ) -> List[BaseModel]:
        if not rel_data:
            return []

        def _batch_rel_tx(tx):
            results = []
            for from_obj, to_obj, rel_obj in rel_data:
                from_schema = self._extract_node_schema(type(from_obj))
                to_schema = self._extract_node_schema(type(to_obj))
                rel_schema = self._extract_rel_schema(type(rel_obj))

                from_merge = {k: getattr(from_obj, k) for k in from_schema["merge_keys"]}
                to_merge = {k: getattr(to_obj, k) for k in to_schema["merge_keys"]}
                rel_props = self._model_to_dict(rel_obj)

                from_match = ", ".join([f"a.{k} = ${'from_' + k}" for k in from_merge.keys()])
                to_match = ", ".join([f"b.{k} = ${'to_' + k}" for k in to_merge.keys()])
                
                cypher = f"""
                MATCH (a:{from_schema['label']}), (b:{to_schema['label']})
                WHERE {from_match} AND {to_match}
                CREATE (a)-[r:{rel_schema['rel_type']}]->(b)
                SET r = $rel_props
                RETURN r
                """
                params = {
                    **{f"from_{k}": v for k, v in from_merge.items()},
                    **{f"to_{k}": v for k, v in to_merge.items()},
                    "rel_props": rel_props
                }
                tx.run(cypher, **params)
                results.append(rel_obj)
            return results

        with self.driver.session() as session:
            return session.write_transaction(_batch_rel_tx)

    ## --- 通用查询 & 分页 ---
    def query(
        self,
        cypher: str,
        parameters: Optional[Dict[str, Any]] = None,
        result_transformer: Optional[Callable[[Record], T]] = None,
        model_class: Optional[Type[T]] = None
    ) -> List[T]:
        if parameters is None:
            parameters = {}
        with self.driver.session() as session:
            results = session.run(cypher, **parameters)
            if result_transformer:
                return [result_transformer(record) for record in results]
            elif model_class:
                return [self._record_to_model(record, model_class) for record in results]
            else:
                return [dict(record) for record in results]

    def paginate(
        self,
        model_class: Type[BaseModel],
        match_conditions: Optional[Dict[str, Any]] = None,
        order_by: str = "id(n)",
        ascending: bool = True,
        page: int = 1,
        size: int = 10
    ) -> List[BaseModel]:
        if page < 1 or size < 1:
            raise ValueError("page and size must be positive integers")
        
        schema = self._extract_node_schema(model_class)
        skip = (page - 1) * size
        direction = "ASC" if ascending else "DESC"
        
        where_clause = ""
        params = {}
        if match_conditions:
            where_clause = "WHERE " + " AND ".join([f"n.{k} = ${k}" for k in match_conditions.keys()])
            params.update(match_conditions)
        
        cypher = f"""
        MATCH (n:{schema['label']})
        {where_clause}
        RETURN n
        ORDER BY {order_by} {direction}
        SKIP $skip
        LIMIT $limit
        """
        params.update({"skip": skip, "limit": size})
        
        return self.query(cypher, params, model_class=model_class)

测试代码

python 复制代码
## test_example.py
from pydantic import BaseModel, Field
from typing import ClassVar, Optional
from neo4j_client import GenericNeo4jClient

## === 业务模型(仅测试用)===
class Person(BaseModel):
    __label__: ClassVar[str] = "Person"
    email: str = Field(..., unique=True)
    name: str
    age: Optional[int] = None

class Book(BaseModel):
    __label__: ClassVar[str] = "Book"
    isbn: str = Field(..., unique=True)
    title: str

class Read(BaseModel):
    __rel_type__: ClassVar[str] = "READ"
    rating: int

## === 测试脚本 ===
if __name__ == "__main__":
    client = GenericNeo4jClient("bolt://localhost:7687", "neo4j", "password")
    
    ## 创建/更新
    alice = client.merge_node(Person(email="alice@example.com", name="Alice", age=30))
    
    ## 查找
    persons = client.find_nodes(Person, {"email": "alice@example.com"})
    
    ## 分页
    page = client.paginate(Person, page=1, size=10, order_by="n.name")
    
    ## 复杂查询
    books = client.query(
        "MATCH (p:Person {email: $email})-[:READ]->(b:Book) RETURN b",
        {"email": "alice@example.com"},
        model_class=Book
    )
    
    client.close()
相关推荐
%xiao Q2 小时前
GESP C++五级-202406
android·开发语言·c++
Traced back2 小时前
# C# + SQL Server 实现自动清理功能的完整方案:按数量与按日期双模式
开发语言·c#
sin22012 小时前
MyBatis的执行流程
java·开发语言·mybatis
web3.08889992 小时前
1688图片搜索API,相似商品精准推荐
开发语言·python
二哈喇子!2 小时前
JAVA环境变量配置步骤及测试(JDK的下载 & 安装 & 环境配置教程)
java·开发语言
少云清2 小时前
【性能测试】15_JMeter _JMeter插件安装使用
开发语言·python·jmeter
yj爆裂鼓手2 小时前
c#万能变量
开发语言·c#
光羽隹衡2 小时前
机器学习——TF-IDF实战(红楼梦数据处理)
python·tf-idf
GGGG寄了3 小时前
HTML——文本标签
开发语言·前端·html