深入浅出生产者消费者模式

/ 并发编程设计模式 / 2 条评论 / 8779浏览

生产者-消费者模式是一个经典的多线程设计模式,它为多线程间的协作提供了良好的解决方案。也经常有面试官会让手写一个生产者消费者,从代码细节可以看出你对多线程编程的熟练程度,今天我们来详细看一下如何写出一个生产者消费者模式,并且逐步对其优化争取做到高性能。

结构剖析

在生产者-消费者模式中,通常有两类线程,一类是生产者线程一类是消费者线程。生产者线程负责提交用户请求,消费者线程则负责处理生产者提交的任务。

最简单粗暴的做法就是生产者每提交一个任务,消费者就立即处理,直到所有任务处理完。但是这样直接通信很容易出现性能上的问题,消费者必须等待它的生产者提交到任务才能执行,就不能达到真正的并行。同时生产者和消费者之间存在依赖关系,在设计上耦合度非常高,这是不可取的。那么最好的做法就是加一个中间层作为通信桥梁。

生产者和消费者之间通过共享内存缓存区进行通信。多个生产者线程将任务提交给共享内存缓存区,消费者线程并不直接与生产者线程通信,而在共享内存缓冲区获取任务,并行地处理。其中内存缓冲区的主要功能是数据再多线程间的共享,同时还可以通过该缓存区,缓解生产者和消费者间的性能差。它是生产者消费者模式的核心组件,既能作为通信的桥梁,又能避免两者直接通信,从而将生产者和消费者进行解耦。生产者不需要消费者的存在,消费者也不需要知道生产者的存在。

由于内存缓冲区的存在,允许生产者和消费者在执行速度上有差异,无论哪一方速度超过另一方,缓冲区都可以得到缓解,确保系统正常运行。

生产者-消费者模式UML图如下:

实测

下面是我用IDEA的工具生成的UML图

其中Queue的选择最为关键,一般情况下应该会选择阻塞式队列,也就是BlockingQueue,但是它的并发效果并不好,所以我选择了ConcurrentLinkedQueue,笔者经过实测,生产和消费100万个数据,使用LinkedBlockingQueue耗时是38秒,并且在处理了防超量生产的情况下仍然产生了超量生产,其原因就是其并发效果不如后者,后者使用的是CAS无锁实现,而使用ConcurrentLinkedQueue耗时是27秒,无超量生产情况产生。其中ConcurrentLinkedQueue运行时占用内存最大时为546MB(jconsole测得)。

代码实现

Data类

package thread.producerConsumer;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.javase_learning</p>
 * Created in 19:11 2019/5/19
 */
public class Data {
    private final int data;

    public Data(int data) {
        this.data = data;
    }

    public int getData() {
        return data;
    }

    @Override
    public String toString() {
        return "Data{" +
                "data=" + data +
                '}';
    }
}

生产者Producer

package thread.producerConsumer;

import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.javase_learning</p>
 * Created in 19:11 2019/5/19
 */
public class Producer implements Runnable {
    //  缓冲队列
    private ConcurrentLinkedQueue<Data> queue;
    //  每个生产者生产数据的数量
    private int produceNum;

    //  任务计数,同时模拟数据生产
    private static AtomicInteger count = new AtomicInteger(0);
    //  需要生产的总任务数量,出处有耦合性,实际开发时应当处理。
    private static AtomicInteger total = new AtomicInteger(Main.NUMS);

    public Producer(ConcurrentLinkedQueue<Data> queue, int produceNum) {
        this.queue = queue;
        this.produceNum = produceNum;
    }

    @Override
    public void run() {
        Data data = null;
        //  限定生产,防止超量生产
        if (total.get() <= 0) {
            Thread.currentThread().interrupt();
        }
        for (int i = 0; i < produceNum; i++) {
            //  再次检测,防止超量生产
            if (total.get() > 0) {
                //  构造数据
                data = new Data(count.incrementAndGet());
                //  向缓冲队列中提交
                if (!queue.offer(data)) {
                    System.out.println("Filed to put data:" + data);
                } else {
                    System.out.println("Successfully produced data:" + data);
                }
                //  需要生产的总量-1
                total.decrementAndGet();
            } else {
                Thread.currentThread().interrupt();
                break;
            }
        }
    }
}

消费者Consumer

