package io.taig.taigless.ws

import cats.effect.Resource
import cats.effect.kernel.{Deferred, Sync}
import cats.effect.std.Dispatcher
import cats.syntax.all._
import fs2.CompositeFailure
import fs2.concurrent.Topic
import org.java_websocket.client.{WebSocketClient => JWebSocketClient}
import org.java_websocket.handshake.ServerHandshake

import java.net.URI
import java.nio.ByteBuffer
import scala.concurrent.duration.FiniteDuration

abstract class WebSocketMessageHandler[F[_], A](
    topic: Topic[F, Either[Throwable, Message[A]]],
    interruption: Deferred[F, Throwable],
    dispatcher: Dispatcher[F]
)(uri: URI, enqueueTimeout: FiniteDuration)(implicit F: Sync[F])
    extends JWebSocketClient(uri) {
  setConnectionLostTimeout(0)

  def message(value: Message[A]): Unit =
    try dispatcher.unsafeRunTimed(topic.publish1(Right(value)).void, enqueueTimeout)
    catch { case throwable: Throwable => dispatcher.unsafeRunSync(interruption.complete(throwable).void) }

  def error(cause: Throwable): Unit =
    try dispatcher.unsafeRunTimed(topic.publish1(Left(cause)) *> topic.close.void, enqueueTimeout)
    catch {
      case throwable: Throwable =>
        dispatcher.unsafeRunSync(interruption.complete(CompositeFailure(cause, throwable)).void)
    }

  final override def onOpen(handshake: ServerHandshake): Unit = message(Message.Open)

  final override def onError(exception: Exception): Unit = error(exception)

  final override def onClose(code: Int, reason: String, remote: Boolean): Unit = {
    message(Message.Close(code, reason))
    dispatcher.unsafeRunSync(topic.close.void)
  }

  final val start: Resource[F, Unit] =
    Resource.eval(F.blocking(connectBlocking())).onFinalize(F.blocking(closeBlocking())).void
}

final class StringWebSocketMessageHandler[F[_]: Sync](
    topic: Topic[F, Either[Throwable, Message[String]]],
    interruption: Deferred[F, Throwable],
    dispatcher: Dispatcher[F]
)(uri: URI, enqueueTimeout: FiniteDuration)
    extends WebSocketMessageHandler[F, String](topic, interruption, dispatcher)(uri, enqueueTimeout) {
  override def onMessage(value: String): Unit = message(Message.Data(value))
}

final class ByteBufferWebSocketMessageHandler[F[_]: Sync](
    topic: Topic[F, Either[Throwable, Message[ByteBuffer]]],
    interruption: Deferred[F, Throwable],
    dispatcher: Dispatcher[F]
)(uri: URI, enqueueTimeout: FiniteDuration)
    extends WebSocketMessageHandler[F, ByteBuffer](topic, interruption, dispatcher)(uri, enqueueTimeout) {
  override def onMessage(value: String): Unit = ()

  override def onMessage(bytes: ByteBuffer): Unit = message(Message.Data(bytes))
}
