Handshaking tutorial with Netty

Edit: The code for this project is now hosted at GitHub, along with other tutorials.

As we all did, I first started using plain NIO for highly scalable network based apps (mostly server-side). Then I jumped to Apache MINA, which I used for just about anything network related until a couple of months ago I discovered Netty.

While MINA took the whole Java asynchronous I/O a step further, Netty pushed it two steps beyond that. I must admit MINA is simpler to understand at first and was really fast to get into. The fact that it completely abstracts your coding logic from the underlying protocol is great. Netty, on the other hand, might not be so simple at first, and the learning curve is definitely steeper (it feels closer to NIO than MINA does). But as you advance and delve into its APIs, it starts to make more and more sense; not to mention the performance boost, which really pays up. Oh boy, it sure does. But more on that on a future post...

Today I'll be addressing an interesting client-server challenge, with Netty: handshaking.

I'll assume you're familiar with Netty; knowing what a Channel, a ChannelHandler, a ChannelBuffer and a Pipeline are is mandatory.

This tutorial will have a client and a server. The client will connect to the server and, upon connection, it will initiate the handshake. Once the handshake phase has been completed, all the data the client sends will simply be reflected (echoed) back at it. To keep things even simpler, I'll be using plain text — both for handshaking and subsequent messaging.

Roughly speaking, here are the steps:

  • Client opens socket to server;
  • As soon as socket connects, both client and server fire up a handshake timeout timer;
  • Client sends a message that contains its own id, the (asserted) server's id and a challenge;
  • The server validates the handshake message and replies to the client with the response to the challenge;
  • Client sends messages to the server, server reflects them back, everyone is happy.

