前文实现了UDF和UDAF,还有一类函数是表函数,它放在From 子句中,返回一个集合。DuckDB中已有PostgreSQL插件,但我们可以用pqxx库实现一个简易的只读read_pg()表函数。
提示词如下:
请将libpqxx库集成到我们的程序,使它能对postgresql数据库操作,并把数据与duckdb打通,比如能在一个sql中访问pg和duckdb中的表的关联结果,先做一个简单的表函数read_pg(db,table)返回一个表,可以执行select * from read_pg(db,table);select * from read_pg(db,table1) a,read_pg(db,table2) b where a.x=b.x;其中db是postgresql连接字符串,比如postgresql://user:secret@localhost/mydb,table是一个varchar
DeepSeek编写的代码一开始总是无限重复调用Function函数,后来经过添加
cpp
// 执行已完成的处理
if (state.execution_finished) {
output.SetCardinality(0);
return;
}
后能正确输出postgresql某个表的数据了,加where 条件也能查出。另一个插曲是,DeepSeek编写的表函数注册代码总是不对,我从网上找了一个例子,虽然我的程序不是插件,却能调用ExtensionUtil::RegisterFunction来注册。
源代码如下
cpp
#include <pqxx/pqxx>
#include <memory>
#include <unordered_map>
#include <duckdb.hpp>
#include "duckdb/function/function_set.hpp"
#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp"
class PGConnectionPool {
private:
std::unordered_map<std::string, std::shared_ptr<pqxx::connection>> connections;
public:
pqxx::connection& getConnection(const std::string& conn_str) {
auto it = connections.find(conn_str);
if (it == connections.end()) {
auto conn = std::make_shared<pqxx::connection>(conn_str);
connections[conn_str] = conn;
return *conn;
}
return *(it->second);
}
};
static PGConnectionPool pg_pool;
struct PGTableFunctionData : public duckdb::TableFunctionData {
std::string conn_str;
std::string table_name;
duckdb::vector<duckdb::LogicalType> return_types;
duckdb::vector<std::string> return_names;
PGTableFunctionData(std::string conn_str, std::string table_name,
duckdb::vector<duckdb::LogicalType> return_types,
duckdb::vector<std::string> return_names)
: conn_str(std::move(conn_str)), table_name(std::move(table_name)),
return_types(std::move(return_types)), return_names(std::move(return_names)) {}
duckdb::unique_ptr<duckdb::FunctionData> Copy() const override {
return duckdb::make_uniq<PGTableFunctionData>(conn_str, table_name, return_types, return_names);
}
bool Equals(const duckdb::FunctionData &other) const override {
auto &other_data = other.Cast<PGTableFunctionData>();
return conn_str == other_data.conn_str && table_name == other_data.table_name;
}
};
struct PGGlobalState : public duckdb::GlobalTableFunctionState {
pqxx::connection* conn = nullptr;
std::unique_ptr<pqxx::work> txn;
pqxx::result result;
pqxx::result::const_iterator it;
bool initialized = false;
bool execution_finished = false;
};
struct PGTableFunction {
static duckdb::TableFunction GetFunction() {
return duckdb::TableFunction(
"read_pg",
{duckdb::LogicalType::VARCHAR, duckdb::LogicalType::VARCHAR},
Function,
Bind,
InitGlobal
);
}
static duckdb::unique_ptr<duckdb::GlobalTableFunctionState> InitGlobal(
duckdb::ClientContext &context,
duckdb::TableFunctionInitInput &input
) {
return duckdb::make_uniq<PGGlobalState>();
}
static duckdb::unique_ptr<duckdb::FunctionData> Bind(
duckdb::ClientContext &context,
duckdb::TableFunctionBindInput &input,
duckdb::vector<duckdb::LogicalType> &return_types,
duckdb::vector<std::string> &return_names
) {
auto conn_str = input.inputs[0].GetValue<std::string>();
auto table_name = input.inputs[1].GetValue<std::string>();
try {
auto& conn = pg_pool.getConnection(conn_str);
pqxx::work txn(conn);
auto r = txn.exec("SELECT column_name, data_type FROM information_schema.columns "
"WHERE table_name = " + txn.quote(table_name) + " ORDER BY ordinal_position");
for (const auto& row : r) {
return_names.push_back(row[0].as<std::string>());
std::string pg_type = row[1].as<std::string>();
if (pg_type == "integer" || pg_type == "bigint") {
return_types.push_back(duckdb::LogicalType::BIGINT);
} else if (pg_type == "text" || pg_type == "varchar") {
return_types.push_back(duckdb::LogicalType::VARCHAR);
} else if (pg_type == "double precision") {
return_types.push_back(duckdb::LogicalType::DOUBLE);
} else if (pg_type == "boolean") {
return_types.push_back(duckdb::LogicalType::BOOLEAN);
} else {
return_types.push_back(duckdb::LogicalType::VARCHAR);
}
}
return duckdb::make_uniq<PGTableFunctionData>(conn_str, table_name, return_types, return_names);
} catch (const std::exception& e) {
throw std::runtime_error("PostgreSQL error: " + std::string(e.what()));
}
}
static void Function(
duckdb::ClientContext &context,
duckdb::TableFunctionInput &data,
duckdb::DataChunk &output
) {
auto &bind_data = data.bind_data->Cast<PGTableFunctionData>();
auto &state = data.global_state->Cast<PGGlobalState>();
if (state.execution_finished) {
output.SetCardinality(0);
return;
}
try {
if (!state.initialized) {
state.conn = &pg_pool.getConnection(bind_data.conn_str);
state.txn = std::make_unique<pqxx::work>(*state.conn);
state.result = state.txn->exec("SELECT * FROM " + state.txn->quote_name(bind_data.table_name));
state.it = state.result.begin();
state.initialized = true;
}
idx_t row_count = 0;
while (state.it != state.result.end() && row_count < STANDARD_VECTOR_SIZE) {
const auto& row = *state.it;
for (duckdb::idx_t col = 0; col < static_cast<duckdb::idx_t>(row.size()); col++) {
auto field = row[static_cast<pqxx::row::size_type>(col)];
if (field.is_null()) {
output.data[col].SetValue(row_count, duckdb::Value());
} else {
std::string value = field.as<std::string>();
switch (bind_data.return_types[col].id()) {
case duckdb::LogicalTypeId::BIGINT:
output.data[col].SetValue(row_count, duckdb::Value::BIGINT(std::stoll(value)));
break;
case duckdb::LogicalTypeId::DOUBLE:
output.data[col].SetValue(row_count, duckdb::Value::DOUBLE(std::stod(value)));
break;
case duckdb::LogicalTypeId::BOOLEAN:
output.data[col].SetValue(row_count, duckdb::Value::BOOLEAN(value == "t" || value == "true"));
break;
default:
output.data[col].SetValue(row_count, duckdb::Value(value));
}
}
}
row_count++;
++state.it;
}
output.SetCardinality(row_count);
if (state.it == state.result.end()) {
state.execution_finished = true;
state.txn->commit();
}
} catch (const std::exception& e) {
//if (state.txn && state.txn->is_open()) state.txn->abort();
throw std::runtime_error("PostgreSQL error: " + std::string(e.what()));
}
}
};
测试代码如下
cpp
#include "duckdb.hpp"
#include "readpg5.cpp"
#include <iostream>
#include "duckdb/main/extension_util.hpp"
using namespace duckdb;
using namespace std;
int main() {
DuckDB db(nullptr);
Connection con(db);
try {
DatabaseInstance& db_instance = *db.instance;
ExtensionUtil::RegisterFunction(db_instance, PGTableFunction::GetFunction());
} catch (const exception &e) {
cerr << "初始化错误: " << e.what() << endl;
return 1;
}
cout << "=== 测试1: 查询PostgreSQL表函数 ===" << endl;
auto result = con.Query("SELECT * FROM read_pg('postgresql://[email protected]/postgres', 't2')");
if (result->HasError()) {
cerr << "查询错误: " << result->GetError() << endl;
} else {
result->Print();
}
cout << "\n=== 测试1.1: 带条件的同一个表查询 ===" << endl;
result = con.Query("SELECT * FROM read_pg('postgresql://[email protected]/postgres', 't2') WHERE tid = 2");
if (result->HasError()) {
cerr << "查询错误: " << result->GetError() << endl;
} else {
result->Print();
}
cout << "\n=== 测试2: 带条件的查询 ===" << endl;
result = con.Query("SELECT * FROM read_pg('postgresql://[email protected]/postgres', 't') WHERE a = 2");
if (result->HasError()) {
cerr << "查询错误: " << result->GetError() << endl;
} else {
result->Print();
}
cout << "\n=== 测试3: 多表查询 ===" << endl;
result = con.Query("SELECT tname FROM read_pg('postgresql://[email protected]/postgres', 't2') a, "
"read_pg('postgresql://[email protected]/postgres', 't') b WHERE a.tid = b.a");
if (result->HasError()) {
cerr << "查询错误: " << result->GetError() << endl;
} else {
result->Print();
}
cout << "\n=== 测试4: pg和duckdb多表查询 ===" << endl;
con.Query("create table duckdb_t as select 2 a union all select 3");
result = con.Query("SELECT tname FROM read_pg('postgresql://[email protected]/postgres', 't2') a, "
"duckdb_t b WHERE a.tid = b.a");
if (result->HasError()) {
cerr << "查询错误: " << result->GetError() << endl;
} else {
result->Print();
}
cout << "\n=== 测试完成 ===" << endl;
return 0;
}
编译命令行
bash
export LIBRARY_PATH=/par/duck/build/src
export LD_LIBRARY_PATH=/par/duck/build/src
g++ -std=c++17 -o readpg2 testpg5.cpp -lduckdb -lpqxx -lpq -I /par/duck/src/include