spark里使用geohash处理数据之线程安全问题

1、背景:

一个兄弟在处理位置数据的时候,用到了geohash,从网上淘了一个工具类,上来就是一通干,但是发现一个很奇怪的问题:数据落到hive表里之后,每次运行的结果不一样,百思不得其姐,然后做了各种测试,以为是spark本身有bug,但是各种纠结之后,还是觉得应该是自己的问题,当时就想到了可能是线程安全问题,之所以说是这个问题,是因为把每个executor-cores设置为1的时候,就没问题,但是生产环境这个参数配置肯定是不行的,于是各种排查,到底是哪里出现了线程安全问题。

我先把原始代码贴上,供各位看官老爷们欣赏下:

java 复制代码
package com.weibo.sso.hive.utils;
import java.util.ArrayList;
import java.util.List;
/**
 * lbs工具类
 *
 *
 */

public class GeoHashUtil {
    public final double Max_Lat = 90;
    public final double Min_Lat = -90;
    public final double Max_Lng = 180;
    public final double Min_Lng = -180;
    /**
     * 纬度二值串长度
     */
    private static int latLength;
    /**
     * 经度二值串长度
     */
    private static int lngLength;


    //计算经纬度的最小单元,最小的区间的中心值
    private static double minLat;
    private static double minLng;




    private final String[] base32Lookup = {
            "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "b", "c", "d", "e", "f", "g", "h", "j", "k",
            "m", "n", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"
    };
    /**
     * 二值化:对经纬度二分逼近,大于中间值的为1,小于中间值的为0,将其转为长度为length的二值串
     *
     * @param min   区间最小值
     * @param max   区间最大值
     * @param value 经度或纬度
     * @param count 二分次数
     * @param list  二值串
     */
    private void convert(double min, double max, double value, int count, List<Character> list) {
        if (list.size() > (count - 1)) {
            return;
        }
        double mid = (max + min) / 2;
        if (value < mid) {
            list.add('0');
            convert(min, mid, value, count, list);
        } else {
            list.add('1');
            convert(mid, max, value, count, list);
        }
    }
    /**
     * 将合并的二值串转为base32串
     *
     * @param str 合并的二值串
     * @return base32串
     */
    private String base32Encode(final String str) {
//        System.out.println("转换base32前:"+str);
        String unit = "";
        StringBuilder sb = new StringBuilder();
        for (int start = 0; start < str.length(); start = start + 5) {
            unit = str.substring(start, start + 5);
            sb.append(base32Lookup[convertToIndex(unit)]);
        }
        return sb.toString();
    }

    /**
     * 每五个一组将二进制转为十进制
     *
     * @param str 五个为一个unit
     * @return 十进制数
     */
    private int convertToIndex(String str) {
        int length = str.length();
        int result = 0;
        for (int index = 0; index < length; index++) {
            result += str.charAt(index) == '0' ? 0 : 1 << (length - 1 - index);
        }
//        System.out.println(result);
        return result;
    }
    /**
     * 经纬度二值串合并:偶数位放经度,奇数位放纬度,把2串编码组合生成新串
     *
     * @param lat 纬度
     * @param lng 经度
     */
    public String encode(double lat, double lng) {
        if (latLength < 1 || lngLength < 1) {
            return "";
        }
        List<Character> latList = new ArrayList<>(latLength);
        List<Character> lngList = new ArrayList<>(lngLength);
        // 获取维度二值串
        convert(Min_Lat, Max_Lat, lat, latLength, latList);
        // 获取经度二值串
        convert(Min_Lng, Max_Lng, lng, lngLength, lngList);
//        System.out.println("latList: " + latList);
//        System.out.println("lngList: " + lngList);
        StringBuilder sb = new StringBuilder();
        for (int index = 0; index < latList.size(); index++) {
            sb.append(lngList.get(index)).append(latList.get(index));
        }
//        如果二者长度不一样,说明要求的精度为奇数,经度长度比纬度长度大1
        if (lngLength != latLength) {
            sb.append(lngList.get(lngList.size() - 1));
        }

        return base32Encode(sb.toString());
    }
    /**
     * 根据精度获取GeoHash串
     *
     * @param precise 精度
     * @return GeoHash串
     */
    public  String getGeoHash(double lat, double lng, int precise) {
        if (precise < 1 || precise > 9) {
            return "";
        }
        latLength = (precise * 5) / 2;
        if (precise % 2 == 0) {
            lngLength = latLength;
        } else {
            lngLength = latLength + 1;
        }

        return encode(lat, lng);

    }

