导入数据到OG GraphQL以及创建graph

1. 数据内容

IMDb Data Files Download

2. 数据整理

2.1. 数据入库

python 复制代码
import gzip
import os
from dataclasses import dataclass
from typing import List

import psycopg


# ====== 连接串(按你的环境改)======
# 例:本机连 docker 映射端口
# PG_DSN = "host=127.0.0.1 port=55432 dbname=postgres user=postgres password=123456"
PG_DSN = os.getenv(
    "PG_DSN",
    "host=192.168.152.134 port=55434 dbname=postgres user=postgres password=123456"
)

SCHEMA = "imdb"


@dataclass
class CopyJob:
    table: str
    columns: List[str]
    gz_path: str


DDL_SQL = f"""
CREATE SCHEMA IF NOT EXISTS {SCHEMA};

DROP TABLE IF EXISTS {SCHEMA}.title_basics;
CREATE TABLE {SCHEMA}.title_basics (
  tconst          text PRIMARY KEY,
  titleType       text,
  primaryTitle    text,
  originalTitle   text,
  isAdult         int,
  startYear       int,
  endYear         int,
  runtimeMinutes  int,
  genres          text
);

DROP TABLE IF EXISTS {SCHEMA}.name_basics;
CREATE TABLE {SCHEMA}.name_basics (
  nconst              text PRIMARY KEY,
  primaryName         text,
  birthYear           int,
  deathYear           int,
  primaryProfession   text,
  knownForTitles      text
);

DROP TABLE IF EXISTS {SCHEMA}.title_principals;
CREATE TABLE {SCHEMA}.title_principals (
  tconst      text,
  ordering    int,
  nconst      text,
  category    text,
  job         text,
  characters  text
);

CREATE INDEX IF NOT EXISTS idx_principals_tconst ON {SCHEMA}.title_principals(tconst);
CREATE INDEX IF NOT EXISTS idx_principals_nconst ON {SCHEMA}.title_principals(nconst);
CREATE INDEX IF NOT EXISTS idx_principals_cat    ON {SCHEMA}.title_principals(category);
"""


def run_ddl(conn: psycopg.Connection) -> None:
    with conn.cursor() as cur:
        cur.execute(DDL_SQL)
    conn.commit()


def copy_tsv_gz(conn: psycopg.Connection, job: CopyJob) -> None:
    """
    使用 COPY FROM STDIN 导入 .tsv.gz
    - HEADER true:跳过首行表头
    - NULL '\\N':IMDb 空值标记
    - DELIMITER E'\t':TSV
    """
    full_table = f"{SCHEMA}.{job.table}"
    cols = ", ".join(job.columns)

    copy_sql = f"""
    COPY {full_table} ({cols})
    FROM STDIN
    WITH (
      FORMAT csv,
      DELIMITER E'\\t',
      HEADER true,
      NULL '\\N',
      QUOTE E'\\b'
    );
    """.strip()

    print(f"[COPY] {job.gz_path}  ->  {full_table}")

    # gzip.open(..., 'rt'):文本模式,返回 str,COPY 直接可读
    with gzip.open(job.gz_path, "rt", encoding="utf-8", errors="replace", newline="") as f:
        with conn.cursor() as cur:
            with cur.copy(copy_sql) as cp:
                for line in f:
                    cp.write(line)

    conn.commit()
    print(f"[OK] {job.table}")


def quick_check(conn: psycopg.Connection) -> None:
    sql = f"""
    SELECT
      (SELECT count(*) FROM {SCHEMA}.title_basics)     AS titles,
      (SELECT count(*) FROM {SCHEMA}.name_basics)      AS names,
      (SELECT count(*) FROM {SCHEMA}.title_principals) AS principals;
    """
    with conn.cursor() as cur:
        cur.execute(sql)
        row = cur.fetchone()
    print("[COUNT]", row)


