DeepSeek辅助实现的DuckDB copy to自定义函数

copy to自定义函数指将DuckDB 数据库中的数据导出成各种自己需要的格式,除了官方已经提供的csv、parquet、xlsx,只要知道某个格式用什么库来写,就将读出的数据填充到那个库的写函数中即可。

本来以为这是最简单的一个,因为有DuckDB自己实现copy to csv的源码在,又有读写网上的Google Gsheet插件在,随便抄哪个都能抄对,但DeepSeek折腾了一整天也没有输出一个正确的程序,最后还是人工比对他的实现和人家正确实现的不同点,才解决。

关键在于:

1.InitializeGlobal只能做到写标题,不能写数据本身,后者要用InitializeLocal完成,否则就连Sink函数都够不着执行,就报Segmentation fault (core dumped)错误退出,至少今天的试验是这样的。

2.读数据要从Datachunk类型的参数中读,而不是从LocalFunctionData参数读,否则,虽然local.size()local.ColumnCount()的值都正确,GetValue读出来的全是空白。

这就是一整天的教训,也怪我让他做一个精简的实现,而不是原封不动地照抄。

源代码如下,它实现了一个mycsv文件后缀名,并在数据中插入myduck前缀,以便与系统的内置csv区分。

cpp 复制代码
#include "duckdb.hpp"
#include "duckdb/common/file_system.hpp"
#include "duckdb/common/serializer/buffered_file_writer.hpp"
#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp"
#include "duckdb/function/copy_function.hpp"
#include "duckdb/parser/parsed_data/create_copy_function_info.hpp"
#include "duckdb/main/extension_util.hpp"
#include <iostream>

#define DEBUG_LOG(msg) //std::cerr << "[DEBUG] " << msg << std::endl
namespace duckdb {

// 1. 全局状态

// 修改全局状态管理,确保文件正确关闭
// 1. 修改全局状态类,添加写入状态跟踪
struct MyCSVCopyGlobalState : public GlobalFunctionData {
    explicit MyCSVCopyGlobalState(unique_ptr<BufferedFileWriter> writer, string file_path)
        : writer(std::move(writer)), file_path(std::move(file_path)), initialized(true) {
        DEBUG_LOG("Writer initialized for: " << this->file_path 
                  << ", writer valid: " << (this->writer != nullptr));
    }

    ~MyCSVCopyGlobalState() {
        if (writer) {
            DEBUG_LOG("Final flush for: " << file_path);
            writer->Flush();
        }
    }

    unique_ptr<BufferedFileWriter> writer;
    string file_path;
    bool initialized = false;

};
// 2. 配置选项
struct MyCSVWriteOptions {
    vector<string> name_list;
    string delimiter = "|";
    string prefix = "myduck";
    bool header = true;
};

struct MyCSVLocalState : public LocalFunctionData {
    explicit MyCSVLocalState(ClientContext &context, 
                           const vector<LogicalType> &sql_types)
        : executor(context) {
        // 初始化转换用的DataChunk
        cast_chunk.Initialize(Allocator::Get(context), GetVarcharTypes(sql_types));
    }

    // 类型转换执行器
    ExpressionExecutor executor;
    
    // 用于存储转换后的字符串数据
    DataChunk cast_chunk;

private:
    static vector<LogicalType> GetVarcharTypes(const vector<LogicalType> &sql_types) {
        vector<LogicalType> varchar_types;
        for (auto &type : sql_types) {
            varchar_types.push_back(LogicalType::VARCHAR);
        }
        return varchar_types;
    }
};

// 3. 绑定数据(实现Equals方法)
struct MyCSVWriteBindData : public TableFunctionData {
    vector<string> files;
    MyCSVWriteOptions options;
    vector<LogicalType> sql_types;

    MyCSVWriteBindData(string file_path, 
                      vector<LogicalType> sql_types, 
                      vector<string> names,
                      string delimiter = "|",
                      string prefix = "myduck",
                      bool header = true)
        : sql_types(std::move(sql_types)) {
        files.push_back(std::move(file_path));
        options.name_list = std::move(names);
        options.delimiter = std::move(delimiter);
        options.prefix = std::move(prefix);
        options.header = std::move(header);
    }

    unique_ptr<FunctionData> Copy() const override {
        return make_uniq<MyCSVWriteBindData>(
            files[0], sql_types, options.name_list, 
            options.delimiter, options.prefix, options.header
        );
    }

