Following up on How to use MDC with thread pools? how can one use MDC with a ForkJoinPool? Specifically, I how can one wrap a ForkJoinTask so MDC values are set before executing a task?
The following seems to work for me:
import java.lang.Thread.UncaughtExceptionHandler;
import java.util.Map;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.MDC;
/**
 * A {@link ForkJoinPool} that inherits MDC contexts from the thread that queues a task.
 *
 * @author Gili Tzabari
 */
public final class MdcForkJoinPool extends ForkJoinPool
{
    /**
     * Creates a new MdcForkJoinPool.
     *
     * @param parallelism the parallelism level. For default value, use {@link java.lang.Runtime#availableProcessors}.
     * @param factory     the factory for creating new threads. For default value, use
     *                    {@link #defaultForkJoinWorkerThreadFactory}.
     * @param handler     the handler for internal worker threads that terminate due to unrecoverable errors encountered
     *                    while executing tasks. For default value, use {@code null}.
     * @param asyncMode   if true, establishes local first-in-first-out scheduling mode for forked tasks that are never
     *                    joined. This mode may be more appropriate than default locally stack-based mode in applications
     *                    in which worker threads only process event-style asynchronous tasks. For default value, use
     *                    {@code false}.
     * @throws IllegalArgumentException if parallelism less than or equal to zero, or greater than implementation limit
     * @throws NullPointerException     if the factory is null
     * @throws SecurityException        if a security manager exists and the caller is not permitted to modify threads
     *                                  because it does not hold
     *                                  {@link java.lang.RuntimePermission}{@code ("modifyThread")}
     */
    public MdcForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler,
        boolean asyncMode)
    {
        super(parallelism, factory, handler, asyncMode);
    }
    @Override
    public void execute(ForkJoinTask<?> task)
    {
        // See http://stackoverflow.com/a/19329668/14731
        super.execute(wrap(task, MDC.getCopyOfContextMap()));
    }
    @Override
    public void execute(Runnable task)
    {
        // See http://stackoverflow.com/a/19329668/14731
        super.execute(wrap(task, MDC.getCopyOfContextMap()));
    }
    private <T> ForkJoinTask<T> wrap(ForkJoinTask<T> task, Map<String, String> newContext)
    {
        return new ForkJoinTask<T>()
        {
            private static final long serialVersionUID = 1L;
            /**
             * If non-null, overrides the value returned by the underlying task.
             */
            private final AtomicReference<T> override = new AtomicReference<>();
            @Override
            public T getRawResult()
            {
                T result = override.get();
                if (result != null)
                    return result;
                return task.getRawResult();
            }
            @Override
            protected void setRawResult(T value)
            {
                override.set(value);
            }
            @Override
            protected boolean exec()
            {
                // According to ForkJoinTask.fork() "it is a usage error to fork a task more than once unless it has completed
                // and been reinitialized". We therefore assume that this method does not have to be thread-safe.
                Map<String, String> oldContext = beforeExecution(newContext);
                try
                {
                    task.invoke();
                    return true;
                }
                finally
                {
                    afterExecution(oldContext);
                }
            }
        };
    }
    private Runnable wrap(Runnable task, Map<String, String> newContext)
    {
        return () ->
        {
            Map<String, String> oldContext = beforeExecution(newContext);
            try
            {
                task.run();
            }
            finally
            {
                afterExecution(oldContext);
            }
        };
    }
    /**
     * Invoked before running a task.
     *
     * @param newValue the new MDC context
     * @return the old MDC context
     */
    private Map<String, String> beforeExecution(Map<String, String> newValue)
    {
        Map<String, String> previous = MDC.getCopyOfContextMap();
        if (newValue == null)
            MDC.clear();
        else
            MDC.setContextMap(newValue);
        return previous;
    }
    /**
     * Invoked after running a task.
     *
     * @param oldValue the old MDC context
     */
    private void afterExecution(Map<String, String> oldValue)
    {
        if (oldValue == null)
            MDC.clear();
        else
            MDC.setContextMap(oldValue);
    }
}
and
import java.util.Map;
import java.util.concurrent.CountedCompleter;
import org.slf4j.MDC;
/**
 * A {@link CountedCompleter} that inherits MDC contexts from the thread that queues a task.
 *
 * @author Gili Tzabari
 * @param <T> The result type returned by this task's {@code get} method
 */