If any of the above steps fail (for instance, the client opens a socket to server with id "A" but asserts that the server's id is "B", the handshake times out, the challenge or challenge response is invalid, etc), either server or client simply shut down the socket.

An interesting issue raised by the introduction of an handshake phase is that upon successful socket connection, it wont be instantly ready to send messages — handshake must first successfully complete. So all messages sent during the handshake period must be queued and then flushed, exactly in the same order they were queued. Not a big deal... but the application must also ensure that if messages are being sent while the queued messages are being flushed, they too must only be sent in order and only after the queued messages are flushed!

Finally, to keep concerns separated, the handshaking logic will be placed within its own channel handler — which will be removed from the pipeline upon handshake completion.

Here's how the pipeline looks when it's created:

Pipeline after creation

In this first stage, no messages will reach the ClientHandler/ServerHandler.

And after handshake has been completed:

Pipeline after handshake

In this stage, the ClientHandshakeHandler/ServerHandshakeHandler will have been removed from the pipeline and messages will reach the ClientHandler/ServerHandler. When the handshake completes (or fails), a custom event is fired upwards in the pipeline to notify the handler.

On to some code, here's the client:

public class Client {
 
    // internal vars ----------------------------------------------------------
 
    private final String id;
    private final String serverId;
    private final ClientListener listener;
    private ClientBootstrap bootstrap;
    private Channel connector;
 
    // constructors -----------------------------------------------------------
 
    public Client(String id, String serverId, ClientListener listener) {
        this.id = id;
        this.serverId = serverId;
        this.listener = listener;
    }
 
    // public methods ---------------------------------------------------------
 
    public boolean start() {
        // Standard netty bootstrapping stuff.
        Executor bossPool = Executors.newCachedThreadPool();
        Executor workerPool = Executors.newCachedThreadPool();
        ChannelFactory factory =
                new NioClientSocketChannelFactory(bossPool, workerPool);
        this.bootstrap = new ClientBootstrap(factory);
 
        // Declared outside to fit under 80 char limit
        final DelimiterBasedFrameDecoder frameDecoder =
                new DelimiterBasedFrameDecoder(Integer.MAX_VALUE,
                                               Delimiters.lineDelimiter());
        this.bootstrap.setPipelineFactory(new ChannelPipelineFactory() {
            public ChannelPipeline getPipeline() throws Exception {
                ByteCounter byteCounter =
                        new ByteCounter("--- CLIENT-COUNTER :: ");
                MessageCounter messageCounter =
                        new MessageCounter("--- CLIENT-MSGCOUNTER :: ");
                ClientHandshakeHandler handshakeHandler =
                        new ClientHandshakeHandler(id, serverId, 5000);
 
                return Channels.pipeline(byteCounter,
                                         frameDecoder,
                                         new StringDecoder(),
                                         new StringEncoder(),
                                         messageCounter,
                                         handshakeHandler,
                                         new ClientHandler(listener));
            }
        });
 
        ChannelFuture future = this.bootstrap
                .connect(new InetSocketAddress("localhost", 12345));
        if (!future.awaitUninterruptibly().isSuccess()) {
            System.out.println("--- CLIENT - Failed to connect to server at " +
                               "localhost:12345.");
            this.bootstrap.releaseExternalResources();
            return false;
        }
 
        this.connector = future.getChannel();
        return this.connector.isConnected();
    }
 
    public void stop() {
        if (this.connector != null) {
            this.connector.close().awaitUninterruptibly();
        }
        this.bootstrap.releaseExternalResources();
        System.out.println("--- CLIENT - Stopped.");
    }
 
    public boolean sendMessage(String message) {
        if (this.connector.isConnected()) {
            // Append \n if it's not present, because of the frame delimiter
            if (!message.endsWith("\n")) {
                this.connector.write(message + '\n');
            } else {
                this.connector.write(message);
            }
            return true;
        }
 
        return false;
    }
}

Nothing really fancy here. Notice the pipeline definition (lines 42 through 48); nevermind the ByteCounter and MessageCounter, those are purely informative. As I mentioned above, I'll be using plain text for both handshaking and messages. I highly recommend you to consider custom binary codecs for a more serious application.

And here's what the server looks like:

public class Server {
 
    // internal vars ----------------------------------------------------------
 
    private final String id;
    private final ServerListener listener;
    private ServerBootstrap bootstrap;
    private ChannelGroup channelGroup;
 
    // constructors -----------------------------------------------------------
 
    public Server(String id, ServerListener listener) {
        this.id = id;
        this.listener = listener;
    }
 
    // public methods ---------------------------------------------------------
 
    public boolean start() {
        // Pretty standard Netty startup stuff...
        // boss/worker executors, channel factory, channel group, pipeline, ...
        Executor bossPool = Executors.newCachedThreadPool();
        Executor workerPool = Executors.newCachedThreadPool();
        ChannelFactory factory =
                new NioServerSocketChannelFactory(bossPool, workerPool);
        this.bootstrap = new ServerBootstrap(factory);
 
        this.channelGroup = new DefaultChannelGroup(this.id + "-all-channels");
 
        // declared here to fit under the 80 char limit
        final ChannelHandler delimiter =
                new DelimiterBasedFrameDecoder(Integer.MAX_VALUE,
                                               Delimiters.lineDelimiter());
        this.bootstrap.setPipelineFactory(new ChannelPipelineFactory() {
            public ChannelPipeline getPipeline() throws Exception {
                ByteCounter counter =
                        new ByteCounter("+++ SERVER-COUNTER :: ");
                MessageCounter messageCounter =
                        new MessageCounter("+++ SERVER-MSGCOUNTER :: ");
                ServerHandshakeHandler handshakeHandler =
                        new ServerHandshakeHandler(id, channelGroup, 5000);
                return Channels.pipeline(counter,
                                         delimiter,
                                         new StringDecoder(),
                                         new StringEncoder(),
                                         messageCounter,
                                         handshakeHandler,
                                         new ServerHandler(listener));
            }
        });
 
        Channel acceptor = this.bootstrap.bind(new InetSocketAddress(12345));
        if (acceptor.isBound()) {
            System.out.println("+++ SERVER - bound to *:12345");
            this.channelGroup.add(acceptor);
            return true;
        } else {
            System.out.println("+++ SERVER - Failed to bind to *:12345");
            this.bootstrap.releaseExternalResources();
            return false;
        }
    }
 
    public void stop() {
        this.channelGroup.close().awaitUninterruptibly();
        this.bootstrap.releaseExternalResources();
        System.out.println("+++ SERVER - Stopped.");
    }
}

The only differences in the server pipeline are its last two ChannelHandlers.

Apart from the obvious difference of the last ChannelHandler (ServerHandler as opposed to ClientHandler), the handshake handler also needs to be different, as its behaviour is slightly different. While the ClientHandshakeHandler takes the initiative of starting the handshake, the ServerHandshakeHandler must wait for an initial handshake message in order to reply.

Before we head to the handshake handlers, lets just do a quick stop by the final handlers.

The ServerHandler:

public class ServerHandler extends SimpleChannelUpstreamHandler {
 
    // internal vars ----------------------------------------------------------
 
    private final AtomicInteger counter;
    private final ServerListener listener;
    private String remoteId;
    private Channel channel;
 
    // constructors -----------------------------------------------------------
 
    public ServerHandler(ServerListener listener) {
        this.listener = listener;
        this.counter = new AtomicInteger();
    }
 
    // SimpleChannelUpstreamHandler -------------------------------------------
 
    @Override
    public void handleUpstream(ChannelHandlerContext ctx, ChannelEvent e)
            throws Exception {
        if (e instanceof HandshakeEvent) {
            if (((HandshakeEvent) e).isSuccessful()) {
                out("+++ SERVER-HANDLER :: Handshake successful, connection " +
                    "to " + ((HandshakeEvent) e).getRemoteId() + " is up.");
                this.remoteId = ((HandshakeEvent) e).getRemoteId();
                this.channel = ctx.getChannel();
                // Notify the listener that a new connection is now READY
                this.listener.connectionOpen(this);
            } else {
                out("+++ SERVER-HANDLER :: Handshake failed.");
            }
            return;
        }
 
        super.handleUpstream(ctx, e);
    }
 
    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
            throws Exception {
        this.counter.incrementAndGet();
        this.listener.messageReceived(this, e.getMessage().toString());
    }
 
    @Override
    public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e)
            throws Exception {
        super.channelClosed(ctx, e);
        out("+++ SERVER-HANDLER :: Channel closed, received " +
            this.counter.get() + " messages: " + e.getChannel());
    }
 
    // public methods ---------------------------------------------------------
 
    public void sendMessage(String message) {
        if (!message.endsWith("\n")) {
            this.channel.write(message + '\n');
        } else {
            this.channel.write(message);
        }
    }
 
    public String getRemoteId() {
        return remoteId;
    }
 
    // private static helpers -------------------------------------------------
 
    private static void out(String s) {
        System.err.println(s);
    }
}

