sobota 18. februára 2012

JAXB and Commons pool

Lately I made a mistake, which went unnoticed for quite a long time. In an effort to improve performance of JAXB (Java Architecture for XML Binding) operations, I cached instances of javax.xml.bind.Marshaller and javax.xml.bind.Unmarshaller. This article explains why this was not a good idea and describes how pooling with Apache Commons Pool can be used instead, to improve overall JAXB performance.


JAXB

JAXB API is fairly verbose, however when working with XML we generally do not want to create SchemaFactory or JAXBContext... What we really need is just one method to marshal object into XML string and second to unmarshal string to object. This goal is described by JaxbHelper interface. When parameter schemaLocation is present XML is validated against XSD schema, in case it is null validation is not performed.

/**
 * Custom interface, which simplifies JAXB API.
 */
public interface JaxbHelper {

    public <T> String marshal(T instance, @Nullable String schemaLocation) throws Exception;

    public <T> T unmarshal(String xml, Class<T> clazz, @Nullable String schemaLocation) throws Exception;
}

SimpleJaxbHelper

Let's start with the simplest possible implementation. It will serve later as a unit of measurement for comparing performance. This and all following JaxbHelper implementations are thread-safe and re-entrant. Parameter schemaLocation is relative to the class, so in case XSD schema is in the same package as class, then name of schema is sufficient and path can be omitted.

/**
 * Simplest possible implementation, does not use cache nor pooling. It is thread-safe and re-entrant.
 */
public class SimpleJaxbHelper implements JaxbHelper {

    @Override
    public <T> String marshal(T instance, @Nullable String schemaLocation) throws JAXBException, SAXException {
        StringWriter result = new StringWriter();

        JAXBContext jaxbContext = JAXBContext.newInstance(instance.getClass());

        Marshaller marshaller = jaxbContext.createMarshaller();

        if (schemaLocation != null) {
            SchemaFactory schemaFactory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
            Schema schema = schemaFactory.newSchema(instance.getClass().getResource(schemaLocation));
            marshaller.setSchema(schema);
        }

        marshaller.marshal(instance, result);

        return result.toString();
    }

    @Override
    public <T> T unmarshal(String xml, Class<T> clazz, @Nullable String schemaLocation) throws JAXBException, SAXException {
        JAXBContext jaxbContext = JAXBContext.newInstance(clazz);

        Unmarshaller unmarshaller = jaxbContext.createUnmarshaller();

        if (schemaLocation != null) {
            SchemaFactory schemaFactory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
            Schema schema = schemaFactory.newSchema(clazz.getResource(schemaLocation));
            unmarshaller.setSchema(schema);
        }

        //noinspection unchecked
        return (T) unmarshaller.unmarshal(new StringReader(xml));
    }

CachedJaxbHelper

Second implementation uses cache for javax.xml.bind.JAXBContext instances which are thread-safe, at least in JAXB RI implementation. It caches instances of javax.xml.validation.Schema as well, because these are immutable and there is really no reason why not to do so. Notice that java.util.concurrent.ConcurrentHashMap is used here because it's get method generally does not block, and may overlap with put method

/**
 * Implementation which holds it's JAXBContext and Schema instances in a map. It is thread-safe and re-entrant.
 */
public class CachedJaxbHelper implements JaxbHelper {

    private static Map<Class, JAXBContext> jaxbContextMap = new ConcurrentHashMap<Class, JAXBContext>();
    private static Map<String, Schema> schemaMap = new ConcurrentHashMap<String, Schema>();

    private static JAXBContext getJaxbContext(Class clazz) throws JAXBException {
        JAXBContext jaxbContext = jaxbContextMap.get(clazz);

        if (jaxbContext == null) {
            jaxbContext = JAXBContext.newInstance(clazz);
            jaxbContextMap.put(clazz, jaxbContext);
        }
        return jaxbContext;
    }

    private static Schema getSchema(Class clazz, String schemaLocation) throws JAXBException, SAXException {
        Schema schema = schemaMap.get(schemaLocation);

        if (schema == null) {
            SchemaFactory schemaFactory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
            schema = schemaFactory.newSchema(clazz.getResource(schemaLocation));
            schemaMap.put(schemaLocation, schema);
        }
        return schema;
    }

