如何限制创建的线程数并等待主线程直到任何一个线程找到答案?

问题描述

这是找到第一对数字(除了 1)的 LCM 和 HCF 之和等于该数字的代码

import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

class PerfectPartition {
    static long gcd(long a,long b) {
        if (a == 0)
            return b;
        return gcd(b % a,a);
    }

    // method to return LCM of two numbers
    static long lcm(long a,long b) {
        return (a / gcd(a,b)) * b;
    }

    long[] getPartition(long n) {
        var ref = new Object() {
            long x;
            long y;
            long[] ret = null;
        };

        Thread mainThread = Thread.currentThread();
        ThreadGroup t = new ThreadGroup("InnerLoop");

        for (ref.x = 2; ref.x < (n + 2) / 2; ref.x++) {
            if (t.activeCount() < 256) {

                new Thread(t,() -> {
                    for (ref.y = 2; ref.y < (n + 2) / 2; ref.y++) {
                        long z = lcm(ref.x,ref.y) + gcd(ref.x,ref.y);
                        if (z == n) {
                            ref.ret = new long[]{ref.x,ref.y};

                            t.interrupt();
                            break;
                        }
                    }
                },"Thread_" + ref.x).start();

                if (ref.ret != null) {
                    return ref.ret;
                }
            } else {
                ref.x--;
            }
        }//return new long[]{1,n - 2};

        return Objects.requireNonNullElseGet(ref.ret,() -> new long[]{1,n - 2});
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(system.in);
        long n = sc.nextLong();
        long[] partition = new PerfectPartition().getPartition(n);
        System.out.println(partition[0] + " " + partition[1]);
    }
}

我想在找到第一对后立即停止代码执行。但是,main 线程只是继续运行并打印 1n-1
限制数量的最佳解决方案是什么?线程数(2 到 max of long)?

预期输出n=4):2 2
预期输出n=8):4 4

解决方法