It really doesn't do much besides triggering events every time a message is received.

You'll notice that there's a custom event being handled in the handleUpstream(). When the handshake is successful, a listener is notified by calling listener.connectionOpen() (line 29). This will allow the listener to perform some specialized action, if needed.

An example of such behaviour would be adding this newly created ServerHandler to a map and, every time that the server would receive a message with contents "say hi to X" — with X being another client's id — the server would fetch the handler corresponding to that id and it would send a message "hi from Y" to that client. Silly example, but you get the picture...

Now for the ClientHandler:

public class ClientHandler extends SimpleChannelUpstreamHandler {
 
    // internal vars ----------------------------------------------------------
 
    private final AtomicInteger counter;
    private final ClientListener listener;
 
    // constructors -----------------------------------------------------------
 
    public ClientHandler(ClientListener listener) {
        this.listener = listener;
        this.counter = new AtomicInteger();
    }
 
    // SimpleChannelUpstreamHandler -------------------------------------------
 
    @Override
    public void handleUpstream(ChannelHandlerContext ctx, ChannelEvent e)
            throws Exception {
        if (e instanceof HandshakeEvent) {
            if (((HandshakeEvent) e).isSuccessful()) {
                out("--- CLIENT-HANDLER :: Handshake successful, connection " +
                    "to " + ((HandshakeEvent) e).getRemoteId() + " is up.");
            } else {
                out("--- CLIENT-HANDLER :: Handshake failed.");
            }
            return;
        }
 
        super.handleUpstream(ctx, e);
    }
 
    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
            throws Exception {
        this.counter.incrementAndGet();
        this.listener.messageReceived(e.getMessage().toString());
    }
 