    @Override
    public <T> String marshal(T instance, @Nullable String schemaLocation) throws JAXBException, SAXException {
        StringWriter result = new StringWriter();

        JAXBContext jaxbContext = getJaxbContext(instance.getClass());

        Marshaller marshaller = jaxbContext.createMarshaller();

        if (schemaLocation != null) {
            Schema schema = getSchema(instance.getClass(), schemaLocation);
            marshaller.setSchema(schema);
        }

        marshaller.marshal(instance, result);

        return result.toString();
    }

    @Override
    public <T> T unmarshal(String xml, Class<T> clazz, @Nullable String schemaLocation) throws JAXBException, SAXException {
        JAXBContext jaxbContext = getJaxbContext(clazz);

        Unmarshaller unmarshaller = jaxbContext.createUnmarshaller();

        if (schemaLocation != null) {
            Schema schema = getSchema(clazz, schemaLocation);
            unmarshaller.setSchema(schema);
        }

        //noinspection unchecked
        return (T) unmarshaller.unmarshal(new StringReader(xml));
    }
}

PooledJaxbHelper

Third implementation leaves Schema instances cached as before, but uses pooling for javax.xml.bind.JAXBContext, javax.xml.bind.Marshaller and javax.xml.bind.Unmarshaller. First to notice in PooledJaxbHelper is PoolKey, this inner class serves as key for marshaller, unmarshaller pools. It encapsulates class and schema location, schema location may be null so equals and hashCode methods must be generated accordingly.

Commons pool is very easy to work with, it contains org.apache.commons.pool.impl.GenericKeyedObjectPool which can hold instances relative to a key. It must be provided with factories which can create pooled objects. JaxbContextFactory is responsible for creating new JAXBContext instances where MarshallerFactory and UnmarshallerFactory are responsible for creating Marshaller and Unmarshaller instances. MarshallerFactory and UnmarshallerFactory already use jaxbContextPool for borrowing JAXBContext instance. Every object borrowed from pool with borrowObject method must be returned with returnObject method. Pool may be provided with optional GenericKeyedObjectPool.Config to change default configuration. I decided to invalidate Marshaller and Unmarshaller instances when exception happens with invalidateObject method, so that this instance is not to be used again.

/**
 * Implementation which holds it's Schema instances in a map, and uses pooling for JAXBContext, Marshaller and Unmarshaller instances.
 * It is thread-safe and re-entrant.
 */
public class PooledJaxbHelper implements JaxbHelper {

    private static class PoolKey {
        private Class clazz;
        private String schemaLocation;

        private PoolKey(Class clazz, @Nullable String schemaLocation) {
            this.clazz = clazz;
            this.schemaLocation = schemaLocation;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            PoolKey poolKey = (PoolKey) o;

            return clazz.equals(poolKey.clazz) && !(schemaLocation != null ? !schemaLocation.equals(poolKey.schemaLocation) : poolKey.schemaLocation != null);

        }

        @Override
        public int hashCode() {
            int result = clazz.hashCode();
            result = 31 * result + (schemaLocation != null ? schemaLocation.hashCode() : 0);
            return result;
        }

        public Class getClazz() {
            return clazz;
        }

        @Nullable
        public String getSchemaLocation() {
            return schemaLocation;
        }
    }

    private static class MarshallerFactory extends BaseKeyedPoolableObjectFactory<PoolKey, Marshaller> {
        @Override
        public Marshaller makeObject(PoolKey key) throws Exception {
            JAXBContext jaxbContext = jaxbContextPool.borrowObject(key.getClazz());

            Marshaller marshaller = jaxbContext.createMarshaller();

            if (key.getSchemaLocation() != null) {
                Schema schema = getSchema(key);
                marshaller.setSchema(schema);
            }

            jaxbContextPool.returnObject(key.getClazz(), jaxbContext);

            return marshaller;
        }
    }

    private static class UnmarshallerFactory extends BaseKeyedPoolableObjectFactory<PoolKey, Unmarshaller> {
        @Override
        public Unmarshaller makeObject(PoolKey key) throws Exception {
            JAXBContext jaxbContext = jaxbContextPool.borrowObject(key.getClazz());

            Unmarshaller unmarshaller = jaxbContext.createUnmarshaller();

            if (key.getSchemaLocation() != null) {
                Schema schema = getSchema(key);
                unmarshaller.setSchema(schema);
            }

            jaxbContextPool.returnObject(key.getClazz(), jaxbContext);

            return unmarshaller;
        }
    }

    private static class JaxbContextFactory extends BaseKeyedPoolableObjectFactory<Class, JAXBContext> {
        @Override
        public JAXBContext makeObject(Class clazz) throws Exception {
            return JAXBContext.newInstance(clazz);
        }
    }

