An enhanced version of ReplayingDecoder for Netty

Published on — Filed under bare metal

One of the things you'll instantly fall in love with when using Netty is the ReplayingDecoder.

However, with time and increasingly complex requirements, you'll soon realise that it has a couple of shortcomings, namely when dealing with complex message structures. The typical usage scenario for ReplayingDecoder is a message with a structure like this:

TYPE (1B)
LENGTH (1B)
VALUE (xB, depends on value read at LENGTH)

You may even deal with longer message structures, as long as they're straightforward. You'll have problems dealing with repeating fields.

For instance, how would you deal with the following message structure?

TYPE (1B)
ID_SIZE (1B)
ID (xB)
PARAM_COUNT (1B)
[PARAM_LENGTH (1B) | PARAM_VALUE (xB)]*

* depends on PARAM_COUNT

I've guided some people to using Netty and this decoder before, without really paying attention when they complained about complex message structures. "You're doing it wrong", I thought. Until I hit the same brick wall.

The solution is actually very simple: every time you read a field, instead of falling through the switch, exit indicating whether you should keep reading (and which is the next desired state) or if you're done with the message.

Here's a proposal, the EnhancedReplayingDecoder:

public abstract class EnhancedReplayingDecoder<t extends Enum<t>>
      extends ReplayingDecoder<t> {

  // internal vars ------------------------------

  private final T initialState;

  // constructors -------------------------------

  public EnhancedReplayingDecoder(T initialState) {
    this(initialState, false);
  }

  public EnhancedReplayingDecoder(T initialState, boolean unfold) {
    super(initialState, unfold);
    this.initialState = initialState;
  }

  // ReplayingDecoder ---------------------------

  @Override
  protected Object decode(ChannelHandlerContext ctx, Channel channel,
                          ChannelBuffer buffer, T state)
        throws Exception {
    for (;;) {
      // Request a decode with provided buffer and current state. It can
      // only end in one of three ways:
      // 1. There is still data missing (CONTINUE), so a checkpoint will
      //    be set and buffer draining will continue;
      // 2. There was enough data to build a message (FINISHED), so reset
      //    the decoder and return the result;
      // 3. There was insufficient data, in which case an exception will
      //    be thrown and handled by the ReplayingDecoder.
      DecodeResult<t> result = this.decode(buffer, this.getState());
      if (result == null) {
        throw new IllegalArgumentException("decode() returned null");
      }
      switch (result.getType()) {
        case FINISHED:
          // Final state, with composed object.
          try {
            return ((FinishedDecodeResult) result).getResult();
          } finally {
            this.reset();
          }
        case CONTINUE:
          // Keep processing the message, setting a checkpoint to the
          // next state.
          this.checkpoint(((ContinueDecodeResult<t>) result)
                  .getNextState());
          break;
        default:
          // Never actually falls here...
          throw new IllegalArgumentException("Unsupported result: " +
                                             result.getType());
      }
    }
  }

  // protected helpers --------------------------

  protected DecodeResult<t> continueDecoding(T nextState) {
    return new ContinueDecodeResult<t>(nextState);
  }

  protected DecodeResult<t> finishedDecoding(Object result) {
    return new FinishedDecodeResult<t>(result);
  }

  protected void reset() {
    this.cleanup();
    this.setState(this.initialState);
  }

  protected abstract DecodeResult<t> decode(ChannelBuffer buffer,
                                            T currentState)
        throws Exception;

  protected abstract void cleanup();
}

So basically, every time you successfully read a field of the message, you call either continueDecoding() providing the next state or finishedDecoding() with a result to send upstream.

Here’s the rest of the code:

You'll need these but you'll never actually handle them directly. The generics are just to keep stuff type safe and in some cases avoid compiler warnings.

public interface DecodeResult<t> {
  enum Type {
    FINISHED,
    CONTINUE
  }

  Type getType();
}

public class ContinueDecodeResult<t extends Enum<t>>
    implements DecodeResult<t> {

  // internal vars ------------------------------

  private final T nextState;

  // constructors -------------------------------

  public ContinueDecodeResult(T nextState) {
    this.nextState = nextState;
  }

  // DecodeResult -------------------------------

  @Override
  public Type getType() { return Type.CONTINUE; }

  // getters & setters --------------------------

  public T getNextState() { return nextState; }
}

public class FinishedDecodeResult<t>
    implements DecodeResult<t> {

  // internal vars ------------------------------

  private final Object result;

  // constructors -------------------------------

  public FinishedDecodeResult(Object result) {
    this.result = result;
  }

  // DecodeResult -------------------------------

  @Override
  public Type getType() { return Type.FINISHED; }

  // getters & setters --------------------------

  public Object getResult() { return result; }
}

And since this wouldn't be complete without an example...

public class SampleDecoder
    extends EnhancedReplayingDecoder<sampleDecoder.DecodingState> {

  // internal vars ----------------------------

  private int type;
  private byte[] id;
  private int nParams;
  private List<string> params;
  private byte[] param;

  // constructors -----------------------------

  public SampleDecoder() {
    super(DecodingState.TYPE);
  }

  // EnhancedReplayingDecoder -----------------

  @Override
  protected DecodeResult<decodingState> decode(ChannelBuffer buffer,
                                               DecodingState state)
      throws Exception {
    switch (state) {
      case TYPE:
        this.type = buffer.readInt();
        return this.continueDecoding(DecodingState.ID_SIZE);

      case ID_SIZE:
        // Should be protected for 0 or negative sizes.
        this.id = new byte[buffer.readByte()];
        return this.continueDecoding(DecodingState.ID);

      case ID:
        buffer.readBytes(this.id);
        if (this.type == 1) {
          // Lets assume type 1 messages only need id.
          Message m = new Type1Message(new String(this.id));
          return this.finishedDecoding(m);
        } else {
          // Otherwise continue decoding.
          return this.continueDecoding(DecodingState.PARAM_COUNT);
        }

      case PARAM_COUNT:
        this.nParams = buffer.readByte();
        // If there are parameters continue decoding, otherwise bail.
        if (this.nParams > 0) {
          this.params = new ArrayList<string>(this.nParams);
          return this.continueDecoding(DecodingState.PARAM_SIZE);
        } else {
          Message m = new OtherMessage(new String(this.id));
          return this.finishedDecoding(m);
        }

      case PARAM_SIZE:
        this.param = new byte[buffer.readByte()];
        return this.continueDecoding(DecodingState.PARAM_VALUE);

      case PARAM_VALUE:
        buffer.readBytes(this.param);
        this.params.add(new String(this.param));
        if (this.params.size() >= this.nParams) {
          // This was the last parameter, exit.
          Message m = new OtherMessage(new String(this.id));
          m.setParams(this.params);
          return this.finishedDecoding(m);
        } else {
          // Continue reading parameters.
          return this.continueDecoding(DecodingState.PARAM_SIZE);
        }

      default:
        throw new IllegalStateException("Unknown state: " + state);
    }
  }

  @Override
  protected void cleanup() {
    // cleanup pending resources allocated for decoding
    this.id = null;
    this.nParams = 0;
    this.type = -1;
    this.param = null;
    this.params = null;
  }

  // public classes --------------------------------------------------------

  public static enum DecodingState {
    TYPE,
    ID_SIZE,
    ID,
    PARAM_COUNT,
    PARAM_SIZE,
    PARAM_VALUE
  }
}

Easy as pie. Best part about it is that you don't need to explicitly cleanup, since that'll be called for you when you return finishedDecoding().

It could be slightly optimised for performance if instead of creating a new instances of DecodingState (in continueDecoding() and finishedDecoding()) I'd simply reuse the same instances...