    @Override
    public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e)
            throws Exception {
        super.channelClosed(ctx, e);
        out("--- CLIENT-HANDLER :: Channel closed, received " +
            this.counter.get() + " messages: " + e.getChannel());
    }
 
    // private static helpers -------------------------------------------------
 
    private static void out(String s) {
        System.out.println(s);
    }
}

Nearly the same as the ServerHandler apart from the custom event handling being purely informative.

So with the Client, ClientHandler, Server and ServerHandler explained, it's now time for the important part of this tutorial: the handshaking logic itself.

I feel compelled to warn you that the handshake presented here is... well, dumb. Remember this is a tutorial and I do want to keep things as simple as possible, with a gentle touch of realism. With that in mind, here's the ClientHandshakeHandler:

public class ClientHandshakeHandler extends SimpleChannelHandler {
 
    // internal vars ----------------------------------------------------------
 
    private final long timeoutInMillis;
    private final String localId;
    private final String remoteId;
    private final AtomicBoolean handshakeComplete;
    private final AtomicBoolean handshakeFailed;
    private final CountDownLatch latch = new CountDownLatch(1);
    private final Queue messages = new ArrayDeque();
    private final Object handshakeMutex = new Object();
    private String challenge;
 
    // constructors -----------------------------------------------------------
 
    public ClientHandshakeHandler(String localId, String remoteId,
                                  long timeoutInMillis) {
        this.localId = localId;
        this.remoteId = remoteId;
        this.timeoutInMillis = timeoutInMillis;
        this.handshakeComplete = new AtomicBoolean(false);
        this.handshakeFailed = new AtomicBoolean(false);
    }
 
    // SimpleChannelHandler ---------------------------------------------------
 
    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
            throws Exception {
        if (this.handshakeFailed.get()) {
            // Bail out fast if handshake already failed
            return;
        }
 
        if (this.handshakeComplete.get()) {
            // If handshake succeeded but message still came through this
            // handler, then immediately send it upwards.
            // Chances are it's the last time a message passes through
            // this handler...
            super.messageReceived(ctx, e);
            return;
        }
 
        synchronized (this.handshakeMutex) {
            // Recheck conditions after locking the mutex.
            // Things might have changed while waiting for the lock.
            if (this.handshakeFailed.get()) {
                return;
            }
 
            if (this.handshakeComplete.get()) {
                super.messageReceived(ctx, e);
                return;
            }
 
            // Parse the challenge.
            // Expected format is "clientId:serverId:challenge"
            String[] params = ((String) e.getMessage()).trim().split(":");
            if (params.length != 3) {
                out("--- CLIENT-HS :: Invalid handshake: expected 3 params, " +
                    "got " + params.length);
                this.fireHandshakeFailed(ctx);
                return;
            }
 
            // Silly validations...
            // 1. Validate that server replied correctly to this client's id.
            if (!params[0].equals(this.localId)) {
                out("--- CLIENT-HS == Handshake failed: local id is " +
                    this.localId +" but challenge response is for '" +
                    params[0] + "'");
                this.fireHandshakeFailed(ctx);
                return;
            }
 
            // 2. Validate that asserted server id is its actual id.
            if (!params[1].equals(this.remoteId)) {
                out("--- CLIENT-HS :: Handshake failed: expecting remote id " +
                    this.remoteId + " but got " + params[1]);
                this.fireHandshakeFailed(ctx);
                return;
            }
 
            // 3. Ensure that challenge response is correct.
            if (!Challenge.isValidResponse(params[2], this.challenge)) {
                out("--- CLIENT-HS :: Handshake failed: '" + params[2] +
                    "' is not a valid response for challenge '" +
                    this.challenge + "'");
                this.fireHandshakeFailed(ctx);
                return;
            }
 
            // Everything went okay!
            out("--- CLIENT-HS :: Challenge validated, flushing messages & " +
                "removing handshake handler from pipeline.");
 
            // Flush messages *directly* downwards.
            // Calling ctx.getChannel().write() here would cause the messages
            // to be inserted at the top of the pipeline, thus causing them
            // to pass through this class's writeRequest() and be re-queued.
            out("--- CLIENT-HS :: " + this.messages.size() +
                " messages in queue to be flushed.");
            for (MessageEvent message : this.messages) {
                ctx.sendDownstream(message);
            }
 
            // Remove this handler from the pipeline; its job is finished.
            ctx.getPipeline().remove(this);
 
            // Finally fire success message upwards.
            this.fireHandshakeSucceeded(this.remoteId, ctx);
        }
    }
 