    private static class CustomPoolConfig extends GenericKeyedObjectPool.Config {
        {
            maxIdle = 3;
            maxActive = 10;
            maxTotal = 100;
            minIdle = 1;
            whenExhaustedAction = GenericKeyedObjectPool.WHEN_EXHAUSTED_GROW;
            timeBetweenEvictionRunsMillis = 1000L * 60L * 10L;
            numTestsPerEvictionRun = 50;
            minEvictableIdleTimeMillis = 1000L * 60L * 5L; // 30 min.
        }
    }

    private static Schema getSchema(PoolKey poolKey) throws JAXBException, SAXException {
        Schema schema = schemaMap.get(poolKey.getSchemaLocation());

        if (schema == null) {
            SchemaFactory schemaFactory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
            schema = schemaFactory.newSchema(poolKey.getClazz().getResource(poolKey.getSchemaLocation()));
            schemaMap.put(poolKey.getSchemaLocation(), schema);
        }
        return schema;
    }

    private static Map<String, Schema> schemaMap = new ConcurrentHashMap<String, Schema>();
    private static GenericKeyedObjectPool<Class, JAXBContext> jaxbContextPool = new GenericKeyedObjectPool<Class, JAXBContext>(new JaxbContextFactory(), new CustomPoolConfig());
    private static GenericKeyedObjectPool<PoolKey, Marshaller> marshallerPool = new GenericKeyedObjectPool<PoolKey, Marshaller>(new MarshallerFactory(), new CustomPoolConfig());
    private static GenericKeyedObjectPool<PoolKey, Unmarshaller> unmarshallerPool = new GenericKeyedObjectPool<PoolKey, Unmarshaller>(new UnmarshallerFactory(), new CustomPoolConfig());

    @Override
    public <T> String marshal(T instance, @Nullable String schemaLocation) throws Exception {
        StringWriter result = new StringWriter();

        PoolKey poolKey = new PoolKey(instance.getClass(), schemaLocation);
        Marshaller marshaller = marshallerPool.borrowObject(poolKey);

        try {
            marshaller.marshal(instance, result);

            marshallerPool.returnObject(poolKey, marshaller);

            return result.toString();
        } catch (Exception e) {
            marshallerPool.invalidateObject(poolKey, marshaller);
            throw new RuntimeException(e);
        }
    }

    @Override
    public <T> T unmarshal(String xml, Class<T> clazz, @Nullable String schemaLocation) throws Exception {
        T result;

        PoolKey poolKey = new PoolKey(clazz, schemaLocation);
        Unmarshaller unmarshaller = unmarshallerPool.borrowObject(poolKey);

        try {
            //noinspection unchecked
            result = (T) unmarshaller.unmarshal(new StringReader(xml));

            unmarshallerPool.returnObject(poolKey, unmarshaller);

            return result;
        } catch (Exception e) {
            unmarshallerPool.invalidateObject(poolKey, unmarshaller);
            throw new RuntimeException(e);
        }
    }
}

Conclusion

As for performance there are many factors which must be taken into consideration, like for example complexity of XML documents, pool configuration, number of processors, size of memory and so on... Therefore absolute numbers do not have any meaning here, but here are some relative results which seem to be consistent enough.

INFO  JaxbHelperTest - testCompareAllToSimple
INFO  JaxbHelperTest - SimpleJaxbHelper / CachedJaxbHelper ratio: 5.813814804912555
INFO  JaxbHelperTest - SimpleJaxbHelper / PooledJaxbHelper ratio: 14.213429365043625

INFO  JaxbHelperTest - testCompareAllToSimpleMultipleThreads
INFO  JaxbHelperTest - SimpleJaxbHelper / CachedJaxbHelper ratio: 7.6223063332965735
INFO  JaxbHelperTest - SimpleJaxbHelper / PooledJaxbHelper ratio: 9.20335326080807

Which basically says that CachedJaxbHelper is 6 to 8 times faster than SimpleJaxbHelper and PooledJaxbHelper is 9 to 14 times faster then SimpleJaxbHelper.


Appendix

Project structure
JaxbPool/
|-- pom.xml
|-- src
    |-- main
    |   |-- java
    |   |   `-- eu
    |   |       `-- zont
    |   |           `-- jaxbpool
    |   |               |-- core
    |   |               |   |-- CachedJaxbHelper.java
    |   |               |   |-- JaxbHelper.java
    |   |               |   |-- PooledJaxbHelper.java
    |   |               |   `-- SimpleJaxbHelper.java
    |   |               `-- xml
    |   |                   |-- ObjectFactory.java
    |   |                   |-- PersonType.java
    |   |                   `-- SampleType.java
    |   `-- resources
    |       |-- eu
    |       |   `-- zont
    |       |       `-- jaxbpool
    |       |           `-- xml
    |       |               `-- sample.xsd
    |       |-- log4j.dtd
    |       `-- log4j.xml
    `-- test
        |-- java
        |   `-- eu
        |       `-- zont
        |           `-- jaxbpool
        |               `-- core
        |                   `-- JaxbHelperTest.java
        `-- resources
            `-- eu
                `-- zont
                    `-- jaxbpool
                        `-- xml
                            `-- sample.xml
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>eu.zont.jaxbpool</groupId>
    <artifactId>jaxb-pool</artifactId>
    <version>1.0-SNAPSHOT</version>

