spark SQL读取不带表头的txt文件时,如果不传入schema信息,则会自动给列命名_c0
、_c1
等。而且也无法通过调用df.as()
方法转换成dataset对象(甚至因为样例类的属性名称与df的列名不一致而抛出异常)。
这时候可以通过下面的方式添加schema
java
// 定义schema
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("word", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> dataFrame = sparkSession.createDataFrame(RowRDD, schema);// rdd -> dataframe
但是如果已经是dataframe对象则无法更新schema。 所以我们需要在加载文件的时候通过调用schema()
方法传入构造好的StructType
对象以创建dataframe。 例如:
java
// 定义schema
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("word", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("cnt", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> citydf = reader.format("text")
.option("delimiter", "\t")
.option("header", true)
.schema(schema)
.csv("D:\project\sparkDemo\inputs\city_info.txt");
那么这时候就有问题了,如果需要加载的文件很多,全都要手动创建列表逐个添加字段会非常麻烦。
那么可以封装StructType
对象的实例化方法,传入目标字段名称以及数据类型。 字段名称以及数据类型可以通过样例类获取。
StructType
对象的实例化方法
java
package src.main.utils;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import javax.activation.UnsupportedDataTypeException;
public class SchemaMaker {
private LinkedHashMap<String, String> schemaMap = new LinkedHashMap<>();
final private List<String> valueTypes = Stream.of("string", "integer", "double", "long").collect(Collectors.toList());
List<StructField> fields = new ArrayList<>();
public SchemaMaker(){
this.fields.clear();
}
public SchemaMaker(ArrayList<ArrayList<String>> dataList){
this.fields.clear();
for (ArrayList<String> data : dataList) {
int size = data.size();
if (size != 2){
throw new RuntimeException("每个数据必须为2个参数,第一个为字段名,第二个为字段类型");
}
String fieldName = data.get(0);
String fieldType = getLowCase(data.get(1));
if (checkType(fieldType)){
this.schemaMap.put(fieldName, fieldType);
}else {
throw new RuntimeException("数据类型不符合预期" + this.valueTypes.toString());
}
}
}
public void add(String fieldName, String fieldType){
String fieldtype = getLowCase(fieldType);
if (checkType(fieldtype)){
this.schemaMap.put(fieldName, fieldtype);
}else {
throw new RuntimeException("数据类型不符合预期" + this.valueTypes.toString());
}
}
private String getLowCase(String s){
return s.toLowerCase();
}
private boolean checkType(String typeValue){
return this.valueTypes.contains(typeValue);
}
private DataType getDataType (String typeValue) throws UnsupportedDataTypeException {
if (typeValue.equals("string")){
return DataTypes.StringType;
} else if (typeValue.equals("integer")) {
return DataTypes.IntegerType;
} else if (typeValue.equals("long")) {
return DataTypes.LongType;
} else if (typeValue.equals("double")) {
return DataTypes.DoubleType;
}else {
throw new UnsupportedDataTypeException(typeValue);
}
}
public StructType getStructType() throws UnsupportedDataTypeException {
for (Map.Entry<String, String> schemaValue : schemaMap.entrySet()) {
String fieldName = schemaValue.getKey();
String fieldType = schemaValue.getValue();
DataType fieldDataType = getDataType(fieldType);
this.fields.add(DataTypes.createStructField(fieldName, fieldDataType, true));
}
return DataTypes.createStructType(this.fields);
}
}
封装一层,通过传入的Object.class().getDeclaredFields()
方法获取的字段信息构造StructType
java
public static StructType getStructType(Field[] fields) throws UnsupportedDataTypeException {
ArrayList<ArrayList<String>> lists = new ArrayList<>();
for (Field field : fields) {
String name = field.getName();
AnnotatedType annotatedType = field.getAnnotatedType();
String[] typeSplit = annotatedType.getType().getTypeName().split("\.");
String type = typeSplit[typeSplit.length - 1];
ArrayList<String> tmpList = new ArrayList<String>();
tmpList.add(name);
tmpList.add(type);
lists.add(tmpList);
}
SchemaMaker schemaMaker = new SchemaMaker(lists);
return schemaMaker.getStructType();
}
样例类的定义
java
public static class City implements Serializable{
private Long cityid;
private String cityname;
private String area;
public City(Long cityid, String cityname, String area) {
this.cityid = cityid;
this.cityname = cityname;
this.area = area;
}
public Long getCityid() {
return cityid;
}
public void setCityid(Long cityid) {
this.cityid = cityid;
}
public String getCityname() {
return cityname;
}
public void setCityname(String cityname) {
this.cityname = cityname;
}
public String getArea() {
return area;
}
public void setArea(String area) {
this.area = area;
}
}
public static class Product implements Serializable{
private Long productid;
private String product;
private String product_from;
public Long getProductid() {
return productid;
}
public void setProductid(Long productid) {
this.productid = productid;
}
public String getProduct() {
return product;
}
public void setProduct(String product) {
this.product = product;
}
public String getProduct_from() {
return product_from;
}
public void setProduct_from(String product_from) {
this.product_from = product_from;
}
public Product(Long productid, String product, String product_from) {
this.productid = productid;
this.product = product;
this.product_from = product_from;
}
}
public static class UserVisitAction implements Serializable{
private String date;
private Long user_id;
private String session_id;
private Long page_id;
private String action_time;
private String search_keyword;
private Long click_category_id;
private Long click_product_id;
private String order_category_ids;
private String order_product_ids;
private String pay_category_ids;
private String pay_product_ids;
private Long city_id;
public String getDate() {
return date;
}
public void setDate(String date) {
this.date = date;
}
public Long getUser_id() {
return user_id;
}
public void setUser_id(Long user_id) {
this.user_id = user_id;
}
public String getSession_id() {
return session_id;
}
public void setSession_id(String session_id) {
this.session_id = session_id;
}
public Long getPage_id() {
return page_id;
}
public void setPage_id(Long page_id) {
this.page_id = page_id;
}
public String getAction_time() {
return action_time;
}
public void setAction_time(String action_time) {
this.action_time = action_time;
}
public String getSearch_keyword() {
return search_keyword;
}
public void setSearch_keyword(String search_keyword) {
this.search_keyword = search_keyword;
}
public Long getClick_category_id() {
return click_category_id;
}
public void setClick_category_id(Long click_category_id) {
this.click_category_id = click_category_id;
}
public Long getClick_product_id() {
return click_product_id;
}
public void setClick_product_id(Long click_product_id) {
this.click_product_id = click_product_id;
}
public String getOrder_category_ids() {
return order_category_ids;
}
public void setOrder_category_ids(String order_category_ids) {
this.order_category_ids = order_category_ids;
}
public String getOrder_product_ids() {
return order_product_ids;
}
public void setOrder_product_ids(String order_product_ids) {
this.order_product_ids = order_product_ids;
}
public String getPay_category_ids() {
return pay_category_ids;
}
public void setPay_category_ids(String pay_category_ids) {
this.pay_category_ids = pay_category_ids;
}
public String getPay_product_ids() {
return pay_product_ids;
}
public void setPay_product_ids(String pay_product_ids) {
this.pay_product_ids = pay_product_ids;
}
public Long getCity_id() {
return city_id;
}
public void setCity_id(Long city_id) {
this.city_id = city_id;
}
public UserVisitAction(String date, Long user_id, String session_id, Long page_id, String action_time, String search_keyword, Long click_category_id, Long click_product_id, String order_category_ids, String order_product_ids, String pay_category_ids, String pay_product_ids, Long city_id) {
this.date = date;
this.user_id = user_id;
this.session_id = session_id;
this.page_id = page_id;
this.action_time = action_time;
this.search_keyword = search_keyword;
this.click_category_id = click_category_id;
this.click_product_id = click_product_id;
this.order_category_ids = order_category_ids;
this.order_product_ids = order_product_ids;
this.pay_category_ids = pay_category_ids;
this.pay_product_ids = pay_product_ids;
this.city_id = city_id;
}
}
主程序部分
java
DataFrameReader reader = sparkSession.read();
StructType citySchema = getStructType(City.class.getDeclaredFields());
StructType productSchema = getStructType(Product.class.getDeclaredFields());
StructType actionSchema = getStructType(UserVisitAction.class.getDeclaredFields());
Dataset<Row> citydf = reader.format("text")
.option("delimiter", "\t")
.option("header", true)
.schema(citySchema)
.csv("D:\project\sparkDemo\inputs\city_info.txt");
Dataset<Row> productdf = reader.format("text")
.option("delimiter", "\t")
.option("header", true)
.schema(productSchema)
.csv("D:\project\sparkDemo\inputs\product_info.txt");
Dataset<Row> actiondf = reader.format("text")
.option("delimiter", "\t")
.option("header", true)
.schema(actionSchema)
.csv("D:\project\sparkDemo\inputs\user_visit_action.txt");
Dataset<City> cityDataset = citydf.as(Encoders.bean(City.class)); // 转换为ds对象
// cityDataset.show();
citydf.write().format("jdbc").option("url", "jdbc:mysql://172.20.143.219:3306/test")
.option("driver", "com.mysql.cj.jdbc.Driver").option("user", "root")
.option("password", "mysql").option("dbtable", "city_info").mode("overwrite").save();
productdf.write().format("jdbc").option("url", "jdbc:mysql://172.20.143.219:3306/test")
.option("driver", "com.mysql.cj.jdbc.Driver").option("user", "root")
.option("password", "mysql").option("dbtable", "product_info").mode("overwrite").save();
actiondf.write().format("jdbc").option("url", "jdbc:mysql://172.20.143.219:3306/test")
.option("driver", "com.mysql.cj.jdbc.Driver").option("user", "root")
.option("password", "mysql").option("dbtable", "user_visit_action").mode("overwrite").save();
通过这个方法自定义了样例类之后可以进行批量读取与处理txt文件了。
PS:在缺乏文件信息的时候不要贸然加载文件,否则可能会造成严重的后果。