    @Override
    public void channelConnected(final ChannelHandlerContext ctx,
                                 ChannelStateEvent e) throws Exception {
        out("--- CLIENT-HS :: Outgoing connection established to: " +
            e.getChannel().getRemoteAddress());
 
        // Write the handshake & add a timeout listener.
        ChannelFuture f = Channels.future(ctx.getChannel());
        f.addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future)
                    throws Exception {
                // Once this message is sent, start the timeout checker.
                new Thread() {
                    @Override
                    public void run() {
                        // Wait until either handshake completes (releases the
                        // latch) or this latch times out.
                        try {
                            latch.await(timeoutInMillis, TimeUnit.MILLISECONDS);
                        } catch (InterruptedException e1) {
                            out("--- CLIENT-HS :: Handshake timeout checker: " +
                                "interrupted!");
                            e1.printStackTrace();
                        }
 
                        // Informative output, do nothing...
                        if (handshakeFailed.get()) {
                            out("--- CLIENT-HS :: (pre-synchro) Handshake " +
                                "timeout checker: discarded " +
                                "(handshake failed)");
                            return;
                        }
 
                        // More informative output, do nothing...
                        if (handshakeComplete.get()) {
                            out("--- CLIENT-HS :: (pre-synchro) Handshake " +
                                "timeout checker: discarded" +
                                "(handshake completed)");
                            return;
                        }
 
                        // Handshake has neither failed nor completed, time
                        // to do something! (trigger failure).
                        // Lock on the mutex first...
                        synchronized (handshakeMutex) {
                            // Same checks as before, conditions might have
                            // changed while waiting to get a lock on the
                            // mutex.
                            if (handshakeFailed.get()) {
                                out("--- CLIENT-HS :: (synchro) Handshake " +
                                    "timeout checker: already failed.");
                                return;
                            }
 
                            if (!handshakeComplete.get()) {
                                // If handshake wasn't completed meanwhile,
                                // time to mark the handshake as having failed.
                                out("--- CLIENT-HS :: (synchro) Handshake " +
                                    "timeout checker: timed out, " +
                                    "killing connection.");
                                fireHandshakeFailed(ctx);
                            } else {
                                // Informative output; the handshake was
                                // completed while this thread was waiting
                                // for a lock on the handshakeMutex.
                                // Do nothing...
                                out("--- CLIENT-HS :: (synchro) Handshake " +
                                    "timeout checker: discarded " +
                                    "(handshake OK)");
                            }
                        }
                    }
                }.start();
            }
        });
 
        this.challenge = Challenge.generateChallenge();
        String handshake =
                this.localId + ':' + this.remoteId + ':' + challenge + '\n';
        Channel c = ctx.getChannel();
        // Passing null as remoteAddress, since constructor in
        // DownstreamMessageEvent will use remote address from the channel if
        // remoteAddress is null.
        // Also, we need to send the data directly downstream rather than
        // call c.write() otherwise the message would pass through this
        // class's writeRequested() method defined below.
        ctx.sendDownstream(new DownstreamMessageEvent(c, f, handshake, null));
    }
 
    @Override
    public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e)
            throws Exception {
        out("--- CLIENT-HS :: Channel closed.");
        if (!this.handshakeComplete.get()) {
            this.fireHandshakeFailed(ctx);
        }
    }
 
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e)
            throws Exception {
        out("--- CLIENT-HS :: Exception caught.");
        e.getCause().printStackTrace();
        if (e.getChannel().isConnected()) {
            // Closing the channel will trigger handshake failure.
            e.getChannel().close();
        } else {
            // Channel didn't open, so we must fire handshake failure directly.
            this.fireHandshakeFailed(ctx);
        }
    }
 
    @Override
    public void writeRequested(ChannelHandlerContext ctx, MessageEvent e)
            throws Exception {
        // Before doing anything, ensure that noone else is working by
        // acquiring a lock on the handshakeMutex.
        synchronized (this.handshakeMutex) {
            if (this.handshakeFailed.get()) {
                // If the handshake failed meanwhile, discard any messages.
                return;
            }
 
            // If the handshake hasn't failed but completed meanwhile and
            // messages still passed through this handler, then forward
            // them downwards.
            if (this.handshakeComplete.get()) {
                out("--- CLIENT-HS :: Handshake already completed, not " +
                    "appending '" + e.getMessage().toString().trim() +
                    "' to queue!");
                super.writeRequested(ctx, e);
            } else {
                // Otherwise, queue messages in order until the handshake
                // completes.
                this.messages.offer(e);
            }
        }
    }
 
    // private static helpers -------------------------------------------------
 
    private static void out(String s) {
        System.out.println(s);
    }
 
    // private helpers --------------------------------------------------------
 
    private void fireHandshakeFailed(ChannelHandlerContext ctx) {
        this.handshakeComplete.set(true);
        this.handshakeFailed.set(true);
        this.latch.countDown();
        ctx.getChannel().close();
        ctx.sendUpstream(HandshakeEvent.handshakeFailed(ctx.getChannel()));
    }
 
    private void fireHandshakeSucceeded(String server,
                                        ChannelHandlerContext ctx) {
        this.handshakeComplete.set(true);
        this.handshakeFailed.set(false);
        this.latch.countDown();
        ctx.sendUpstream(HandshakeEvent
                .handshakeSucceeded(server, ctx.getChannel()));
    }
}