public abstract class MdcCountedCompleter<T> extends CountedCompleter<T>
{
    private static final long serialVersionUID = 1L;
    private final Map<String, String> newContext;
    /**
     * Creates a new MdcCountedCompleter instance using the MDC context of the current thread.
     */
    protected MdcCountedCompleter()
    {
        this(null);
    }
    /**
     * Creates a new MdcCountedCompleter instance using the MDC context of the current thread.
     *
     * @param completer this task's completer; {@code null} if none
     */
    protected MdcCountedCompleter(CountedCompleter<?> completer)
    {
        super(completer);
        this.newContext = MDC.getCopyOfContextMap();
    }
    /**
     * The main computation performed by this task.
     */
    protected abstract void computeWithContext();
    @Override
    public final void compute()
    {
        Map<String, String> oldContext = beforeExecution(newContext);
        try
        {
            computeWithContext();
        }
        finally
        {
            afterExecution(oldContext);
        }
    }
    /**
     * Invoked before running a task.
     *
     * @param newValue the new MDC context
     * @return the old MDC context
     */
    private Map<String, String> beforeExecution(Map<String, String> newValue)
    {
        Map<String, String> previous = MDC.getCopyOfContextMap();
        if (newValue == null)
            MDC.clear();
        else
            MDC.setContextMap(newValue);
        return previous;
    }
    /**
     * Invoked after running a task.
     *
     * @param oldValue the old MDC context
     */
    private void afterExecution(Map<String, String> oldValue)
    {
        if (oldValue == null)
            MDC.clear();
        else
            MDC.setContextMap(oldValue);
    }
}
MdcForkJoinPool instead of the common ForkJoinPool.MdcCountedCompleter instead of CountedCompleter.Here is some additional information to go along with @Gili's answer.
Test that shows that the solution works (note that there will be lines without the Context, but at least they won't be the WRONG context, which is what was happening with a normal ForkJoinPool).
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertThat;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import org.junit.Test;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.LoggerContext;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.OutputStreamAppender;
public class MDCForkJoinPoolTest {
    private static final Logger log = (Logger) LoggerFactory.getLogger("mdc-test");
    // you can demonstrate the problem I'm trying to fix by changing the below to a normal ForkJoinPool and then running the test
    private ForkJoinPool threads = new MDCForkJoinPool(16);
    private Semaphore threadsRunning = new Semaphore(-99);
    private ByteArrayOutputStream bio = new ByteArrayOutputStream();
    @Test
    public void shouldCopyManagedDiagnosticContextWhenUsingForkJoinPool() throws Exception {
        for (int i = 0 ; i < 100; i++) {
            Thread t = new Thread(simulatedRequest(), "MDC-Test-"+i);
            t.setDaemon(true);
            t.start();
        }
        // set up the appender to grab the output
        LoggerContext lc = (LoggerContext) LoggerFactory.getILoggerFactory();
        OutputStreamAppender<ILoggingEvent> appender = new OutputStreamAppender<>();
        LogbackEncoder encoder = new LogbackEncoder();
        encoder.setPattern("%X{mdc_val:-}=%m%n");
        encoder.setContext(lc);
        encoder.start();
        appender.setEncoder(encoder);
        appender.setImmediateFlush(true);
        appender.setContext(lc);
        appender.setOutputStream(bio);
        appender.start();
        log.addAppender(appender);
        log.setAdditive(false);
        log.setLevel(Level.INFO);
        assertThat("timed out waiting for threads to complete.", threadsRunning.tryAcquire(300, TimeUnit.SECONDS), is(true));
        Set<String> ids = new HashSet<>();
        try (BufferedReader r = new BufferedReader(new InputStreamReader(new ByteArrayInputStream(bio.toByteArray()), Charset.forName("utf8")))) {
            r.lines().forEach(line->{
                System.out.println(line);
               String[] vals = line.split("=");
               if (!vals[0].isEmpty()) {
                   ids.add(vals[0]);
                   assertThat(vals[1], startsWith(vals[0]));
               }
            });
        }
        assertThat(ids.size(), is(100));
    }
    private Runnable simulatedRequest() {
        return () -> {
            String id = UUID.randomUUID().toString();
            MDC.put("mdc_val", id);
            Map<String, String> context = MDC.getCopyOfContextMap();
            threads.submit(()->{
                MDC.setContextMap(context);
                IntStream.range(0, 100).parallel().forEach((i)->{
                   log.info("{} - {}", id, i); 
                });
            }).join();
            threadsRunning.release();
        };
    }
}
Also, here are the additional methods that should be overridden in the original answer.
    @Override
    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }
    @Override
    public <T> ForkJoinTask<T> submit(Callable<T> task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }
    @Override
    public <T> ForkJoinTask<T> submit(Runnable task, T result) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()), result);
    }
    @Override
    public ForkJoinTask<?> submit(Runnable task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }
    private <T> Callable<T> wrap(Callable<T> task, Map<String, String> newContext)
    {
        return () ->
        {
            Map<String, String> oldContext = beforeExecution(newContext);
            try
            {
                return task.call();
            }
            finally
            {
                afterExecution(oldContext);
            }
        };
    }
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With