    //查询该经纬度的geoHash以及该geoHash附近的8个geoHash点
    public String getArroundGeoHash(double lat, double lon, int precise) {

        if (precise < 1 || precise > 9) {
            return "";
        }
        latLength = (precise * 5) / 2;
        if (precise % 2 == 0) {
            lngLength = latLength;
        } else {
            lngLength = latLength + 1;
        }

        //计算经纬度的最小单元,最小的区间的中心值
        minLat = Max_Lat - Min_Lat;
        for (int i = 0; i < latLength; i++) {
            minLat /= 2.0;
        }
        minLng = Max_Lng - Min_Lng;
        for (int i = 0; i < lngLength; i++) {
            minLng /= 2.0;
        }

        double uplat = lat + minLat;
        double downLat = lat - minLat;

        double leftlng = lon - minLng;
        double rightLng = lon + minLng;

        String leftUp = encode(uplat, leftlng);


        String leftMid = encode(lat, leftlng);


        String leftDown = encode(downLat, leftlng);


        String midUp = encode(uplat, lon);


        String midMid = encode(lat, lon);


        String midDown = encode(downLat, lon);


        String rightUp = encode(uplat, rightLng);


        String rightMid = encode(lat, rightLng);


        String rightDown = encode(downLat, rightLng);



        return "nw="+leftUp+",w="+leftMid+",sw="+leftDown+",n="+midUp+",c="+midMid+",s="+midDown+",ne="+rightUp+",e="+rightMid+",se="+rightDown;
    }
    /**
     * 获取GeoHash6
     *
     * @return GeoHash6
     */
    public String getGeoHash6(double lat, double lng) {
        latLength = 15;
        lngLength = 15;
        return encode(lat, lng);
    }

    /**
     * 获取GeoHash7
     *
     * @return GeoHash7
     */
    public String getGeoHash7(double lat, double lng) {
        latLength = 17;
        lngLength = 18;
        return encode(lat, lng);
    }

    public static void main(String[] args) {

        System.out.println(new GeoHashUtil().getGeoHash(39.91092, 116.41338, 7));

        String arroundGeoHash = new GeoHashUtil().getArroundGeoHash(39.91092, 116.41338,8);
        System.out.println(arroundGeoHash);
    }
}

2、分析问题

我们来分析下上面的代码:

问题1:latLengthlngLength 是静态变量

latLengthlngLength 被声明为static,这意味着它们在所有线程间共享。如果多个线程同时调用 getGeoHashgetArroundGeoHash,它们会修改这些静态变量的值,可能导致线程间的相互影响。

具体情况:
  • 如果两个线程几乎同时调用getGeoHash,它们会修改静态的latLengthlngLength,导致计算出的经纬度二值串可能不一致。
  • 比如一个线程调用了 getGeoHash(39.91092, 116.41338, 7),设定了 latLength = 17,此时如果另一个线程调用 getGeoHash(39.91092, 116.41338, 5),会将 latLength 设置为一个新的值,影响了第一个线程的执行结果。

问题2:minLatminLng 是静态变量

同样,minLatminLng 也是静态变量,意味着它们在多线程环境下也会出现线程间干扰的问题。当不同线程同时调用 getArroundGeoHash 时,会因为这两个变量的值被不同的线程修改而导致错误的结果。

多说一句:很多人其实之所以写出来不安全的代码,是因为他不晓得什么是安全什么是不安全,如果说只读不写,也就是说大家都只读一个共享变量,那就不会有线程安全问题,就像spark的广播变量一样,就是readonly模式,线程不安全的两个常见条件:1、变量是共享的 2、有写操作 ,也就是涉及到了对共享变量的更改,问题往往就出在这里。

3、解决方案

要解决这些线程安全问题,主要是通过消除静态变量的竞争,将其变为实例变量或局部变量,以确保每个线程拥有自己的副本,最为传统的方案如下:

常规方案:
  1. latLengthlngLength 变为局部变量,只在每个方法中计算,而不是使用静态变量。
  2. minLatminLng 变为局部变量,避免多个线程共享这些值。
  3. 消除静态变量的共享状态,使每个线程都独立处理它们的值。