The first thing you'll notice is that this ChannelHandler is heavily dependent on synchronization. While it does serialize access and cause threads to block for short periods of time, it is an acceptable trade-off to ensure absolute order when sending messages. Any alternatives to this cautious approach are welcome!

The ServerHandshakeHandler is pretty much the same, except for the challenge validation part.

public class ServerHandshakeHandler extends SimpleChannelHandler {
    // ...
 
    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
            throws Exception {
        if (this.handshakeFailed.get()) {
            // Bail out fast if handshake already failed
            return;
        }
 
        if (this.handshakeComplete.get()) {
            // If handshake succeeded but message still came through this
            // handler, then immediately send it upwards.
            super.messageReceived(ctx, e);
            return;
        }
 
        synchronized (this.handshakeMutex) {
            // Recheck conditions after locking the mutex.
            if (this.handshakeFailed.get()) {
                return;
            }
 
            if (this.handshakeComplete.get()) {
                super.messageReceived(ctx, e);
                return;
            }
 
            // Validate handshake
            String handshake = (String) e.getMessage();
            // 1. Validate expected clientId:serverId:challenge format
            String[] params = handshake.trim().split(":");
            if (params.length != 3) {
                out("+++ SERVER-HS :: Invalid handshake: expecting 3 params, " +
                    "got " + params.length + " -> '" + handshake + "'");
                this.fireHandshakeFailed(ctx);
                return;
            }
 
            // 2. Validate the asserted serverId = localId
            String client = params[0];
            if (!this.localId.equals(params[1])) {
                out("+++ SERVER-HS :: Invalid handshake: this is " +
                    this.localId + " and client thinks it's " + params[1]);
                this.fireHandshakeFailed(ctx);
                return;
            }
 
            // 3. Validate the challenge format.
            if (!Challenge.isValidChallenge(params[2])) {
                out("+++ SERVER-HS :: Invalid handshake: invalid challenge '" +
                    params[2] + "'");
                this.fireHandshakeFailed(ctx);
                return;
            }
 
            // Success! Write the challenge response.
            out("+++ SERVER-HS :: Challenge validated, flushing messages & " +
                "removing handshake handler from  pipeline.");
            String response = params[0] + ':' + params[1] + ':' +
                              Challenge.generateResponse(params[2]) + '\n';
            this.writeDownstream(ctx, response);
 
            // Flush any pending messages (in this tutorial, no messages will
            // ever be queued because the server does not take the initiative
            // of sending messages to clients on its own...
            out("+++ SERVER-HS :: " + this.messages.size() +
                " messages in queue to be flushed.");
            for (MessageEvent message : this.messages) {
                ctx.sendDownstream(message);
            }
 
            // Finally, remove this handler from the pipeline and fire success
            // event up the pipeline.
            out("+++ SERVER-HS :: Removing handshake handler from pipeline.");
            ctx.getPipeline().remove(this);
            this.fireHandshakeSucceeded(client, ctx);
        }
    }
 
