使用线程池按行并发取二维数组最大值
快手后端二面问题,由于网上直接搜竟然没有搜出来,自己写了一下
生成二维数组
生成二维数组的公共类
java
class RandomArray{
public static double[][] getDoubleArray(int row, int col){
double[][] array = new double[row][col];
for(int i=0;i<row;i++){
for(int j=0;j<col;j++){
array[i][j] = Math.random() * (i * 10L + 1);
}
}
return array;
}
public static long[][] getLongArray(int row, int col){
long[][] array = new long[row][col];
for(int i=0;i<row;i++){
for(int j=0;j<col;j++){
array[i][j] = (long)(Math.random() * (i * 10L + 1));
}
}
return array;
}
}
使用Callable实现线程
主要是实现下可返回值的线程,即通过
Future.get()
获取线程返回值。
java
import java.util.*;
import java.util.concurrent.*;
class ArrayMax implements Callable<Double> {
private double[] array;
private int ind;
private double max = Double.MIN_VALUE;
public ArrayMax(double[] array, int ind){
this.array = array;
this.ind = ind;
}
@Override
public Double call() throws Exception{
for(int i=0;i<array.length;i++){
max = Math.max(max, array[i]);
}
// 只能这样来模拟执行时间不同
// Thread.sleep((long) (Math.random() * 1000));
System.out.println(Thread.currentThread().getName() + " of task " + ind + " max value: " + max);
return max;
}
public double getMax(){
return max;
}
}
public class ArrayMaxTest {
static ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
20,
40,
60 * 60,
TimeUnit.SECONDS,
new ArrayBlockingQueue<>(5000),
new ThreadPoolExecutor.CallerRunsPolicy());
private static Double max = Double.MIN_VALUE;
private static int row, col;
private static double[][] array;
private static void unsafeMaintainMax(Double input){
// max = Math.max(max, input);
if(input > max) {
max = input;
}
}
private static void useThreadPool(){
for(int i=0;i<row;i++){
try{
ArrayMax arrayMax = new ArrayMax(array[i], i);
Future f = threadPoolExecutor.submit(arrayMax);
// System.out.println("Thread " + i + " : " + f.get());
// TODO : 猜测由于等待执行完还是按顺序,所以这个其实是安全的!!!
unsafeMaintainMax((Double) f.get());
} catch (Exception exception){
System.out.println(exception);
}
}
threadPoolExecutor.shutdown();
}
private static void useThreadPool2() throws ExecutionException, InterruptedException {
List<Future> futureList = new ArrayList<>();
Map<Integer, Future> futureMap = new HashMap<>();
Set<Future> futureSet = new HashSet<>();
for(int i=0;i<row;i++){
try{
ArrayMax arrayMax = new ArrayMax(array[i], i);
Future f = threadPoolExecutor.submit(arrayMax);
// System.out.println("Thread " + i + " : " + f.get());
// futureList.add(f);
// futureMap.put(i, f);
futureSet.add(f);
} catch (Exception exception){
System.out.println(exception);
}
}
// XXX:由于.get 方法是阻塞方法,(不论线程有没有sleep)直接遍历 这些线程 都是按顺序执行的
// Thread.sleep(5000);
// for(Future f : futureList){
// // Future.get 是个阻塞方法。会阻塞当前线程(主线程),要配合.isDone函数
// unsafeMaintainMax((Double) f.get());
// }
// 采用轮询遍历,此时这些线程才不是顺序执行了
while(!futureSet.isEmpty()){
List<Future> wait2Remove = new ArrayList<>();
for(Future f : futureSet){
if(f.isDone()){
// 由于是主线程进行取最大值,所以不论怎样都是安全的
unsafeMaintainMax((Double) f.get());
// futureSet.remove(f);
wait2Remove.add(f);
}
}
// 要滞后删除,不能直接删除,否则会报错
for(Future f : wait2Remove) {
futureSet.remove(f);
}
}
// while(!futureMap.isEmpty()){
// for(Map.Entry<Integer, Future> entry : futureMap.entrySet()){
// // System.out.print(entry.getKey() + " ");
// if(entry.getValue().isDone()){
// unsafeMaintainMax((Double) entry.getValue().get());
// futureMap.remove(entry.getKey());
// // System.out.println();
// }
// }
// }
threadPoolExecutor.shutdown();
threadPoolExecutor.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
}
private static void useNormal(){
Double submax = Double.MIN_VALUE;
for(int i=0;i<row;i++){
for(int j=0;j<col;j++){
submax = Math.max(submax, array[i][j]);
}
}
unsafeMaintainMax(submax);
}
public static void main(String[] args) throws ExecutionException, InterruptedException {
row = 400;
col = 2000;
long startTime = 0, endTime = 0;
array = RandomArray.getDoubleArray(row, col);
startTime = System.currentTimeMillis();
useThreadPool2();
// useNormal();
System.out.println("unsafe Max : " + max);
endTime = System.currentTimeMillis();
System.out.println("开始时间:" + startTime +
"\n结束时间:" + endTime +
"\n用时:" + (endTime - startTime));
}
}
这里注意要先把线程都运行起来再使用Future.get()
获取返回值,否则运行马上获取则是顺序执行的,那么就没有意义了,这就还要写个自旋遍历并配合Future.isDone()
方法来并行获取结果。后期可以考虑使用CountDownLatch
获取结果。
使用Runnable获取线程
使用单例模式维护最大值。
java
import java.util.concurrent.*;
public class ArrayMaxTest2 implements Runnable {
static ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
20,
40,
60 * 60,
TimeUnit.SECONDS,
new ArrayBlockingQueue<>(5000),
new ThreadPoolExecutor.CallerRunsPolicy());
private double[] array;
private double max = Double.MIN_VALUE;
private int ind;
public ArrayMaxTest2(double[] array, int ind){
this.array = array;
this.ind = ind;
}
@Override
public void run(){
for(int i=0;i<array.length;i++){
max = Math.max(max, array[i]);
}
System.out.println(Thread.currentThread().getName() + " of task " + ind + " max value: " + max);
unsafeMaintainMax(max);
// dclMainTainMax(max);
}
/
private static Double ans = Double.MIN_VALUE;
private static void unsafeMaintainMax(Double input){
ans = Math.max(ans, input);
}
private static Double dclAns = Double.MIN_VALUE;
private static void dclMainTainMax(Double input){
synchronized (dclAns){
dclAns = Math.max(dclAns, input);
}
}
public static void main(String[] args){
int row = 400, col = 2000;
long startTime = 0, endTime = 0;
double[][] twoDimArray = RandomArray.getDoubleArray(row, col);
startTime = System.currentTimeMillis();
for(int i=0;i<row;i++){
try{
ArrayMaxTest2 arrayMax = new ArrayMaxTest2(twoDimArray[i], i);
threadPoolExecutor.execute(arrayMax);
} catch (Exception exception){
System.out.println(exception);
}
}
threadPoolExecutor.shutdown(); // 阻止新来任务的提交
// 这样前面的线程还没有执行完
// System.out.println("unsafe Max : " + ans);
// System.out.println("dcl Max : " + dclAns);
// endTime = System.currentTimeMillis();
// System.out.println("开始时间:" + startTime + "\n结束时间:" + endTime + "\n用时:" + (endTime - startTime));
try {
// 等待所有线程执行完
threadPoolExecutor.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
// TODO:还可以使用 CountDownLatch
System.out.println("unsafe Max : " + ans);
System.out.println("dcl Max : " + dclAns);
endTime = System.currentTimeMillis();
System.out.println("开始时间:" + startTime + "\n结束时间:" + endTime + "\n用时:" + (endTime - startTime));
} catch (InterruptedException interruptedException){
interruptedException.printStackTrace();
}
}
}
这里也可以发现,如果使用单例模式,不使用sychronized
和volatile
同步,确实会出现问题。