模型训练之爬取数据

安装依赖

pip install requests beautifulsoup4 trafilatura tqdm

爬虫代码

python 复制代码
#!/usr/bin/env python3
"""
Rocketech 知识库爬虫(用于 RAG 数据集)
用法:
    python crawl_rocketech.py [--max_pages 500] [--delay 2] [--output data.jsonl]
"""

import requests
import trafilatura
from bs4 import BeautifulSoup
from urllib.parse import urljoin, urlparse
from urllib.robotparser import RobotFileParser
import time
import json
import argparse
import os
from tqdm import tqdm


class RocketechCrawler:
    def __init__(self, start_url, output_file="rocketech_data.jsonl", max_pages=500, delay=2.0):
        self.start_url = start_url
        self.domain = urlparse(start_url).netloc
        self.base_url = f"https://{self.domain}/"
        self.output_file = output_file
        self.max_pages = max_pages
        self.delay = delay

        # 已访问/待访问队列
        self.visited = set()
        self.to_visit = [start_url]

        # robots.txt 规则
        self.rp = RobotFileParser()
        self.rp.set_url(f"https://{self.domain}/robots.txt")
        try:
            self.rp.read()
        except Exception:
            print("无法获取 robots.txt,将不限制爬取。")
            self.rp = None

        # 用于去重的 URL 归一化
        self.session = requests.Session()
        self.session.headers.update({
            "User-Agent": "RocketechRAGBot/1.0 (research project; contact@example.com)"
        })

    def can_fetch(self, url):
        """检查 robots.txt 是否允许抓取"""
        if self.rp is None:
            return True
        return self.rp.can_fetch(self.session.headers["User-Agent"], url)

    def normalize_url(self, url):
        """去掉 fragment,尾斜杠统一,忽略大小写(域名部分)"""
        parsed = urlparse(url)
        norm = parsed._replace(fragment="", query="").geturl()
        norm = norm.rstrip("/")
        return norm

    def extract_links(self, html, current_url):
        """从页面中提取所有同域链接"""
        soup = BeautifulSoup(html, "html.parser")
        links = set()
        for a in soup.find_all("a", href=True):
            href = a["href"].strip()
            full_url = urljoin(current_url, href)
            parsed = urlparse(full_url)
            # 只保留 http/https 同域链接
            if parsed.scheme in ("http", "https") and parsed.netloc == self.domain:
                # 过滤掉资源文件(图片、PDF等)
                if not any(parsed.path.lower().endswith(ext) for ext in [".pdf", ".zip", ".png", ".jpg", ".jpeg", ".gif", ".mp3", ".mp4", ".doc", ".docx"]):
                    normalized = self.normalize_url(full_url)
                    if normalized not in self.visited:
                        links.add(normalized)
        return links

    def process_page(self, url):
        """抓取并提取一个页面的文本内容"""
        try:
            resp = self.session.get(url, timeout=30)
            if resp.status_code != 200:
                return None, None
            html = resp.text
        except Exception as e:
            print(f" 请求失败 {url}: {e}")
            return None, None

        # 提取正文(trailatura 自动剔除导航/页脚/广告)
        downloaded = trafilatura.extract(html, include_comments=False, include_tables=False,
                                         no_fallback=False, favor_precision=True)
        if not downloaded:
            # 备选:直接用BeautifulSoup取文本
            soup = BeautifulSoup(html, "html.parser")
            for script in soup(["script", "style", "nav", "footer", "header"]):
                script.decompose()
            text = soup.get_text(separator="\n")
            text = "\n".join(line.strip() for line in text.splitlines() if line.strip())
        else:
            text = downloaded.strip()

        # 标题
        soup = BeautifulSoup(html, "html.parser")
        title = soup.title.string.strip() if soup.title else url

        return title, text

    def save_record(self, record):
        """追加一行到 JSONL"""
        with open(self.output_file, "a", encoding="utf-8") as f:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")

    def crawl(self):
        print(f"开始爬取 {self.start_url},目标页数:{self.max_pages}")
        pbar = tqdm(total=self.max_pages, desc="已爬取页面")

        while self.to_visit and len(self.visited) < self.max_pages:
            # 取下一个 URL
            url = self.to_visit.pop(0)
            norm_url = self.normalize_url(url)
            if norm_url in self.visited:
                continue

            # 检查 robots.txt
            if not self.can_fetch(url):
                self.visited.add(norm_url)
                continue

            print(f"\n📄 正在处理: {url}")
            title, text = self.process_page(url)

            # 保存数据
            if text and len(text) > 50:  # 过滤过短页面
                record = {
                    "url": url,
                    "title": title,
                    "text": text,
                    "timestamp": time.time()
                }
                self.save_record(record)
                print(f" 保存成功,正文长度: {len(text)} 字符")
            else:
                print(f"  跳过(无有效正文)")

            # 标记已访问并提取链接
            self.visited.add(norm_url)
            if len(self.visited) < self.max_pages:
                html = None
                try:
                    resp = self.session.get(url, timeout=10)
                    if resp.status_code == 200:
                        html = resp.text
                except Exception:
                    pass
                if html:
                    new_links = self.extract_links(html, url)
                    self.to_visit.extend(new_links)
                    # 去重(保留顺序)
                    self.to_visit = list(dict.fromkeys(self.to_visit))

            pbar.update(1)
            # 礼貌延迟
            time.sleep(self.delay)

        pbar.close()
        print(f"\n爬取完成,共处理 {len(self.visited)} 个页面,数据保存至 {self.output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Rocketech 网站爬虫(RAG数据集)")
    parser.add_argument("--start_url", default="网页网址", help="起始URL")
    parser.add_argument("--output", default="rocketech_data.jsonl", help="输出JSONL文件")
    parser.add_argument("--max_pages", type=int, default=500, help="最大爬取页数")
    parser.add_argument("--delay", type=float, default=2.0, help="请求间隔(秒)")
    args = parser.parse_args()

    crawler = RocketechCrawler(
        start_url=args.start_url,
        output_file=args.output,
        max_pages=args.max_pages,
        delay=args.delay
    )
    crawler.crawl()

保存

crawl_rocketech.py

执行

python crawl_rocketech.py --max_pages 200 --delay 1.5

相关推荐
张二娃同学1 小时前
第12篇_深度学习学习路线总结
人工智能·python·深度学习·神经网络·学习
之歆1 小时前
DAY_10 JavaScript 深度解析:原型链 · 引用类型 · 内置对象 · 数组方法全攻略(上)
开发语言·javascript·ecmascript
zmzb01031 小时前
Python课后习题训练记录Day122
开发语言·python
陳土2 小时前
R语言jiebaR包使用摘要
开发语言·r语言
Evand J2 小时前
【MATLAB】多无人机编队协同控制与三维航迹规划仿真。障碍物斥力避障,输出编队误差、控制输入、三维轨迹等
开发语言·matlab·无人机
m0_702036532 小时前
如何从Oracle Java调用外部API_HTTP请求在数据库Java Source中的实现
jvm·数据库·python
Freak嵌入式2 小时前
WIZnet-EVB-Pico2开始,用MicroPython玩转以太网开发
arm开发·人工智能·python·嵌入式硬件·机器人·嵌入式·micropython
froginwe112 小时前
jQuery UI 小部件方法调用
开发语言