    // ...
}

Another difference is that the channelOpen() implementation (obviously) does not send any messages; it just fires up the handshake timeout checker.

Other than those two, it's pretty much the same. In fact, you will probably notice that there is a lot of code that could be factored out from both ClientHandshakeHandler and ServerHandshakeHandler into an AbstractHandshakeHandler in order to keep things DRY.

Protip™: A final acknowledgement from the client would be a very good idea, since it's the client who actually validates the challenge response.

Before wrapping up, let's just take a brief look at the classes that make this thing run:

public class ClientRunner {
 
    public static void runClient(final String id, final String serverId,
                                 final int nMessages)
            throws InterruptedException {
 
        final AtomicInteger cLast = new AtomicInteger();
        final AtomicInteger clientCounter = new AtomicInteger();
        final CountDownLatch latch = new CountDownLatch(1);
 
        // Create a client with custom id, that connects to a server with given
        // id and has a message listener that ensures that ALL messages are
        // received in perfect order.
        Client c = new Client(id, serverId, new ClientListener() {
            @Override
            public void messageReceived(String message) {
                int num = Integer.parseInt(message.trim());
                if (num != (cLast.get() + 1)) {
                    System.err.println("--- CLIENT-LISTENER(" + id + ") " +
                                       ":: OUT OF ORDER!!! expecting " +
                                       (cLast.get() + 1) + " and got " +
                                       message);
                } else {
                    cLast.set(num);
                }
 
                if (clientCounter.incrementAndGet() >= nMessages) {
                    latch.countDown();
                }
            }
        });
 
        if (!c.start()) {
            return;
        }
 
        for (int i = 0; i < nMessages; i++) {
            // This sleep here prevents all messages to be instantly queued
            // in the handshake message queue. Since handshake takes some time,
            // all messages sent during handshake will be queued (and later on
            // flushed).
            // Since we want to test the effect of removing the handshake
            // handler from the pipeline (and ensure that message order is
            // preserved), this sleep helps us accomplish that with a random
            // factor.
            // If lucky, a couple of messages will even hit the handshake
            // handler *after* the handshake has been completed but right
            // before the handshake handler is removed from the pipeline.
            // Worry not, that case is also covered :)
            Thread.sleep(1L);
            c.sendMessage((i + 1) + "\n");
        }
 
        // Run the client for some time, then shut it down.
        latch.await(10, TimeUnit.SECONDS);
        c.stop();
    }
 