    <dependencies>
        <dependency>
            <groupId>commons-pool</groupId>
            <artifactId>commons-pool</artifactId>
            <version>1.6</version>
        </dependency>
        <dependency>
            <groupId>org.kohsuke.jetbrains</groupId>
            <artifactId>annotations</artifactId>
            <version>9.0</version>
        </dependency>
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>1.2.16</version>
        </dependency>
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.10</version>
            <scope>test</scope>
        </dependency>
    </dependencies>
</project>
Tests
public class JaxbHelperTest {

    private static final Logger log = Logger.getLogger(JaxbHelperTest.class);

    private static final String SAMPLE_SCHEMA_LOCATION = "sample.xsd";
    private static final int TO_MILLISECONDS = 1000000;
    private static final int HEAVY_LOAD = 1000;
    private static final int NUM_THREADS = 10;


    private SampleType createSample() {
        ObjectFactory objectFactory = new ObjectFactory();

        SampleType sample = objectFactory.createSampleType();

        PersonType person = objectFactory.createPersonType();
        person.setFirstname("firstname");
        person.setSurname("surname");

        sample.getPerson().add(person);

        return sample;
    }

    private void testJaxbHelper(JaxbHelper jaxbHelper) throws Exception {
        SampleType sample = createSample();

        String sampleXml = jaxbHelper.marshal(sample, SAMPLE_SCHEMA_LOCATION);
        SampleType sampleCopy = jaxbHelper.unmarshal(sampleXml, SampleType.class, SAMPLE_SCHEMA_LOCATION);

        assertNotNull(sampleXml);
        assertNotNull(sampleCopy);
        assertEquals(sample.getPerson().size(), sampleCopy.getPerson().size());
        assertEquals(sample.getPerson().get(0).getFirstname(), sampleCopy.getPerson().get(0).getFirstname());
        assertEquals(sample.getPerson().get(0).getSurname(), sampleCopy.getPerson().get(0).getSurname());
    }

    @Test
    public void testSimpleJaxbHelper() throws Exception {
        testJaxbHelper(new SimpleJaxbHelper());
    }

    @Test
    public void testCachedJaxbHelper() throws Exception {
        testJaxbHelper(new CachedJaxbHelper());
    }

    @Test
    public void testPooledJaxbHelper() throws Exception {
        testJaxbHelper(new PooledJaxbHelper());
    }

    private long testJaxbHelperLoad(JaxbHelper jaxbHelper, int load) throws Exception {
        SampleType sample = createSample();

        long startTime = System.nanoTime();

        for (int i = 0; i < load; i++) {
            String sampleXml = jaxbHelper.marshal(sample, SAMPLE_SCHEMA_LOCATION);
            jaxbHelper.unmarshal(sampleXml, SampleType.class, SAMPLE_SCHEMA_LOCATION);
        }

        return System.nanoTime() - startTime;
    }

    @Test
    public void testSimpleJaxbHelperHeavyLoad() throws Exception {
        long estimatedTime = testJaxbHelperLoad(new SimpleJaxbHelper(), HEAVY_LOAD);
        log.info("SimpleJaxbHelper estimated time: " + estimatedTime / TO_MILLISECONDS);
    }

    @Test
    public void testCachedJaxbHelperHeavyLoad() throws Exception {
        long estimatedTime = testJaxbHelperLoad(new CachedJaxbHelper(), HEAVY_LOAD);
        log.info("CachedJaxbHelper estimated time: " + estimatedTime / TO_MILLISECONDS);
    }