package thread.producerConsumer;

import java.text.MessageFormat;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.javase_learning</p>
 * Created in 19:11 2019/5/19
 */
public class Consumer implements Runnable {
    private ConcurrentLinkedQueue<Data> queue;
    private CountDownLatch countDown;

    public Consumer(ConcurrentLinkedQueue<Data> queue, CountDownLatch countDown) {
        this.queue = queue;
        this.countDown = countDown;
    }

    @Override
    public void run() {
        Data data = queue.poll();
        if (data != null) {
            //  计算+1
            int res = data.getData() + 1;
            System.out.println("Data processing: " + MessageFormat.format("{0}+1={1}", data.getData(), res));
            //  记录该线程任务已结束
            countDown.countDown();
        }
    }
}

Main运行主类

package thread.producerConsumer;

import java.util.concurrent.*;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.javase_learning</p>
 * Created in 19:10 2019/5/19
 */
public class Main {
    public static final int NUMS = 1000000;

    public static void main(String[] args) {

        //  定义缓冲队列
        ConcurrentLinkedQueue<Data> queue = new ConcurrentLinkedQueue<>();
        //  创建生产者和消费者线程池,为防止冲爆内存,这里限定线程数
        ExecutorService producerEs = Executors.newFixedThreadPool(100);
        ExecutorService consumerEs = Executors.newFixedThreadPool(100);
        //  定义计时器,等待消费者全部消费完
        CountDownLatch countDown = new CountDownLatch(NUMS);

        long start = System.currentTimeMillis();
        long end = System.currentTimeMillis();

        //  向线程池中提交 NUMS 个生产任务
        for (int i = 0; i < NUMS; i++) {
            //  每个生产者以3倍的量生产,模拟生产者和消费者的速度差
            producerEs.execute(new Producer(queue,3));
            consumerEs.execute(new Consumer(queue, countDown));
        }
        try {
            countDown.await();
            end = System.currentTimeMillis();
            producerEs.shutdown();
            consumerEs.shutdown();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        System.out.println("Total time:" + (end - start));
    }
}

Disruptor

Disruptor是一个高效的无锁内存队列,它使用无锁的方式实现了一个有界环形队列,非常适合生产者-消费者模式。它相比于BlockingQueue具有更多的特点:

这里简单引用另一篇文章(https://www.jianshu.com/p/8473bbb556af)中的一部分Disruptor介绍

如其名,环形的缓冲区。曾经 RingBuffer 是 Disruptor 中的最主要的对象,但从3.0版本开始,其职责被简化为仅仅负责对通过 Disruptor 进行交换的数据(事件)进行存储和更新。在一些更高级的应用场景中,Ring Buffer 可以由用户的自定义实现来完全替代。

通过顺序递增的序号来编号管理通过其进行交换的数据(事件),对数据(事件)的处理过程总是沿着序号逐个递增处理。一个 Sequence 用于跟踪标识某个特定的事件处理者( RingBuffer/Consumer )的处理进度。虽然一个 AtomicLong 也可以用于标识进度,但定义 Sequence 来负责该问题还有另一个目的,那就是防止不同的 Sequence 之间的CPU缓存伪共享(Flase Sharing)问题。 (注:这是 Disruptor 实现高性能的关键点之一,网上关于伪共享问题的介绍已经汗牛充栋,在此不再赘述)。

Sequencer 是 Disruptor 的真正核心。此接口有两个实现类 SingleProducerSequencer、MultiProducerSequencer ,它们定义在生产者和消费者之间快速、正确地传递数据的并发算法。

用于保持对RingBuffer的 main published Sequence 和Consumer依赖的其它Consumer的 Sequence 的引用。 Sequence Barrier 还定义了决定 Consumer 是否还有可处理的事件的逻辑。

定义 Consumer 如何进行等待下一个事件的策略。 (注:Disruptor 定义了多种不同的策略,针对不同的场景,提供了不一样的性能表现)

在 Disruptor 的语义中,生产者和消费者之间进行交换的数据被称为事件(Event)。它不是一个被 Disruptor 定义的特定类型,而是由 Disruptor 的使用者定义并指定。

EventProcessor 持有特定消费者(Consumer)的 Sequence,并提供用于调用事件处理实现的事件循环(Event Loop)。

Disruptor 定义的事件处理接口,由用户实现,用于处理事件,是 Consumer 的真正实现。

即生产者,只是泛指调用 Disruptor 发布事件的用户代码,Disruptor 没有定义特定接口或类型。

Disruptor代码实现

笔者简单地写了一个用Disruptor实现的程序,结果惊人!同样是100万个数据,时间消耗只用了22秒,时间上和前面ConcurrentLinkedQueue差别不大,但有所提升,但是它的内存消耗居然最高时仅使用了60MB!!!在内存上节省了9倍的空间,可以看到这种内存复用的策略很有效果,在实际开发中很有帮助。

提示:Disruptor是第三方包,需要下载jar包导入项目或者使用Maven导入

数据类Data

package thread.producerConsumer.disruptor;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.disruptor.javase_learning</p>
 * Created in 21:57 2019/5/19
 */
public class Data {
    private long value;

    public long getValue() {
        return value;
    }

    public void setValue(long value) {
        this.value = value;
    }
}

数据生产工厂

package thread.producerConsumer.disruptor;

import com.lmax.disruptor.EventFactory;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.disruptor.javase_learning</p>
 * Created in 21:56 2019/5/19
 */
public class DataFactory implements EventFactory<Data> {
    @Override
    public Data newInstance() {
        return new Data();
    }
}

生产者Prodocer

package thread.producerConsumer.disruptor;

import com.lmax.disruptor.RingBuffer;

import java.nio.ByteBuffer;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.disruptor.javase_learning</p>
 * Created in 21:59 2019/5/19
 */
public class Producer {
    private final RingBuffer<Data> ringBuffer;

    public Producer(RingBuffer<Data> ringBuffer) {
        this.ringBuffer = ringBuffer;
    }

    public void pushData(ByteBuffer bb){
        long sequence = ringBuffer.next();
        try{
            Data even = ringBuffer.get(sequence);
            even.setValue(bb.getLong(0));
        }finally {
            ringBuffer.publish(sequence);
        }
    }
}

消费者Consumer

package thread.producerConsumer.disruptor;

import com.lmax.disruptor.WorkHandler;

import java.text.MessageFormat;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.disruptor.javase_learning</p>
 * Created in 21:51 2019/5/19
 */
public class Consumer implements WorkHandler<Data> {
    @Override
    public void onEvent(Data data) throws Exception {
        long res = data.getValue() + 1;
        System.out.println(MessageFormat.format("Data processing: {0} + 1 = {1}",data.getValue(),res));
    }
}

运行主类Main

package thread.producerConsumer.disruptor;

import com.lmax.disruptor.RingBuffer;
import com.lmax.disruptor.dsl.Disruptor;

import java.nio.ByteBuffer;
import java.util.concurrent.ThreadFactory;

/**
 * @author beifengtz
 * <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
 * <p>location: thread.producerConsumer.disruptor.javase_learning</p>
 * Created in 21:50 2019/5/19
 */
public class Main {
    private static final int NUMS = 10;
    private static final int SUM = 1000000;

    public static void main(String[] args) {
        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        long start = System.currentTimeMillis();
        DataFactory factory = new DataFactory();
        int bufferSize = 1024;
        Disruptor<Data> disruptor = new Disruptor<>(factory, bufferSize, new ThreadFactory() {
            @Override
            public Thread newThread(Runnable r) {
                return new Thread(r);
            }
        });

        //  构造100个消费者
        Consumer[] consumers = new Consumer[NUMS];
        for (int i = 0; i < NUMS; i++) {
            consumers[i] = new Consumer();
        }
        disruptor.handleEventsWithWorkerPool(consumers);
        disruptor.start();

        RingBuffer<Data> ringBuffer = disruptor.getRingBuffer();
        Producer producer = new Producer(ringBuffer);
        ByteBuffer bb = ByteBuffer.allocate(8);
        for (long l = 0; l < SUM; l++) {
            bb.putLong(0, l);
            producer.pushData(bb);
            System.out.println("Successfully produced data: " + l);
        }
        long end = System.currentTimeMillis();
        disruptor.shutdown();
        System.out.println("Total time: " + (end - start));
    }
}

微信公众号浏览体验更佳,在这里还有更多优秀文章为你奉上,快来关注吧!

北风IT之路

  1. 师兄,加油,还能评论吗?

  2. 加油,加油