模型训练之爬取数据

安装依赖

pip install requests beautifulsoup4 trafilatura tqdm

爬虫代码

python 复制代码
#!/usr/bin/env python3
"""
Rocketech 知识库爬虫(用于 RAG 数据集)
用法:
    python crawl_rocketech.py [--max_pages 500] [--delay 2] [--output data.jsonl]
"""

import requests
import trafilatura
from bs4 import BeautifulSoup
from urllib.parse import urljoin, urlparse
from urllib.robotparser import RobotFileParser
import time
import json
import argparse
import os
from tqdm import tqdm


class RocketechCrawler:
    def __init__(self, start_url, output_file="rocketech_data.jsonl", max_pages=500, delay=2.0):
        self.start_url = start_url
        self.domain = urlparse(start_url).netloc
        self.base_url = f"https://{self.domain}/"
        self.output_file = output_file
        self.max_pages = max_pages
        self.delay = delay

        # 已访问/待访问队列
        self.visited = set()
        self.to_visit = [start_url]

        # robots.txt 规则
        self.rp = RobotFileParser()
        self.rp.set_url(f"https://{self.domain}/robots.txt")
        try:
            self.rp.read()
        except Exception:
            print("无法获取 robots.txt,将不限制爬取。")
            self.rp = None

        # 用于去重的 URL 归一化
        self.session = requests.Session()
        self.session.headers.update({
            "User-Agent": "RocketechRAGBot/1.0 (research project; contact@example.com)"
        })

    def can_fetch(self, url):
        """检查 robots.txt 是否允许抓取"""
        if self.rp is None:
            return True
        return self.rp.can_fetch(self.session.headers["User-Agent"], url)

    def normalize_url(self, url):
        """去掉 fragment,尾斜杠统一,忽略大小写(域名部分)"""
        parsed = urlparse(url)
        norm = parsed._replace(fragment="", query="").geturl()
        norm = norm.rstrip("/")
        return norm

    def extract_links(self, html, current_url):
        """从页面中提取所有同域链接"""
        soup = BeautifulSoup(html, "html.parser")
        links = set()
        for a in soup.find_all("a", href=True):
            href = a["href"].strip()
            full_url = urljoin(current_url, href)
            parsed = urlparse(full_url)
            # 只保留 http/https 同域链接
            if parsed.scheme in ("http", "https") and parsed.netloc == self.domain:
                # 过滤掉资源文件(图片、PDF等)
                if not any(parsed.path.lower().endswith(ext) for ext in [".pdf", ".zip", ".png", ".jpg", ".jpeg", ".gif", ".mp3", ".mp4", ".doc", ".docx"]):
                    normalized = self.normalize_url(full_url)
                    if normalized not in self.visited:
                        links.add(normalized)
        return links

    def process_page(self, url):
        """抓取并提取一个页面的文本内容"""
        try:
            resp = self.session.get(url, timeout=30)
            if resp.status_code != 200:
                return None, None
            html = resp.text
        except Exception as e:
            print(f" 请求失败 {url}: {e}")
            return None, None

        # 提取正文(trailatura 自动剔除导航/页脚/广告)
        downloaded = trafilatura.extract(html, include_comments=False, include_tables=False,
                                         no_fallback=False, favor_precision=True)
        if not downloaded:
            # 备选:直接用BeautifulSoup取文本
            soup = BeautifulSoup(html, "html.parser")
            for script in soup(["script", "style", "nav", "footer", "header"]):
                script.decompose()
            text = soup.get_text(separator="\n")
            text = "\n".join(line.strip() for line in text.splitlines() if line.strip())
        else:
            text = downloaded.strip()

        # 标题
        soup = BeautifulSoup(html, "html.parser")
        title = soup.title.string.strip() if soup.title else url

        return title, text

    def save_record(self, record):
        """追加一行到 JSONL"""
        with open(self.output_file, "a", encoding="utf-8") as f:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")

    def crawl(self):
        print(f"开始爬取 {self.start_url},目标页数:{self.max_pages}")
        pbar = tqdm(total=self.max_pages, desc="已爬取页面")

        while self.to_visit and len(self.visited) < self.max_pages:
            # 取下一个 URL
            url = self.to_visit.pop(0)
            norm_url = self.normalize_url(url)
            if norm_url in self.visited:
                continue

            # 检查 robots.txt
            if not self.can_fetch(url):
                self.visited.add(norm_url)
                continue

            print(f"\n📄 正在处理: {url}")
            title, text = self.process_page(url)

            # 保存数据
            if text and len(text) > 50:  # 过滤过短页面
                record = {
                    "url": url,
                    "title": title,
                    "text": text,
                    "timestamp": time.time()
                }
                self.save_record(record)
                print(f" 保存成功,正文长度: {len(text)} 字符")
            else:
                print(f"  跳过(无有效正文)")

            # 标记已访问并提取链接
            self.visited.add(norm_url)
            if len(self.visited) < self.max_pages:
                html = None
                try:
                    resp = self.session.get(url, timeout=10)
                    if resp.status_code == 200:
                        html = resp.text
                except Exception:
                    pass
                if html:
                    new_links = self.extract_links(html, url)
                    self.to_visit.extend(new_links)
                    # 去重(保留顺序)
                    self.to_visit = list(dict.fromkeys(self.to_visit))

            pbar.update(1)
            # 礼貌延迟
            time.sleep(self.delay)

        pbar.close()
        print(f"\n爬取完成,共处理 {len(self.visited)} 个页面,数据保存至 {self.output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Rocketech 网站爬虫(RAG数据集)")
    parser.add_argument("--start_url", default="网页网址", help="起始URL")
    parser.add_argument("--output", default="rocketech_data.jsonl", help="输出JSONL文件")
    parser.add_argument("--max_pages", type=int, default=500, help="最大爬取页数")
    parser.add_argument("--delay", type=float, default=2.0, help="请求间隔(秒)")
    args = parser.parse_args()

    crawler = RocketechCrawler(
        start_url=args.start_url,
        output_file=args.output,
        max_pages=args.max_pages,
        delay=args.delay
    )
    crawler.crawl()

保存

crawl_rocketech.py

执行

python crawl_rocketech.py --max_pages 200 --delay 1.5

相关推荐
Warson_L6 小时前
Python `Annotated` 与 LangGraph Reducer 学习笔记
python
韩师傅6 小时前
海天线算法的前世今生
python·计算机视觉
韩师傅6 小时前
当你的甲方设备过烂,要如何快速出效果?
python·计算机视觉
Warson_L6 小时前
LangGraph的MessageState and HumanMessage
python
韩师傅6 小时前
当你的甲方吐槽天空不够蓝,你应该如何应对
python·计算机视觉
Warson_L7 小时前
python的类&继承
python
Warson_L7 小时前
类型标注/type annotation
python
ThreeS9 小时前
手搓MiniVLA全实战教程-一步一步用pytorch解释原理与思路
人工智能·python
金銀銅鐵11 小时前
[Python] 模 n 乘法的逆元计算器
python·数学·游戏