def main():
    # ====== 改成你的实际文件路径 ======
    jobs = [
        CopyJob(
            table="title_basics",
            columns=["tconst", "titleType", "primaryTitle", "originalTitle", "isAdult", "startYear", "endYear", "runtimeMinutes", "genres"],
            gz_path="title.basics.tsv.gz",
        ),
        CopyJob(
            table="name_basics",
            columns=["nconst", "primaryName", "birthYear", "deathYear", "primaryProfession", "knownForTitles"],
            gz_path="name.basics.tsv.gz",
        ),
        CopyJob(
            table="title_principals",
            columns=["tconst", "ordering", "nconst", "category", "job", "characters"],
            gz_path="title.principals.tsv.gz",
        ),
    ]

    with psycopg.connect(PG_DSN) as conn:
        # 1) 建表(如果你已经建好,可以注释掉)
        run_ddl(conn)

        # 2) 导入
        for job in jobs:
            copy_tsv_gz(conn, job)

        # 3) 校验
        quick_check(conn)


if __name__ == "__main__":
    main()

2.2. 数据小量整理

sql 复制代码
DROP TABLE IF EXISTS imdb.small_name_basics;
DROP TABLE IF EXISTS imdb.small_title_basics;
DROP TABLE IF EXISTS imdb.small_title_principals;


CREATE TABLE imdb.small_title_basics AS
SELECT *
FROM imdb.title_basics
WHERE startyear BETWEEN 2015 AND 2018
  AND isadult = 0
  AND titletype IN ('movie','tvSeries','tvMiniSeries')
ORDER BY tconst
LIMIT 2000;     -- ⭐ 控制规模

ALTER TABLE imdb.small_title_basics
  ADD PRIMARY KEY (tconst);


CREATE TABLE imdb.small_title_principals AS
SELECT p.*
FROM imdb.title_principals p
JOIN imdb.small_title_basics t
  ON t.tconst = p.tconst
WHERE p.category IN ('actor','actress');

CREATE INDEX ON imdb.small_title_principals(tconst);
CREATE INDEX ON imdb.small_title_principals(nconst);
CREATE INDEX ON imdb.small_title_principals(category);


CREATE TABLE imdb.small_name_basics AS
SELECT n.*
FROM imdb.name_basics n
JOIN (
  SELECT DISTINCT nconst
  FROM imdb.small_title_principals
) x ON x.nconst = n.nconst;

ALTER TABLE imdb.small_name_basics
  ADD PRIMARY KEY (nconst);


SELECT
  (SELECT count(*) FROM imdb.small_name_basics)      AS people,
  (SELECT count(*) FROM imdb.small_title_basics)     AS titles,
  (SELECT count(*) FROM imdb.small_title_principals) AS edges;

2.3. 创建graph

2.3.1. 设置每次数据库查询范围

sql 复制代码
ALTER DATABASE imdb
SET search_path = ag_catalog, imdb, public;

# 只是当次会话有效
SET search_path = ag_catalog, imdb, public;

2.4. 导入graph节点信息

python 复制代码
import json
import os
import time

import psycopg
from psycopg.rows import dict_row
from tqdm import tqdm

# ==========================
# 配置
# ==========================
DSN = os.getenv(
    "PG_DSN",
    "host=192.168.152.134 port=55434 dbname=postgres user=postgres password=123456"
)

GRAPH = "imdb_graph"

BATCH_PERSON = int(os.getenv("BATCH_PERSON", "5000"))
BATCH_TITLE  = int(os.getenv("BATCH_TITLE", "5000"))
BATCH_EDGE   = int(os.getenv("BATCH_EDGE", "5000"))

RETRY = int(os.getenv("RETRY", "3"))
SYNC_COMMIT_OFF = os.getenv("SYNC_COMMIT_OFF", "true").lower() in ("1", "true", "yes", "y")

# 只跑某阶段:PERSON / TITLE / EDGE / ALL
RUN_STAGE = os.getenv("RUN_STAGE", "ALL").upper()

# 可选:只导 actor/actress 的边(small 表你本来就筛过了,这里保守保留开关)
ONLY_ACTORS = os.getenv("ONLY_ACTORS", "true").lower() in ("1", "true", "yes", "y")