上述是常规思路,总感觉跟我的气质不符,所以想到了一个老朋友:threadlocal,为什么是它?不介绍下,恐怕不能服众.

ThreadLocal 是一种常见的线程安全技术,它允许每个线程拥有自己独立的变量副本,从而避免了线程之间的竞争问题,在本场景里,使用 ThreadLocal 可以确保每个线程都独立地处理 latLengthlngLength 以及 minLatminLng 这些变量,避免线程之间的相互影响。

为什么使用 ThreadLocal

  • 多个线程同时调用同一个方法时,如果方法中有共享的可变状态变量,可能会出现竞争问题。
  • 通过 ThreadLocal,可以让每个线程独立拥有自己的这些变量,而不会相互干扰。

优化方案:

我们可以使用 ThreadLocal 来管理 latLengthlngLengthminLatminLng 这些变量,每个线程会有自己独立的副本,线程间不会相互影响。接下来,我们将为 GeoHashUtil 类进行优化。

java 复制代码
package com.weibo.sso.hive.utils;

import java.util.ArrayList;
import java.util.List;

public class GeoHashUtil {

    public final double Max_Lat = 90;
    public final double Min_Lat = -90;
    public final double Max_Lng = 180;
    public final double Min_Lng = -180;

    private final String[] base32Lookup = {
            "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "b", "c", "d", "e", "f", "g", "h", "j", "k",
            "m", "n", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"
    };

    // 使用ThreadLocal来管理每个线程的latLength、lngLength、minLat和minLng
    private final ThreadLocal<Integer> latLength = ThreadLocal.withInitial(() -> 0);
    private final ThreadLocal<Integer> lngLength = ThreadLocal.withInitial(() -> 0);
    private final ThreadLocal<Double> minLat = ThreadLocal.withInitial(() -> 0.0);
    private final ThreadLocal<Double> minLng = ThreadLocal.withInitial(() -> 0.0);

    private void convert(double min, double max, double value, int count, List<Character> list) {
        if (list.size() > (count - 1)) {
            return;
        }
        double mid = (max + min) / 2;
        if (value < mid) {
            list.add('0');
            convert(min, mid, value, count, list);
        } else {
            list.add('1');
            convert(mid, max, value, count, list);
        }
    }

    private String base32Encode(final String str) {
        StringBuilder sb = new StringBuilder();
        for (int start = 0; start < str.length(); start = start + 5) {
            String unit = str.substring(start, start + 5);
            sb.append(base32Lookup[convertToIndex(unit)]);
        }
        return sb.toString();
    }

    private int convertToIndex(String str) {
        int length = str.length();
        int result = 0;
        for (int index = 0; index < length; index++) {
            result += str.charAt(index) == '0' ? 0 : 1 << (length - 1 - index);
        }
        return result;
    }

    // 用于编码的encode方法,将使用ThreadLocal中的变量
    public String encode(double lat, double lng) {
        if (latLength.get() < 1 || lngLength.get() < 1) {
            return "";
        }
        List<Character> latList = new ArrayList<>(latLength.get());
        List<Character> lngList = new ArrayList<>(lngLength.get());

        convert(Min_Lat, Max_Lat, lat, latLength.get(), latList);
        convert(Min_Lng, Max_Lng, lng, lngLength.get(), lngList);

        StringBuilder sb = new StringBuilder();
        for (int index = 0; index < latList.size(); index++) {
            sb.append(lngList.get(index)).append(latList.get(index));
        }

        if (lngLength.get() != latLength.get()) {
            sb.append(lngList.get(lngList.size() - 1));
        }

        return base32Encode(sb.toString());
    }

    // 获取GeoHash
    public String getGeoHash(double lat, double lng, int precise) {
        if (precise < 1 || precise > 9) {
            return "";
        }

        // 设置每个线程独立的 latLength 和 lngLength
        latLength.set((precise * 5) / 2);
        if (precise % 2 == 0) {
            lngLength.set(latLength.get());
        } else {
            lngLength.set(latLength.get() + 1);
        }

        return encode(lat, lng);
    }