    bool Equals(const FunctionData &other) const override {
        auto &other_bind = other.Cast<MyCSVWriteBindData>();
        return files == other_bind.files && 
               options.delimiter == other_bind.options.delimiter &&
               options.prefix == other_bind.options.prefix &&
               options.header == other_bind.options.header;
    }
};

// 4. 主功能类
class MyCSVCopyFunction : public CopyFunction {
public:
    MyCSVCopyFunction() : CopyFunction("mycsv") {
        
    copy_to_bind = Bind;
    copy_to_initialize_global = InitializeGlobal;
    copy_to_initialize_local = InitializeLocal;
    copy_to_sink = Sink;


    DEBUG_LOG("Pointers registered: " 
              << (void*)copy_to_bind << ", "
              << (void*)copy_to_initialize_global << ", "
              << (void*)copy_to_initialize_local << ", "
              << (void*)copy_to_sink);
    }

    static unique_ptr<FunctionData> Bind(
        ClientContext &context, 
        CopyFunctionBindInput &input,
        const vector<string> &names,
        const vector<LogicalType> &sql_types);

    static unique_ptr<GlobalFunctionData> InitializeGlobal(
        ClientContext &context, 
        FunctionData &bind_data,
        const string &file_path);
    static unique_ptr<LocalFunctionData> InitializeLocal(duckdb::ExecutionContext&, duckdb::FunctionData&);
    static void Sink(
        ExecutionContext &context,
        FunctionData &bind_data,
        GlobalFunctionData &gstate,
        LocalFunctionData &lstate,
        DataChunk &input);
};

// 辅助写入函数
static void WriteCSVString(BufferedFileWriter &writer, const string &str) {
    writer.WriteData(
        reinterpret_cast<const_data_ptr_t>(str.c_str()),
        str.size()
    );
}

// 5. 绑定函数实现
unique_ptr<FunctionData> MyCSVCopyFunction::Bind(
    ClientContext &context, 
    CopyFunctionBindInput &input,
    const vector<string> &names,
    const vector<LogicalType> &sql_types) {
    
    auto bind_data = make_uniq<MyCSVWriteBindData>(
        input.info.file_path, 
        sql_types, 
        names
    );

    // 处理选项参数
    for (auto &option : input.info.options) {
        if (option.first == "delimiter" && !option.second.empty()) {
            bind_data->options.delimiter = option.second[0].ToString();
        } else if (option.first == "prefix" && !option.second.empty()) {
            bind_data->options.prefix = option.second[0].ToString();
        } else if (option.first == "header" && !option.second.empty()) {
            bind_data->options.header = option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue<bool>();
        }
    }

    return std::move(bind_data);
}

unique_ptr<LocalFunctionData> MyCSVCopyFunction::InitializeLocal(
    ExecutionContext &context, 
    FunctionData &bind_data) {
    
    DEBUG_LOG("Initializing thread-local state for worker ");
    
    auto &data = bind_data.Cast<MyCSVWriteBindData>();
    
    // 创建线程本地状态
    auto local_state = make_uniq<MyCSVLocalState>(context.client, data.sql_types);
    
    // 如果需要表达式转换(例如日期格式化),在此初始化executor
    /*
    if (!data.options.force_quote.empty()) {
        vector<unique_ptr<Expression>> expressions;
        // 构建转换表达式...
        local_state->executor.Initialize(expressions);
    }
    */
    DEBUG_LOG("Thread-local state initialized with "
             << data.sql_types.size() << " columns");
    return std::move(local_state);
}
// 修改InitializeGlobal函数
unique_ptr<GlobalFunctionData> MyCSVCopyFunction::InitializeGlobal(
    ClientContext &context, 
    FunctionData &bind_data,
    const string &file_path) {
    
    DEBUG_LOG("Initializing global state for file: " << file_path);
    auto &data = bind_data.Cast<MyCSVWriteBindData>();
    auto &fs = FileSystem::GetFileSystem(context);
    
    try {
        // 检查文件是否可写
        DEBUG_LOG("Checking file access: " << file_path);
        auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW);
        handle->Close();

        // 创建文件写入器
        DEBUG_LOG("Creating BufferedFileWriter");
        auto writer = make_uniq<BufferedFileWriter>(fs, file_path);
        
        // 写入表头
        if (data.options.header) {
            DEBUG_LOG("Writing header");
            for (size_t i = 0; i < data.options.name_list.size(); ++i) {
                if (i != 0) {
                    WriteCSVString(*writer, data.options.delimiter);
                }
                WriteCSVString(*writer, data.options.prefix);
                WriteCSVString(*writer, data.options.name_list[i]);
            }
            WriteCSVString(*writer, "\n");
            writer->Flush();
            DEBUG_LOG("Header written successfully");
        }
        
        return make_uniq<MyCSVCopyGlobalState>(std::move(writer), file_path);
    } catch (const std::exception &e) {
        DEBUG_LOG("InitializeGlobal failed: " << e.what());
        throw;
    }
}