限制数量的最佳解决方案是什么?线程数(

首先,您应该考虑将执行代码的硬件(例如内核数量)以及您要并行化的算法类型,即是否受 CPU 限制?、内存限制?、IO 限制,等等。

您的代码受 CPU 限制,因此,从性能的角度来看,如果运行的线程数超过系统中可用内核的数量,通常不会带来回报。与往常一样,尽你所能。

其次,在您的情况下,您需要以证明并行性的方式在线程之间分配工作:

  for (ref.x = 2; ref.x < (n + 2) / 2; ref.x++) {
        if (t.activeCount() < 256) {

            new Thread(t,() -> {
                for (ref.y = 2; ref.y < (n + 2) / 2; ref.y++) {
                    long z = lcm(ref.x,ref.y) + gcd(ref.x,ref.y);
                    if (z == n) {
                        ref.ret = new long[]{ref.x,ref.y};

                        t.interrupt();
                        break;
                    }
                }
            },"Thread_" + ref.x).start();

            if (ref.ret != null) {
                return ref.ret;
            }
        } else {
            ref.x--;
        }
    }//return new long[]{1,n - 2};

你是怎么做的,但是 IMO 以一种令人费解的方式;更简单的 IMO 是显式并行化循环,在线程之间拆分其迭代,并删除所有 ThreadGroup 相关逻辑。

第三,注意竞争条件,例如:

var ref = new Object() {
    long x;
    long y;
    long[] ret = null;
};

这个对象在线程之间共享,并由它们更新,从而导致竞争条件。正如我们即将看到的,您实际上并不需要这样的共享对象。

让我们一步一步来:

首先,找出您应该在线程数与内核数相同的情况下执行代码的线程数:

int cores = Runtime.getRuntime().availableProcessors();

定义并行工作(这是循环分布的一个可能示例):

public void run() {
    for (int x = 2; && x < (n + 2) / 2; x ++) {
        for (int y = 2 + threadID; y < (n + 2) / 2; y += total_threads) {
            long z = lcm(x,y) + gcd(x,y);
            if (z == n) {
                // do something 
            }
        }
    }
}

在下面的代码中,我们以循环的方式在线程之间拆分要并行完成的工作,如下图所示:

enter image description here

我想在找到第一对后立即停止代码执行。

有几种方法可以实现这一点。我将提供最简单的 IMO,尽管不是最复杂的。当已经找到结果时,您可以使用变量向线程发送信号,例如:

final AtomicBoolean found;

每个线程将共享相同的 AtomicBoolean 变量,以便在其中一个线程中执行的更改对其他线程也是可见的:

@Override
public void run() {
    for (int x = 2 ; !found.get() && x < (n + 2) / 2; x ++) {
        for (int y = 2 + threadID; y < (n + 2) / 2; y += total_threads)  {
            long z = lcm(x,y);
            if (z == n) {
                synchronized (found) {
                    if(!found.get()) {
                        rest[0] = x;
                        rest[1] = y;
                        found.set(true);
                    }
                    return;
                }
            }
        }
    }
}

由于您要求提供代码片段示例,因此这里是一个简单的非防弹(且未经过适当测试)运行编码示例:

class ThreadWork implements Runnable{

    final long[] rest;
    final AtomicBoolean found;
    final int threadID;
    final int total_threads;
    final long n;

    ThreadWork(long[] rest,AtomicBoolean found,int threadID,int total_threads,long n) {
        this.rest = rest;
        this.found = found;
        this.threadID = threadID;
        this.total_threads = total_threads;
        this.n = n;
    }

    static long gcd(long a,long b) {
        return (a == 0) ? b : gcd(b % a,a);
    }

    static long lcm(long a,long b,long gcd) {
        return (a / gcd) * b;
    }

    @Override
    public void run() {
        for (int x = 2; !found.get() && x < (n + 2) / 2; x ++) {
            for (int y = 2 + threadID; !found.get() && y < (n + 2) / 2; y += total_threads) {
                long result = gcd(x,y);
                long z = lcm(x,y,result) + result;
                if (z == n) {
                    synchronized (found) {
                        if(!found.get()) {
                            rest[0] = x;
                            rest[1] = y;
                            found.set(true);
                        }
                        return;
                    }
                }
            }
        }
    }
}

class PerfectPartition {

    public static void main(String[] args) throws InterruptedException {
        Scanner sc = new Scanner(System.in);
        final long n = sc.nextLong();
       final int total_threads = Runtime.getRuntime().availableProcessors();

        long[] rest = new long[2];
        AtomicBoolean found = new AtomicBoolean();

        double startTime = System.nanoTime();
        Thread[] threads = new Thread[total_threads];
        for(int i = 0; i < total_threads; i++){
            ThreadWork task = new ThreadWork(rest,found,i,total_threads,n);
            threads[i] = new Thread(task);
            threads[i].start();
        }

        for(int i = 0; i < total_threads; i++){
            threads[i].join();
        }

        double estimatedTime = System.nanoTime() - startTime;
        System.out.println(rest[0] + " " + rest[1]);


        double elapsedTimeInSecond = estimatedTime / 1_000_000_000;
        System.out.println(elapsedTimeInSecond + " seconds");
    }
}

输出:

4 -> 2 2
8 -> 4 4

以此代码为灵感,想出最符合您要求的解决方案。在您完全理解这些基础知识后,尝试使用更复杂的 Java 功能(例如 ExecutorsFuturesCountDownLatch)改进该方法。


新更新:顺序优化

查看 gcd 方法:

  static long gcd(long a,long b) {
        return (a == 0)? b : gcd(b % a,a);
  }

lcm 方法:

static long lcm(long a,long b) {
    return (a / gcd(a,b)) * b;
}

以及它们的使用方式:

long z = lcm(ref.x,ref.y);

您可以通过不在 gcd(a,b) 方法中再次调用 lcm 来优化您的顺序代码。所以将 lcm 方法更改为:

static long lcm(long a,long gcd) {
    return (a / gcd) * b;
}

long z = lcm(ref.x,ref.y);

long result = gcd(ref.x,ref.y)
long z = lcm(ref.x,ref.y,gcd) + gcd;

我在这个答案中提供的代码已经反映了这些变化。

,

首先,你错过了在线程上调用“start”。

new Thread(t,() -> {
    ...
    ...
},"Thread_" + ref.x).start();

关于您的问题,要限制您可以使用线程池的线程数,例如 Executors.newFixedThreadPool(int nThreads)。

并且要停止执行,您可以让主线程在单次计数 CountDownLatch 上等待,并在工作线程中成功匹配时倒计时锁存器,并在主线程中在锁存器上的等待完成时关闭线程池。

如您所问,以下是使用线程池和 CountDownLatch 的示例代码:

import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

public class LcmHcmSum {

    static long gcd(long a,long b) {
        if (a == 0)
            return b;
        return gcd(b % a,a);
    }

    // method to return LCM of two numbers
    static long lcm(long a,long b) {
        return (a / gcd(a,b)) * b;
    }
    
    long[] getPartition(long n) {
        singleThreadJobSubmitter.execute(() -> {
            for (int x = 2; x < (n + 2) / 2; x++) {
                    submitjob(n,x);
                    if(numberPair != null) break;  // match found,exit the loop
            }
            try {
                jobsExecutor.shutdown();  // process the already submitted jobs
                jobsExecutor.awaitTermination(10,TimeUnit.SECONDS);  // wait for the completion of the jobs
                
                if(numberPair == null) {  // no match found,all jobs processed,nothing more to do,count down the latch 
                    latch.countDown();
                }
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        
        try {
            latch.await();
            singleThreadJobSubmitter.shutdownNow();
            jobsExecutor.shutdownNow();
            
        } catch (InterruptedException e1) {
            e1.printStackTrace();
        }
        return Objects.requireNonNullElseGet(numberPair,() -> new long[]{1,n - 2});
    }

    private Future<?> submitjob(long n,long x) {
        return jobsExecutor.submit(() -> {
            for (int y = 2; y < (n + 2) / 2; y++) {
                long z = lcm(x,y);
                if (z == n) {
                    synchronized(LcmHcmSum.class) {  numberPair = new long[]{x,y}; }
                    latch.countDown();
                    break;
                }
            }
        });
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long n = sc.nextLong();
        long[] partition = new LcmHcmSum().getPartition(n);
        System.out.println(partition[0] + " " + partition[1]);
    }
    
    private static CountDownLatch latch = new CountDownLatch(1);
    private static ExecutorService jobsExecutor = Executors.newFixedThreadPool(4);
    private static volatile long[] numberPair = null;
    private static ExecutorService singleThreadJobSubmitter = Executors.newSingleThreadExecutor();      
    

}
,

您可以使用线程池。类似的东西:

ExecutorService executor = Executors.newFixedThreadPool(256);

然后将任务(或可运行的)调度到其中。

完成后,停止添加任务,并终止线程池(终止也会阻止向线程池添加新任务的能力)。