# ==========================
# 固定 Cypher
# ==========================
SQL_PERSON_CYPHER = f"""
SELECT 1
FROM cypher('{GRAPH}', $$
  UNWIND $rows AS row
  MERGE (p:Person {{nconst: row.nconst}})
  SET p.primaryname = row.primaryname,
      p.birthyear = row.birthyear,
      p.deathyear = row.deathyear,
      p.primaryprofession = row.primaryprofession,
      p.knownfortitles = row.knownfortitles
  RETURN 1
$$, %s::agtype) AS (x agtype);
"""

SQL_TITLE_CYPHER = f"""
SELECT 1
FROM cypher('{GRAPH}', $$
  UNWIND $rows AS row
  MERGE (t:Title {{tconst: row.tconst}})
  SET t.titletype = row.titletype,
      t.primarytitle = row.primarytitle,
      t.originaltitle = row.originaltitle,
      t.isadult = row.isadult,
      t.startyear = row.startyear,
      t.endyear = row.endyear,
      t.runtimeminutes = row.runtimeminutes,
      t.genres = row.genres
  RETURN 1
$$, %s::agtype) AS (x agtype);
"""

SQL_EDGE_CYPHER = f"""
SELECT 1
FROM cypher('{GRAPH}', $$
  UNWIND $rows AS row
  MATCH (p:Person {{nconst: row.nconst}})
  MATCH (t:Title  {{tconst: row.tconst}})
  MERGE (p)-[r:WORKED_ON {{category: row.category, ordering: row.ordering}}]->(t)
  SET r.job = row.job,
      r.characters = row.characters
  RETURN 1
$$, %s::agtype) AS (x agtype);
"""


# ==========================
# small_* 取数 SQL(单线程 keyset)
# ==========================
SQL_TOTAL_PERSON = "SELECT count(*) AS total FROM imdb.small_name_basics;"
SQL_TOTAL_TITLE  = "SELECT count(*) AS total FROM imdb.small_title_basics;"
SQL_TOTAL_EDGE   = "SELECT count(*) AS total FROM imdb.small_title_principals;"

SQL_SELECT_PERSON = """
SELECT nconst, primaryname, birthyear, deathyear, primaryprofession, knownfortitles
FROM imdb.small_name_basics
WHERE nconst > %s
ORDER BY nconst
LIMIT %s
"""

SQL_SELECT_TITLE = """
SELECT tconst, titletype, primarytitle, originaltitle, isadult, startyear, endyear, runtimeminutes, genres
FROM imdb.small_title_basics
WHERE tconst > %s
ORDER BY tconst
LIMIT %s
"""

SQL_SELECT_EDGE = f"""
SELECT tconst, nconst, ordering, category, job, characters
FROM imdb.small_title_principals
WHERE (tconst, nconst, ordering) > (%s, %s, %s)
  {"AND category IN ('actor','actress')" if ONLY_ACTORS else ""}
ORDER BY tconst, nconst, ordering
LIMIT %s
"""


def make_payload(rows) -> str:
    return json.dumps({"rows": rows}, ensure_ascii=False)


def init_session(cur):
    cur.execute("SET search_path = ag_catalog, imdb, public;")
    if SYNC_COMMIT_OFF:
        cur.execute("SET synchronous_commit = off;")


def fetch_total(cur, sql_total: str) -> int:
    cur.execute(sql_total)
    return int(cur.fetchone()["total"])


def exec_with_retry(cur, conn, sql_cypher: str, payload: str):
    last_err = None
    for attempt in range(1, RETRY + 1):
        try:
            cur.execute(sql_cypher, (payload,))
            conn.commit()
            return
        except Exception as e:
            conn.rollback()
            last_err = e
            time.sleep(min(0.5 * attempt, 2.0))
    raise last_err


def load_people(cur, conn):
    total = fetch_total(cur, SQL_TOTAL_PERSON)
    last = ""
    done = 0

    t0 = time.time()
    with tqdm(total=total, desc="Person (single)", unit="rows", dynamic_ncols=True) as pbar:
        while True:
            cur.execute(SQL_SELECT_PERSON, (last, BATCH_PERSON))
            batch = cur.fetchall()
            if not batch:
                break

            rows = [dict(r) for r in batch]
            exec_with_retry(cur, conn, SQL_PERSON_CYPHER, make_payload(rows))

            last = rows[-1]["nconst"]
            done += len(rows)
            pbar.update(len(rows))

    elapsed = time.time() - t0
    print(f"[Person single] DONE {done}/{total} elapsed={elapsed:.1f}s rate={done/elapsed:.0f} rows/s")


