package org.whispersystems.dispatch.redis;

import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.security.SecureRandom;

import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.*;

public class PubSubConnectionTest {

  private static final String REPLY = "*3\r\n" +
      "$9\r\n" +
      "subscribe\r\n" +
      "$5\r\n" +
      "abcde\r\n" +
      ":1\r\n" +
      "*3\r\n" +
      "$9\r\n" +
      "subscribe\r\n" +
      "$5\r\n" +
      "fghij\r\n" +
      ":2\r\n" +
      "*3\r\n" +
      "$9\r\n" +
      "subscribe\r\n" +
      "$5\r\n" +
      "klmno\r\n" +
      ":2\r\n" +
      "*3\r\n" +
      "$7\r\n" +
      "message\r\n" +
      "$5\r\n" +
      "abcde\r\n" +
      "$10\r\n" +
      "1234567890\r\n" +
      "*3\r\n" +
      "$7\r\n" +
      "message\r\n" +
      "$5\r\n" +
      "klmno\r\n" +
      "$10\r\n" +
      "0987654321\r\n";


  @Test
  public void testSubscribe() throws IOException {
//    ByteChannel      byteChannel = mock(ByteChannel.class);
    OutputStream outputStream = mock(OutputStream.class);
    Socket       socket       = mock(Socket.class      );
    when(socket.getOutputStream()).thenReturn(outputStream);
    PubSubConnection connection  = new PubSubConnection(socket);

    connection.subscribe("foobar");

    ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
    verify(outputStream).write(captor.capture());

    assertArrayEquals(captor.getValue(), "SUBSCRIBE foobar\r\n".getBytes());
  }

  @Test
  public void testUnsubscribe() throws IOException {
    OutputStream outputStream = mock(OutputStream.class);
    Socket       socket       = mock(Socket.class      );
    when(socket.getOutputStream()).thenReturn(outputStream);
    PubSubConnection connection  = new PubSubConnection(socket);

    connection.unsubscribe("bazbar");

    ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
    verify(outputStream).write(captor.capture());

    assertArrayEquals(captor.getValue(), "UNSUBSCRIBE bazbar\r\n".getBytes());
  }

  @Test
  public void testTricklyResponse() throws Exception {
    InputStream  inputStream  = mockInputStreamFor(new TrickleInputStream(REPLY.getBytes()));
    OutputStream outputStream = mock(OutputStream.class);
    Socket       socket       = mock(Socket.class      );
    when(socket.getOutputStream()).thenReturn(outputStream);
    when(socket.getInputStream()).thenReturn(inputStream);

    PubSubConnection pubSubConnection = new PubSubConnection(socket);
    readResponses(pubSubConnection);
  }

  @Test
  public void testFullResponse() throws Exception {
    InputStream  inputStream  = mockInputStreamFor(new FullInputStream(REPLY.getBytes()));
    OutputStream outputStream = mock(OutputStream.class);
    Socket       socket       = mock(Socket.class      );
    when(socket.getOutputStream()).thenReturn(outputStream);
    when(socket.getInputStream()).thenReturn(inputStream);

    PubSubConnection pubSubConnection = new PubSubConnection(socket);
    readResponses(pubSubConnection);
  }

  @Test
  public void testRandomLengthResponse() throws Exception {
    InputStream  inputStream  = mockInputStreamFor(new RandomInputStream(REPLY.getBytes()));
    OutputStream outputStream = mock(OutputStream.class);
    Socket       socket       = mock(Socket.class      );
    when(socket.getOutputStream()).thenReturn(outputStream);
    when(socket.getInputStream()).thenReturn(inputStream);

    PubSubConnection pubSubConnection = new PubSubConnection(socket);
    readResponses(pubSubConnection);
  }

  private InputStream mockInputStreamFor(final MockInputStream stub) throws IOException {
    InputStream result = mock(InputStream.class);

    when(result.read()).thenAnswer(new Answer<Integer>() {
      @Override
      public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
        return stub.read();
      }
    });

    when(result.read(any(byte[].class))).thenAnswer(new Answer<Integer>() {
      @Override
      public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
        byte[] buffer = (byte[])invocationOnMock.getArguments()[0];
        return stub.read(buffer, 0, buffer.length);
      }
    });

    when(result.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer<Integer>() {
      @Override
      public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
        byte[] buffer = (byte[]) invocationOnMock.getArguments()[0];
        int offset = (int) invocationOnMock.getArguments()[1];
        int length = (int) invocationOnMock.getArguments()[2];

        return stub.read(buffer, offset, length);
      }
    });

    return result;
  }

  private void readResponses(PubSubConnection pubSubConnection) throws Exception {
    PubSubReply reply = pubSubConnection.read();

    assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
    assertEquals(reply.getChannel(), "abcde");
    assertFalse(reply.getContent().isPresent());

    reply = pubSubConnection.read();

    assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
    assertEquals(reply.getChannel(), "fghij");
    assertFalse(reply.getContent().isPresent());

    reply = pubSubConnection.read();

    assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
    assertEquals(reply.getChannel(), "klmno");
    assertFalse(reply.getContent().isPresent());

    reply = pubSubConnection.read();

    assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
    assertEquals(reply.getChannel(), "abcde");
    assertArrayEquals(reply.getContent().get(), "1234567890".getBytes());

    reply = pubSubConnection.read();

    assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
    assertEquals(reply.getChannel(), "klmno");
    assertArrayEquals(reply.getContent().get(), "0987654321".getBytes());
  }

  private interface MockInputStream {
    public int read();
    public int read(byte[] input, int offset, int length);
  }

  private static class TrickleInputStream implements MockInputStream {

    private final byte[] data;
    private int index = 0;

    private TrickleInputStream(byte[] data) {
      this.data = data;
    }

    public int read() {
      return data[index++];
    }

    public int read(byte[] input, int offset, int length) {
      input[offset] = data[index++];
      return 1;
    }

  }

  private static class FullInputStream implements MockInputStream {

    private final byte[] data;
    private int index = 0;

    private FullInputStream(byte[] data) {
      this.data = data;
    }

    public int read() {
      return data[index++];
    }

    public int read(byte[] input, int offset, int length) {
      int amount = Math.min(data.length - index, length);
      System.arraycopy(data, index, input, offset, amount);
      index += length;

      return amount;
    }
  }

  private static class RandomInputStream implements MockInputStream {
    private final byte[] data;
    private int index = 0;

    private RandomInputStream(byte[] data) {
      this.data = data;
    }

    public int read() {
      return data[index++];
    }

    public int read(byte[] input, int offset, int length) {
      int maxCopy    = Math.min(data.length - index, length);
      int randomCopy = new SecureRandom().nextInt(maxCopy) + 1;
      int copyAmount = Math.min(maxCopy, randomCopy);

      System.arraycopy(data, index, input, offset, copyAmount);
      index += copyAmount;

      return copyAmount;
    }

  }

}