    public static void main(String[] args) throws InterruptedException {
        // More clients will test robustness of the server, but output becomes
        // more confusing.
        int nClients = 1;
        final int nMessages = 10000;
        // Changing this value to something different than the server's id
        // will cause handshaking to fail.
        final String serverId = "server1";
        ExecutorService threadPool = Executors.newCachedThreadPool();
        for (int i = 0; i < nClients; i++) {
            final int finalI = i;
            threadPool.submit(new Runnable() {
                @Override
                public void run() {
                    try {
                        ClientRunner.runClient("client" + finalI, serverId,
                                               nMessages);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
    }
}

What this runner class does is launch N clients, which send a sequence of numbers in separate messages. Each client then listens for incoming messages and verifies that each of those messages arrived in the same order as they were sent (ascending).

This client starts sending messages as soon as the connection is established, so it will queue a dozen of messages before handshake actually completes. This tests the proper usage of synchronization, as well as queuing and flushing the messages while the handshake is proceeding.

public class ServerRunner {
 
    public static void main(String[] args) {
        final Map lastMap =
                new ConcurrentHashMap();
 
        // Create a new server with id "server1" with a listener that ensures
        // that for each handler, perfect message order is guaranteed.
        final Server s = new Server("server2", new ServerListener() {
 
            @Override
            public void messageReceived(ServerHandler handler,
                                        String message) {
                AtomicInteger last = lastMap.get(handler);
                int num = Integer.parseInt(message.trim());
                if (num != (last.get() + 1)) {
                    System.err.println("+++ SERVER-LISTENER(" +
                                       handler.getRemoteId() + ") :: " +
                                       "OUT OF ORDER!!! expecting " +
                                       (last.get() + 1) + " and got " +
                                       message);
                } else {
                    last.set(num);
                }
 
                handler.sendMessage(message);
            }
 
            @Override
            public void connectionOpen(ServerHandler handler) {
                System.err.println("+++ SERVER-LISTENER(" +
                                   handler.getRemoteId() +
                                   ") :: Connection with " +
                                   handler.getRemoteId() +
                                   " opened & ready to send/receive data.");
                AtomicInteger counter = new AtomicInteger();
                lastMap.put(handler, counter);
            }
        });
 
        if (!s.start()) {
            return;
        }
 
        Runtime.getRuntime().addShutdownHook(new Thread() {
            @Override
            public void run() {
                s.stop();
            }
        });
    }
}

The server runner, on the other hand, simply launches a server. For each new incoming connection, it will also test if the messages are arriving in order. This is kind of redundant since the client already performs this check, but it doesn't hurt.

I'll invite you to download the project and tweak a few things to see how it behaves. One simple change is on the client runner, changing the id of the server — this will cause the handshake to fail.

That's pretty much it, I hope you find this useful!

A couple of final notes

While this tutorial is very simplistic, it presents a very powerful concept for authenticating and validating incoming connections. You can include all sorts of useful information in the handshake apart from the IDs of the entities connecting, such as supported versions or timestamps; you can also increase the number of steps according to your needs.

In my case, I've deemed that 3 messages are enough for the vast majority of scenarios where this is applicable:

  • client sends handshake with version and challenge;
  • server accepts version, replies with handshake challenge response;
  • client sends acknowledgement and handshake phase is complete.

Pair up handshaking with custom binary protocols and you've got yourself some powerful, fast, safe data exchange links. Handshaking is perfect for connections between servers; it avoids redundant connections and periodic polling to a central discovery service.