    // 获取周围的GeoHash
    public String getArroundGeoHash(double lat, double lon, int precise) {
        if (precise < 1 || precise > 9) {
            return "";
        }

        latLength.set((precise * 5) / 2);
        if (precise % 2 == 0) {
            lngLength.set(latLength.get());
        } else {
            lngLength.set(latLength.get() + 1);
        }

        // 计算经纬度的最小单元
        minLat.set(Max_Lat - Min_Lat);
        for (int i = 0; i < latLength.get(); i++) {
            minLat.set(minLat.get() / 2.0);
        }
        minLng.set(Max_Lng - Min_Lng);
        for (int i = 0; i < lngLength.get(); i++) {
            minLng.set(minLng.get() / 2.0);
        }

        double uplat = lat + minLat.get();
        double downLat = lat - minLat.get();
        double leftlng = lon - minLng.get();
        double rightLng = lon + minLng.get();

        String leftUp = encode(uplat, leftlng);
        String leftMid = encode(lat, leftlng);
        String leftDown = encode(downLat, leftlng);
        String midUp = encode(uplat, lon);
        String midMid = encode(lat, lon);
        String midDown = encode(downLat, lon);
        String rightUp = encode(uplat, rightLng);
        String rightMid = encode(lat, rightLng);
        String rightDown = encode(downLat, rightLng);

        return "nw=" + leftUp + ",w=" + leftMid + ",sw=" + leftDown + ",n=" + midUp + ",c=" + midMid + ",s=" + midDown +
                ",ne=" + rightUp + ",e=" + rightMid + ",se=" + rightDown;
    }

    // 获取GeoHash6
    public String getGeoHash6(double lat, double lng) {
        latLength.set(15);
        lngLength.set(15);
        return encode(lat, lng);
    }

    // 获取GeoHash7
    public String getGeoHash7(double lat, double lng) {
        latLength.set(17);
        lngLength.set(18);
        return encode(lat, lng);
    }

    public static void main(String[] args) {
        GeoHashUtil util = new GeoHashUtil();
        System.out.println(util.getGeoHash(39.91092, 116.41338, 7));
        System.out.println(util.getArroundGeoHash(39.91092, 116.41338, 8));
    }
}

优化后的代码说明:

  1. ThreadLocal 的使用

    • latLengthlngLength 这两个原本是共享的静态变量现在使用 ThreadLocal 来处理,确保每个线程都有自己独立的长度变量,不会受到其他线程的影响。
    • minLatminLng 也通过 ThreadLocal 处理,避免多线程之间的干扰。
  2. ThreadLocal 的初始化

    • ThreadLocal 使用 ThreadLocal.withInitial() 方法进行初始化,为每个线程单独管理变量副本。
    • 每个线程会独立拥有自己的 latLengthlngLengthminLatminLng,避免线程间的竞争。
  3. 线程安全性

    • 每个线程在计算 GeoHash 时,使用的 latLengthlngLength 等变量都是当前线程独有的,互不影响,从而确保线程安全

这种优化的好处:

  • 完全隔离线程的变量状态 :通过 ThreadLocal 确保每个线程拥有自己的副本,变量不会相互干扰。
  • 简化了锁的使用 :通过 ThreadLocal 直接避免了线程间的竞争,无需显式加锁来保证线程安全。

以上

相关推荐
Elastic 中国社区官方博客3 小时前
使用 Elastic AI Assistant for Search 和 Azure OpenAI 实现从 0 到 60 的转变
大数据·人工智能·elasticsearch·microsoft·搜索引擎·ai·azure
Francek Chen5 小时前
【大数据技术基础 | 实验十二】Hive实验:Hive分区
大数据·数据仓库·hive·hadoop·分布式
Natural_yz9 小时前
大数据学习17之Spark-Core
大数据·学习·spark
莫叫石榴姐10 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
魔珐科技11 小时前
以3D数字人AI产品赋能教育培训人才发展,魔珐科技亮相AI+教育创新与人才发展大会
大数据·人工智能
上优11 小时前
uniapp 选择 省市区 省市 以及 回显
大数据·elasticsearch·uni-app
陌小呆^O^12 小时前
Cmakelist.txt之Liunx-rabbitmq
分布式·rabbitmq
samLi062012 小时前
【更新】中国省级产业集聚测算数据及协调集聚指数数据(2000-2022年)
大数据
Mephisto.java12 小时前
【大数据学习 | Spark-Core】Spark提交及运行流程
大数据·学习·spark