copy to自定义函数指将DuckDB 数据库中的数据导出成各种自己需要的格式,除了官方已经提供的csv、parquet、xlsx,只要知道某个格式用什么库来写,就将读出的数据填充到那个库的写函数中即可。
本来以为这是最简单的一个,因为有DuckDB自己实现copy to csv的源码在前,又有读写网上的Google Gsheet插件在后,随便抄哪个都能抄对,但DeepSeek折腾了一整天也没有输出一个正确的程序,最后还是人工比对他的实现和人家正确实现的不同点,才解决。
关键在于:
1.InitializeGlobal
只能做到写标题,不能写数据本身,后者要用InitializeLocal
完成,否则就连Sink
函数都够不着执行,就报Segmentation fault (core dumped)
错误退出,至少今天的试验是这样的。
2.读数据要从Datachunk类型的参数中读,而不是从LocalFunctionData
参数读,否则,虽然local.size()
和local.ColumnCount()
的值都正确,GetValue
读出来的全是空白。
这就是一整天的教训,也怪我让他做一个精简的实现,而不是原封不动地照抄。
源代码如下,它实现了一个mycsv
文件后缀名,并在数据中插入myduck
前缀,以便与系统的内置csv区分。
cpp
#include "duckdb.hpp"
#include "duckdb/common/file_system.hpp"
#include "duckdb/common/serializer/buffered_file_writer.hpp"
#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp"
#include "duckdb/function/copy_function.hpp"
#include "duckdb/parser/parsed_data/create_copy_function_info.hpp"
#include "duckdb/main/extension_util.hpp"
#include <iostream>
#define DEBUG_LOG(msg) //std::cerr << "[DEBUG] " << msg << std::endl
namespace duckdb {
// 1. 全局状态
// 修改全局状态管理,确保文件正确关闭
// 1. 修改全局状态类,添加写入状态跟踪
struct MyCSVCopyGlobalState : public GlobalFunctionData {
explicit MyCSVCopyGlobalState(unique_ptr<BufferedFileWriter> writer, string file_path)
: writer(std::move(writer)), file_path(std::move(file_path)), initialized(true) {
DEBUG_LOG("Writer initialized for: " << this->file_path
<< ", writer valid: " << (this->writer != nullptr));
}
~MyCSVCopyGlobalState() {
if (writer) {
DEBUG_LOG("Final flush for: " << file_path);
writer->Flush();
}
}
unique_ptr<BufferedFileWriter> writer;
string file_path;
bool initialized = false;
};
// 2. 配置选项
struct MyCSVWriteOptions {
vector<string> name_list;
string delimiter = "|";
string prefix = "myduck";
bool header = true;
};
struct MyCSVLocalState : public LocalFunctionData {
explicit MyCSVLocalState(ClientContext &context,
const vector<LogicalType> &sql_types)
: executor(context) {
// 初始化转换用的DataChunk
cast_chunk.Initialize(Allocator::Get(context), GetVarcharTypes(sql_types));
}
// 类型转换执行器
ExpressionExecutor executor;
// 用于存储转换后的字符串数据
DataChunk cast_chunk;
private:
static vector<LogicalType> GetVarcharTypes(const vector<LogicalType> &sql_types) {
vector<LogicalType> varchar_types;
for (auto &type : sql_types) {
varchar_types.push_back(LogicalType::VARCHAR);
}
return varchar_types;
}
};
// 3. 绑定数据(实现Equals方法)
struct MyCSVWriteBindData : public TableFunctionData {
vector<string> files;
MyCSVWriteOptions options;
vector<LogicalType> sql_types;
MyCSVWriteBindData(string file_path,
vector<LogicalType> sql_types,
vector<string> names,
string delimiter = "|",
string prefix = "myduck",
bool header = true)
: sql_types(std::move(sql_types)) {
files.push_back(std::move(file_path));
options.name_list = std::move(names);
options.delimiter = std::move(delimiter);
options.prefix = std::move(prefix);
options.header = std::move(header);
}
unique_ptr<FunctionData> Copy() const override {
return make_uniq<MyCSVWriteBindData>(
files[0], sql_types, options.name_list,
options.delimiter, options.prefix, options.header
);
}
bool Equals(const FunctionData &other) const override {
auto &other_bind = other.Cast<MyCSVWriteBindData>();
return files == other_bind.files &&
options.delimiter == other_bind.options.delimiter &&
options.prefix == other_bind.options.prefix &&
options.header == other_bind.options.header;
}
};
// 4. 主功能类
class MyCSVCopyFunction : public CopyFunction {
public:
MyCSVCopyFunction() : CopyFunction("mycsv") {
copy_to_bind = Bind;
copy_to_initialize_global = InitializeGlobal;
copy_to_initialize_local = InitializeLocal;
copy_to_sink = Sink;
DEBUG_LOG("Pointers registered: "
<< (void*)copy_to_bind << ", "
<< (void*)copy_to_initialize_global << ", "
<< (void*)copy_to_initialize_local << ", "
<< (void*)copy_to_sink);
}
static unique_ptr<FunctionData> Bind(
ClientContext &context,
CopyFunctionBindInput &input,
const vector<string> &names,
const vector<LogicalType> &sql_types);
static unique_ptr<GlobalFunctionData> InitializeGlobal(
ClientContext &context,
FunctionData &bind_data,
const string &file_path);
static unique_ptr<LocalFunctionData> InitializeLocal(duckdb::ExecutionContext&, duckdb::FunctionData&);
static void Sink(
ExecutionContext &context,
FunctionData &bind_data,
GlobalFunctionData &gstate,
LocalFunctionData &lstate,
DataChunk &input);
};
// 辅助写入函数
static void WriteCSVString(BufferedFileWriter &writer, const string &str) {
writer.WriteData(
reinterpret_cast<const_data_ptr_t>(str.c_str()),
str.size()
);
}
// 5. 绑定函数实现
unique_ptr<FunctionData> MyCSVCopyFunction::Bind(
ClientContext &context,
CopyFunctionBindInput &input,
const vector<string> &names,
const vector<LogicalType> &sql_types) {
auto bind_data = make_uniq<MyCSVWriteBindData>(
input.info.file_path,
sql_types,
names
);
// 处理选项参数
for (auto &option : input.info.options) {
if (option.first == "delimiter" && !option.second.empty()) {
bind_data->options.delimiter = option.second[0].ToString();
} else if (option.first == "prefix" && !option.second.empty()) {
bind_data->options.prefix = option.second[0].ToString();
} else if (option.first == "header" && !option.second.empty()) {
bind_data->options.header = option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue<bool>();
}
}
return std::move(bind_data);
}
unique_ptr<LocalFunctionData> MyCSVCopyFunction::InitializeLocal(
ExecutionContext &context,
FunctionData &bind_data) {
DEBUG_LOG("Initializing thread-local state for worker ");
auto &data = bind_data.Cast<MyCSVWriteBindData>();
// 创建线程本地状态
auto local_state = make_uniq<MyCSVLocalState>(context.client, data.sql_types);
// 如果需要表达式转换(例如日期格式化),在此初始化executor
/*
if (!data.options.force_quote.empty()) {
vector<unique_ptr<Expression>> expressions;
// 构建转换表达式...
local_state->executor.Initialize(expressions);
}
*/
DEBUG_LOG("Thread-local state initialized with "
<< data.sql_types.size() << " columns");
return std::move(local_state);
}
// 修改InitializeGlobal函数
unique_ptr<GlobalFunctionData> MyCSVCopyFunction::InitializeGlobal(
ClientContext &context,
FunctionData &bind_data,
const string &file_path) {
DEBUG_LOG("Initializing global state for file: " << file_path);
auto &data = bind_data.Cast<MyCSVWriteBindData>();
auto &fs = FileSystem::GetFileSystem(context);
try {
// 检查文件是否可写
DEBUG_LOG("Checking file access: " << file_path);
auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW);
handle->Close();
// 创建文件写入器
DEBUG_LOG("Creating BufferedFileWriter");
auto writer = make_uniq<BufferedFileWriter>(fs, file_path);
// 写入表头
if (data.options.header) {
DEBUG_LOG("Writing header");
for (size_t i = 0; i < data.options.name_list.size(); ++i) {
if (i != 0) {
WriteCSVString(*writer, data.options.delimiter);
}
WriteCSVString(*writer, data.options.prefix);
WriteCSVString(*writer, data.options.name_list[i]);
}
WriteCSVString(*writer, "\n");
writer->Flush();
DEBUG_LOG("Header written successfully");
}
return make_uniq<MyCSVCopyGlobalState>(std::move(writer), file_path);
} catch (const std::exception &e) {
DEBUG_LOG("InitializeGlobal failed: " << e.what());
throw;
}
}
void MyCSVCopyFunction::Sink(
ExecutionContext &context,
FunctionData &bind_data,
GlobalFunctionData &gstate,
LocalFunctionData &lstate,
DataChunk &input) {
auto &state = gstate.Cast<MyCSVCopyGlobalState>();
auto &local = lstate.Cast<MyCSVLocalState>();
// 1. 类型转换
local.cast_chunk.Reset();
local.cast_chunk.SetCardinality(input);
local.executor.Execute(input, local.cast_chunk);
// 2. 写入数据
for (idx_t row = 0; row < local.cast_chunk.size(); row++) {
for (idx_t col = 0; col < local.cast_chunk.ColumnCount(); col++) {
auto val = input.GetValue(col, row);
//std::cout<<val.ToString()<<std::endl;
WriteCSVString(*state.writer, val.IsNull() ? "NULL" : val.ToString());
}WriteCSVString(*state.writer, "\n");
}
}
// 8. 注册函数(修正注册方式)
void RegisterMyCSVFunction(DatabaseInstance &db) {
// Register COPY TO (FORMAT 'mycsv') function
MyCSVCopyFunction MyCSV_copy_function;
ExtensionUtil::RegisterFunction(db, MyCSV_copy_function);
DEBUG_LOG("RegisterMyCSVFunction started");
return;
}
} // namespace duckdb
using namespace duckdb;
int main() {
try {
DuckDB db(nullptr);
Connection con(db);
// 注册自定义格式
RegisterMyCSVFunction(*db.instance);
// 创建测试数据
DEBUG_LOG("Creating test table");
auto create_result = con.Query("CREATE TABLE test AS SELECT i, 'val_'||i AS text FROM range(5) t(i)");
if (create_result->HasError()) {
DEBUG_LOG("Create table error: " << create_result->GetError());
return 1;
}
auto result1 = con.Query("from test");
result1->Print();
auto explain = con.Query("EXPLAIN COPY test TO 'output.mycsv' WITH (FORMAT mycsv,DELIMITER '|',PREFIX 'myprefix')");
if (explain) {
std::cout <<explain->GetValue(1,0);
}
DEBUG_LOG("Executing COPY command");
auto result = con.Query(R"(
COPY test TO 'output.mycsv' WITH (
FORMAT mycsv
)
)");
if (result->HasError()) {
std::cerr << "Error: " << result->GetError() << std::endl;
return 1;
}
DEBUG_LOG("Execution completed successfully");
return 0;
} catch (const std::exception &e) {
DEBUG_LOG("Fatal error in main: " << e.what());
return 1;
}
return 0;
}