def load_titles(cur, conn):
    total = fetch_total(cur, SQL_TOTAL_TITLE)
    last = ""
    done = 0

    t0 = time.time()
    with tqdm(total=total, desc="Title (single)", unit="rows", dynamic_ncols=True) as pbar:
        while True:
            cur.execute(SQL_SELECT_TITLE, (last, BATCH_TITLE))
            batch = cur.fetchall()
            if not batch:
                break

            rows = [dict(r) for r in batch]
            exec_with_retry(cur, conn, SQL_TITLE_CYPHER, make_payload(rows))

            last = rows[-1]["tconst"]
            done += len(rows)
            pbar.update(len(rows))

    elapsed = time.time() - t0
    print(f"[Title single] DONE {done}/{total} elapsed={elapsed:.1f}s rate={done/elapsed:.0f} rows/s")


def load_edges(cur, conn):
    total = fetch_total(cur, SQL_TOTAL_EDGE)
    last_t, last_n, last_o = "", "", -1
    done = 0

    t0 = time.time()
    with tqdm(total=total, desc="Edge (single)", unit="rows", dynamic_ncols=True) as pbar:
        while True:
            cur.execute(SQL_SELECT_EDGE, (last_t, last_n, last_o, BATCH_EDGE))
            batch = cur.fetchall()
            if not batch:
                break

            rows = [dict(r) for r in batch]
            exec_with_retry(cur, conn, SQL_EDGE_CYPHER, make_payload(rows))

            last_t = rows[-1]["tconst"]
            last_n = rows[-1]["nconst"]
            last_o = rows[-1]["ordering"]
            done += len(rows)
            pbar.update(len(rows))

    elapsed = time.time() - t0
    print(f"[Edge single] DONE {done}/{total} elapsed={elapsed:.1f}s rate={done/elapsed:.0f} rows/s")


def main():
    with psycopg.connect(DSN, row_factory=dict_row) as conn:
        with conn.cursor() as cur:
            init_session(cur)

            print("== small single-thread loader ==")
            print(f"RUN_STAGE={RUN_STAGE}  sync_commit_off={SYNC_COMMIT_OFF}  retry={RETRY}")
            print(f"BATCH person={BATCH_PERSON} title={BATCH_TITLE} edge={BATCH_EDGE}")
            print(f"ONLY_ACTORS={ONLY_ACTORS}")

            if RUN_STAGE in ("PERSON", "ALL"):
                load_people(cur, conn)

            if RUN_STAGE in ("TITLE", "ALL"):
                load_titles(cur, conn)

            if RUN_STAGE in ("EDGE", "ALL"):
                load_edges(cur, conn)

            print("== done ==")


if __name__ == "__main__":
    main()
相关推荐
强子感冒了2 小时前
Java学习笔记:String、StringBuilder与StringBuffer
java·开发语言·笔记·学习
程序员JerrySUN2 小时前
OP-TEE + YOLOv8:从“加密权重”到“内存中解密并推理”的完整实战记录
android·java·开发语言·redis·yolo·架构
+VX:Fegn08952 小时前
计算机毕业设计|基于springboot + vueOA工程项目管理系统(源码+数据库+文档)
java·数据库·vue.js·spring boot·后端·课程设计
JasmineWr3 小时前
Spring事务解析
java·spring
qq_336313933 小时前
java基础-IO流(缓冲流)
java·开发语言
青岛少儿编程-王老师3 小时前
CCF编程能力等级认证GESP—C++2级—20251227
java·开发语言·c++
高山上有一只小老虎3 小时前
小红的推荐系统
java·算法
萧曵 丶3 小时前
JDK各版本新增特性详解
java·面试
爱上妖精的尾巴3 小时前
7-3 WPS JS宏 keys、values、entries、JSON.stringify 循环对象中的属性
后端·restful·wps·jsa