package sslproxy

import (
	"github.com/curltech/go-colla-core/config"
	"github.com/curltech/go-colla-core/logger"
	"golang.org/x/crypto/acme/autocert"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strings"
)

/**
利用go的代理服务器功能实现的简单的http反向代理功能
*/
const (
	HTTPSPrefix = "https://"
	HTTPPrefix  = "http://"
)

// 创建反向代理，加转发头字段
func build(toURL *url.URL) *httputil.ReverseProxy {
	proxy := &httputil.ReverseProxy{}
	addProxyHeaders := func(req *http.Request) {
		req.Header.Set(http.CanonicalHeaderKey("X-Forwarded-Proto"), "https")
		req.Header.Set(http.CanonicalHeaderKey("X-Forwarded-Port"), "443") // TODO: inherit another port if needed
	}
	proxy.Director = newDirector(toURL, addProxyHeaders)

	return proxy
}

// newDirector creates a base director that should be exactly what http.NewSingleHostReverseProxy() creates, but allows
// for the caller to supply and extraDirector function to decorate to request to the downstream server
func newDirector(target *url.URL, extraDirector func(*http.Request)) func(*http.Request) {
	targetQuery := target.RawQuery
	return func(req *http.Request) {
		req.URL.Scheme = target.Scheme
		req.URL.Host = target.Host
		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
		if targetQuery == "" || req.URL.RawQuery == "" {
			req.URL.RawQuery = targetQuery + req.URL.RawQuery
		} else {
			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
		}
		if _, ok := req.Header["User-Agent"]; !ok {
			// explicitly disable User-Agent so it's not set to default value
			req.Header.Set("User-Agent", "")
		}

		if extraDirector != nil {
			extraDirector(req)
		}
	}
}

// singleJoiningSlash is a utility function that adds a single slash to a URL where appropriate, copied from
// the httputil package
// TODO: add test to ensure behavior does not diverge from httputil's implementation, as per Rob Pike's proverbs
func singleJoiningSlash(a, b string) string {
	aslash := strings.HasSuffix(a, "/")
	bslash := strings.HasPrefix(b, "/")
	switch {
	case aslash && bslash:
		return a + b[1:]
	case !aslash && !bslash:
		return a + "/" + b
	}
	return a + b
}

func Start() error {
	if config.ProxyParams.Mode == "none" {
		return nil
	}

	// 设置代理的目标地址
	if !strings.HasPrefix(config.ProxyParams.Target, HTTPPrefix) && !strings.HasPrefix(config.ProxyParams.Target, HTTPSPrefix) {
		config.ProxyParams.Target = HTTPPrefix + config.ProxyParams.Target
	}

	// Parse toURL as a URL
	toURL, err := url.Parse(config.ProxyParams.Target)
	if err != nil {
		logger.Sugar.Errorf("%v", err.Error())

		return err
	}

	// Setup reverse proxy ServeMux
	proxy := build(toURL)
	mux := http.NewServeMux()
	mux.Handle("/", proxy)

	// Redirect 表示目标如果http，重定向到https
	if config.ProxyParams.Redirect {
		// Redirect to fromURL by default, unless a domain is specified--in that case, redirect using the public facing
		// domain
		redirectURL := config.ProxyParams.Address
		if config.TlsParams.Domain != "" {
			redirectURL = config.TlsParams.Domain
		}
		redirectTLS := func(w http.ResponseWriter, r *http.Request) {
			http.Redirect(w, r, "https://"+redirectURL+r.RequestURI, http.StatusMovedPermanently)
		}
		go func() {
			logger.Sugar.Infof("Also redirecting https requests on port 80 to https requests on %s", redirectURL)
			err = http.ListenAndServe(":80", http.HandlerFunc(redirectTLS))
			if err != nil {
				logger.Sugar.Infof("HTTP redirection server failure")
				logger.Sugar.Infof(err.Error())
			}
		}()
	}
	if config.ProxyParams.Mode == "http" {
		logger.Sugar.Infof("Proxying calls from http://%s to %s started!", config.ProxyParams.Address, toURL)
		err = http.ListenAndServe(config.ProxyParams.Address, mux)
		if err == nil {
			logger.Sugar.Errorf("%v", err.Error())
		}
	} else {
		// 假如域名存在，使用LetsEncrypt certificates
		if config.TlsParams.Domain != "" {
			logger.Sugar.Infof("Domain specified, using LetsEncrypt to autogenerate and serve certs for %s\n", config.TlsParams.Domain)
			// 必须使用443
			if !strings.HasSuffix(config.ProxyParams.Address, ":443") {
				logger.Sugar.Infof("WARN: Right now, you must serve on port :443 to use autogenerated LetsEncrypt certs using the -domain flag, this may NOT WORK")
			}
			m := &autocert.Manager{
				Cache:      autocert.DirCache("certs"),
				Prompt:     autocert.AcceptTOS,
				HostPolicy: autocert.HostWhitelist(config.TlsParams.Domain),
			}
			server := &http.Server{
				Addr:      config.TlsParams.Domain,
				TLSConfig: m.TLSConfig(),
			}
			server.Handler = mux
			logger.Sugar.Infof("Proxying calls from https://%s to %s with LetsEncrypt started!", config.ProxyParams.Address, toURL)
			err = server.ListenAndServeTLS("", "")
			if err != nil {
				logger.Sugar.Errorf("failed to server.ListenAndServeTLS: %v", err.Error())
			}
		} else {
			// 没有域名，使用自己生成的证书
			if config.ProxyParams.Mode == "tls" {
				if config.TlsParams.Cert == "" || config.TlsParams.Key == "" {
					panic("NoTLSCertKey")
				}
				logger.Sugar.Infof("Proxying calls from https://%s to %s started!", config.ProxyParams.Address, toURL)
				err = http.ListenAndServeTLS(config.ProxyParams.Address, config.TlsParams.Cert, config.TlsParams.Key, mux)
				if err != nil {
					logger.Sugar.Errorf("failed to http.ListenAndServeTLS: %v", err.Error())
				}
			}
		}
	}
	return err
}

func init() {
	go Start()
}