    @Test
    public void testPooledJaxbHelperHeavyLoad() throws Exception {
        long estimatedTime = testJaxbHelperLoad(new PooledJaxbHelper(), HEAVY_LOAD);
        log.info("PooledJaxbHelper estimated time: " + estimatedTime / TO_MILLISECONDS);
    }

    @Test
    public void testCompareAllToSimple() throws Exception {
        long simpleEstimatedTime = testJaxbHelperLoad(new SimpleJaxbHelper(), HEAVY_LOAD);
        long cachedEstimatedTime = testJaxbHelperLoad(new CachedJaxbHelper(), HEAVY_LOAD);
        long pooledEstimatedTime = testJaxbHelperLoad(new PooledJaxbHelper(), HEAVY_LOAD);

        log.info("testCompareAllToSimple");
        log.info("SimpleJaxbHelper / CachedJaxbHelper ratio: " + (double) simpleEstimatedTime / (double) cachedEstimatedTime);
        log.info("SimpleJaxbHelper / PooledJaxbHelper ratio: " + (double) simpleEstimatedTime / (double) pooledEstimatedTime);
    }


    private class JaxbTask implements Callable<Long> {
        private JaxbHelper jaxbHelper;
        private int load;

        public JaxbTask(JaxbHelper jaxbHelper, int load) {
            this.jaxbHelper = jaxbHelper;
            this.load = load;
        }

        public Long call() throws Exception {
            return testJaxbHelperLoad(jaxbHelper, load);
        }
    }


    private long testJaxbHelperHeavyLoadMultipleThreads(JaxbHelper jaxbHelper, int numThreads, int load) throws Exception {
        long estimatedTime = 0;

        ExecutorService threadExecutor = Executors.newFixedThreadPool(numThreads);

        List<JaxbTask> taskList = new ArrayList<JaxbTask>(numThreads);

        for (int i = 0; i < numThreads; i++) {
            taskList.add(new JaxbTask(jaxbHelper, load / numThreads));
        }

        List<Future<Long>> results = threadExecutor.invokeAll(taskList);

        threadExecutor.shutdown();

        boolean finished = threadExecutor.awaitTermination(1, TimeUnit.MINUTES);

        if (finished) {
            for (Future<Long> result : results) {
                estimatedTime += result.get();
            }
        } else {
            fail("Some of the test threads failed to finish correctly.");
        }

        return estimatedTime;
    }

    @Test
    public void testCompareSimpleAndPooledMultipleThreads() throws Exception {
        long simpleEstimatedTime = testJaxbHelperHeavyLoadMultipleThreads(new SimpleJaxbHelper(), NUM_THREADS, HEAVY_LOAD);
        long pooledEstimatedTime = testJaxbHelperHeavyLoadMultipleThreads(new PooledJaxbHelper(), NUM_THREADS, HEAVY_LOAD);

        log.info("SimpleJaxbHelper / PooledJaxbHelper ratio: " + (double) simpleEstimatedTime / (double) pooledEstimatedTime);
    }

    @Test
    public void testCompareSimpleAndCachedMultipleThreads() throws Exception {
        long simpleEstimatedTime = testJaxbHelperHeavyLoadMultipleThreads(new SimpleJaxbHelper(), NUM_THREADS, HEAVY_LOAD);
        long cachedEstimatedTime = testJaxbHelperHeavyLoadMultipleThreads(new CachedJaxbHelper(), NUM_THREADS, HEAVY_LOAD);

        log.info("SimpleJaxbHelper / CachedJaxbHelper ratio: " + (double) simpleEstimatedTime / (double) cachedEstimatedTime);
    }

    @Test
    public void testCompareAllToSimpleMultipleThreads() throws Exception {
        long simpleEstimatedTime = testJaxbHelperHeavyLoadMultipleThreads(new SimpleJaxbHelper(), NUM_THREADS, HEAVY_LOAD);
        long cachedEstimatedTime = testJaxbHelperHeavyLoadMultipleThreads(new CachedJaxbHelper(), NUM_THREADS, HEAVY_LOAD);
        long pooledEstimatedTime = testJaxbHelperHeavyLoadMultipleThreads(new PooledJaxbHelper(), NUM_THREADS, HEAVY_LOAD);

        log.info("testCompareAllToSimpleMultipleThreads");
        log.info("SimpleJaxbHelper / CachedJaxbHelper ratio: " + (double) simpleEstimatedTime / (double) cachedEstimatedTime);
        log.info("SimpleJaxbHelper / PooledJaxbHelper ratio: " + (double) simpleEstimatedTime / (double) pooledEstimatedTime);
    }
}