void MyCSVCopyFunction::Sink(
    ExecutionContext &context,
    FunctionData &bind_data,
    GlobalFunctionData &gstate,
    LocalFunctionData &lstate,
    DataChunk &input) {
    auto &state = gstate.Cast<MyCSVCopyGlobalState>();
    auto &local = lstate.Cast<MyCSVLocalState>();
    
    // 1. 类型转换
    local.cast_chunk.Reset();
    local.cast_chunk.SetCardinality(input);
    local.executor.Execute(input, local.cast_chunk);
    
    // 2. 写入数据

    for (idx_t row = 0; row < local.cast_chunk.size(); row++) {
        for (idx_t col = 0; col < local.cast_chunk.ColumnCount(); col++) {

            auto val = input.GetValue(col, row);
            //std::cout<<val.ToString()<<std::endl;
            WriteCSVString(*state.writer, val.IsNull() ? "NULL" : val.ToString());
        }WriteCSVString(*state.writer, "\n");
    }
}

// 8. 注册函数(修正注册方式)
void RegisterMyCSVFunction(DatabaseInstance &db) {

    // Register COPY TO (FORMAT 'mycsv') function
    MyCSVCopyFunction MyCSV_copy_function;
    ExtensionUtil::RegisterFunction(db, MyCSV_copy_function);

    DEBUG_LOG("RegisterMyCSVFunction started");
    return;

}

} // namespace duckdb
using namespace duckdb;
int main() {

try {
    DuckDB db(nullptr);
    Connection con(db);
    
    // 注册自定义格式
    RegisterMyCSVFunction(*db.instance);

    // 创建测试数据
    DEBUG_LOG("Creating test table");
    auto create_result = con.Query("CREATE TABLE test AS SELECT i, 'val_'||i AS text FROM range(5) t(i)");
    if (create_result->HasError()) {
        DEBUG_LOG("Create table error: " << create_result->GetError());
        return 1;
    }
    auto result1 = con.Query("from test");
    result1->Print();


    auto explain = con.Query("EXPLAIN COPY test TO 'output.mycsv' WITH (FORMAT mycsv,DELIMITER '|',PREFIX 'myprefix')");
    if (explain) {

       std::cout <<explain->GetValue(1,0);
    }
    DEBUG_LOG("Executing COPY command");
    auto result = con.Query(R"(
        COPY test TO 'output.mycsv' WITH (
            FORMAT mycsv
        )
    )");


    if (result->HasError()) {
        std::cerr << "Error: " << result->GetError() << std::endl;
        return 1;
    }
         DEBUG_LOG("Execution completed successfully");
        return 0;
    } catch (const std::exception &e) {
        DEBUG_LOG("Fatal error in main: " << e.what());
        return 1;
    }   
    return 0;
}
相关推荐
苏苏susuus3 小时前
机器学习:load_predict_project
人工智能·机器学习
科技小E3 小时前
打手机检测算法AI智能分析网关V4守护公共/工业/医疗等多场景安全应用
人工智能·安全·智能手机
猿饵块4 小时前
视觉slam--框架
人工智能
yvestine5 小时前
自然语言处理——Transformer
人工智能·深度学习·自然语言处理·transformer
SuperW5 小时前
OPENCV图形计算面积、弧长API讲解(1)
人工智能·opencv·计算机视觉
why1516 小时前
微服务商城-商品微服务
数据库·后端·golang
山海不说话6 小时前
视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)
人工智能·python·计算机视觉·视觉检测
柒间6 小时前
Elasticsearch 常用操作命令整合 (cURL 版本)
大数据·数据库·elasticsearch
虹科数字化与AR7 小时前
安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)
人工智能·ar·ar眼镜·船舶智造·数字工作流·智能装配
achene_ql7 小时前
select、poll、epoll 与 Reactor 模式
linux·服务器·网络·c++