From d036d981286a04a64df86ba9ddeacba12d533047 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sun, 18 May 2025 22:32:25 +0800 Subject: [PATCH] fix: http server does not handle http2 logic correctly --- component/tls/httpserver.go | 68 +++++++++++++++++++++++++++++++++ hub/route/server.go | 7 +--- listener/inbound/common_test.go | 27 ++++++++----- listener/inbound/vless_test.go | 18 ++------- listener/sing_vless/server.go | 18 +++++---- listener/sing_vmess/server.go | 18 +++++---- listener/trojan/server.go | 18 +++++---- 7 files changed, 125 insertions(+), 49 deletions(-) create mode 100644 component/tls/httpserver.go diff --git a/component/tls/httpserver.go b/component/tls/httpserver.go new file mode 100644 index 000000000..a8c9ed7f0 --- /dev/null +++ b/component/tls/httpserver.go @@ -0,0 +1,68 @@ +package tls + +import ( + "context" + "net" + "net/http" + "time" + + N "github.com/metacubex/mihomo/common/net" + "github.com/metacubex/mihomo/log" + + "golang.org/x/net/http2" +) + +func extractTlsHandshakeTimeoutFromServer(s *http.Server) time.Duration { + var ret time.Duration + for _, v := range [...]time.Duration{ + s.ReadHeaderTimeout, + s.ReadTimeout, + s.WriteTimeout, + } { + if v <= 0 { + continue + } + if ret == 0 || v < ret { + ret = v + } + } + return ret +} + +// NewListenerForHttps returns a net.Listener for (*http.Server).Serve() +// the "func (c *conn) serve(ctx context.Context)" in http\server.go +// only do tls handshake and check NegotiatedProtocol with std's *tls.Conn +// so we do the same logic to let http2 (not h2c) work fine +func NewListenerForHttps(l net.Listener, httpServer *http.Server, tlsConfig *Config) net.Listener { + http2Server := &http2.Server{} + _ = http2.ConfigureServer(httpServer, http2Server) + return N.NewHandleContextListener(context.Background(), l, func(ctx context.Context, conn net.Conn) (net.Conn, error) { + c := Server(conn, tlsConfig) + + tlsTO := extractTlsHandshakeTimeoutFromServer(httpServer) + if tlsTO > 0 { + dl := time.Now().Add(tlsTO) + _ = conn.SetReadDeadline(dl) + _ = conn.SetWriteDeadline(dl) + } + + err := c.HandshakeContext(ctx) + if err != nil { + return nil, err + } + + // Restore Conn-level deadlines. + if tlsTO > 0 { + _ = conn.SetReadDeadline(time.Time{}) + _ = conn.SetWriteDeadline(time.Time{}) + } + + if c.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS { + http2Server.ServeConn(c, &http2.ServeConnOpts{BaseConfig: httpServer}) + return nil, net.ErrClosed + } + return c, nil + }, func(a any) { + log.Errorln("https server panic: %s", a) + }) +} diff --git a/hub/route/server.go b/hub/route/server.go index c41f51a5e..7f3977ed6 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -28,8 +28,6 @@ import ( "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/sagernet/cors" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) var ( @@ -215,11 +213,10 @@ func startTLS(cfg *Config) { } } server := &http.Server{ - // using h2c.NewHandler to ensure we can work in plain http2 and some tls conn is not *tls.Conn - Handler: h2c.NewHandler(router(cfg.IsDebug, cfg.Secret, cfg.DohServer, cfg.Cors), &http2.Server{}), + Handler: router(cfg.IsDebug, cfg.Secret, cfg.DohServer, cfg.Cors), } tlsServer = server - if err = server.Serve(tlsC.NewListener(l, tlsConfig)); err != nil { + if err = server.Serve(tlsC.NewListenerForHttps(l, server, tlsConfig)); err != nil { log.Errorln("External controller tls serve error: %s", err) } } diff --git a/listener/inbound/common_test.go b/listener/inbound/common_test.go index 14bf52507..5b8a6f17e 100644 --- a/listener/inbound/common_test.go +++ b/listener/inbound/common_test.go @@ -20,11 +20,13 @@ import ( "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/generater" + tlsC "github.com/metacubex/mihomo/component/tls" C "github.com/metacubex/mihomo/constant" "github.com/go-chi/chi/v5" "github.com/go-chi/render" "github.com/stretchr/testify/assert" + "golang.org/x/net/http2" ) var httpPath = "/inbound_test" @@ -134,7 +136,10 @@ func NewHttpTestTunnel() *TestTunnel { r.Get(httpPath, func(w http.ResponseWriter, r *http.Request) { render.Data(w, r, httpData) }) - go http.Serve(ln, r) + h2Server := &http2.Server{} + server := http.Server{Handler: r} + _ = http2.ConfigureServer(&server, h2Server) + go server.Serve(ln) testFn := func(t *testing.T, proxy C.ProxyAdapter, proto string) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s://%s%s", proto, remoteAddr, httpPath), nil) if !assert.NoError(t, err) { @@ -208,23 +213,27 @@ func NewHttpTestTunnel() *TestTunnel { ch: make(chan struct{}), } if metadata.DstPort == 443 { - tlsConn := tls.Server(c, tlsConfig.Clone()) + tlsConn := tlsC.Server(c, tlsC.UConfig(tlsConfig)) if metadata.Host == realityDest { // ignore the tls handshake error for realityDest if realityRealDial { rconn, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress()) if err != nil { panic(err) } - N.Relay(rconn, tlsConn) - return - } - ctx, cancel := context.WithTimeout(ctx, C.DefaultTLSTimeout) - defer cancel() - if err := tlsConn.HandshakeContext(ctx); err != nil { + N.Relay(rconn, conn) return } } - ln.ch <- tlsConn + ctx, cancel := context.WithTimeout(ctx, C.DefaultTLSTimeout) + defer cancel() + if err := tlsConn.HandshakeContext(ctx); err != nil { + return + } + if tlsConn.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS { + h2Server.ServeConn(tlsConn, &http2.ServeConnOpts{BaseConfig: &server}) + } else { + ln.ch <- tlsConn + } } else { ln.ch <- c } diff --git a/listener/inbound/vless_test.go b/listener/inbound/vless_test.go index d2056c05e..1bec3d557 100644 --- a/listener/inbound/vless_test.go +++ b/listener/inbound/vless_test.go @@ -237,14 +237,9 @@ func TestInboundVless_Reality(t *testing.T) { outboundOptions.Flow = "xtls-rprx-vision" testInboundVless(t, inboundOptions, outboundOptions) }) - t.Run("ECH", func(t *testing.T) { - inboundOptions := inboundOptions + t.Run("X25519MLKEM768", func(t *testing.T) { outboundOptions := outboundOptions - inboundOptions.EchKey = echKeyPem - outboundOptions.ECHOpts = outbound.ECHOptions{ - Enable: true, - Config: echConfigBase64, - } + outboundOptions.RealityOpts.SupportX25519MLKEM768 = true testInboundVless(t, inboundOptions, outboundOptions) t.Run("xtls-rprx-vision", func(t *testing.T) { outboundOptions := outboundOptions @@ -276,14 +271,9 @@ func TestInboundVless_Reality_Grpc(t *testing.T) { GrpcOpts: outbound.GrpcOptions{GrpcServiceName: "GunService"}, } testInboundVless(t, inboundOptions, outboundOptions) - t.Run("ECH", func(t *testing.T) { - inboundOptions := inboundOptions + t.Run("X25519MLKEM768", func(t *testing.T) { outboundOptions := outboundOptions - inboundOptions.EchKey = echKeyPem - outboundOptions.ECHOpts = outbound.ECHOptions{ - Enable: true, - Config: echConfigBase64, - } + outboundOptions.RealityOpts.SupportX25519MLKEM768 = true testInboundVless(t, inboundOptions, outboundOptions) }) } diff --git a/listener/sing_vless/server.go b/listener/sing_vless/server.go index 90cff6572..16aa1c654 100644 --- a/listener/sing_vless/server.go +++ b/listener/sing_vless/server.go @@ -84,7 +84,7 @@ func New(config LC.VlessServer, tunnel C.Tunnel, additions ...inbound.Addition) tlsConfig := &tlsC.Config{} var realityBuilder *reality.Builder - var httpHandler http.Handler + var httpServer http.Server if config.Certificate != "" && config.PrivateKey != "" { cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) @@ -119,16 +119,16 @@ func New(config LC.VlessServer, tunnel C.Tunnel, additions ...inbound.Addition) } sl.HandleConn(conn, tunnel, additions...) }) - httpHandler = httpMux + httpServer.Handler = httpMux tlsConfig.NextProtos = append(tlsConfig.NextProtos, "http/1.1") } if config.GrpcServiceName != "" { - httpHandler = gun.NewServerHandler(gun.ServerOption{ + httpServer.Handler = gun.NewServerHandler(gun.ServerOption{ ServiceName: config.GrpcServiceName, ConnHandler: func(conn net.Conn) { sl.HandleConn(conn, tunnel, additions...) }, - HttpHandler: httpHandler, + HttpHandler: httpServer.Handler, }) tlsConfig.NextProtos = append([]string{"h2"}, tlsConfig.NextProtos...) // h2 must before http/1.1 } @@ -144,15 +144,19 @@ func New(config LC.VlessServer, tunnel C.Tunnel, additions ...inbound.Addition) if realityBuilder != nil { l = realityBuilder.NewListener(l) } else if len(tlsConfig.Certificates) > 0 { - l = tlsC.NewListener(l, tlsConfig) + if httpServer.Handler != nil { + l = tlsC.NewListenerForHttps(l, &httpServer, tlsConfig) + } else { + l = tlsC.NewListener(l, tlsConfig) + } } else { return nil, errors.New("disallow using Vless without both certificates/reality config") } sl.listeners = append(sl.listeners, l) go func() { - if httpHandler != nil { - _ = http.Serve(l, httpHandler) + if httpServer.Handler != nil { + _ = httpServer.Serve(l) return } for { diff --git a/listener/sing_vmess/server.go b/listener/sing_vmess/server.go index b65a294bd..0b4d013a8 100644 --- a/listener/sing_vmess/server.go +++ b/listener/sing_vmess/server.go @@ -78,7 +78,7 @@ func New(config LC.VmessServer, tunnel C.Tunnel, additions ...inbound.Addition) tlsConfig := &tlsC.Config{} var realityBuilder *reality.Builder - var httpHandler http.Handler + var httpServer http.Server if config.Certificate != "" && config.PrivateKey != "" { cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) @@ -113,16 +113,16 @@ func New(config LC.VmessServer, tunnel C.Tunnel, additions ...inbound.Addition) } sl.HandleConn(conn, tunnel, additions...) }) - httpHandler = httpMux + httpServer.Handler = httpMux tlsConfig.NextProtos = append(tlsConfig.NextProtos, "http/1.1") } if config.GrpcServiceName != "" { - httpHandler = gun.NewServerHandler(gun.ServerOption{ + httpServer.Handler = gun.NewServerHandler(gun.ServerOption{ ServiceName: config.GrpcServiceName, ConnHandler: func(conn net.Conn) { sl.HandleConn(conn, tunnel, additions...) }, - HttpHandler: httpHandler, + HttpHandler: httpServer.Handler, }) tlsConfig.NextProtos = append([]string{"h2"}, tlsConfig.NextProtos...) // h2 must before http/1.1 } @@ -138,13 +138,17 @@ func New(config LC.VmessServer, tunnel C.Tunnel, additions ...inbound.Addition) if realityBuilder != nil { l = realityBuilder.NewListener(l) } else if len(tlsConfig.Certificates) > 0 { - l = tlsC.NewListener(l, tlsConfig) + if httpServer.Handler != nil { + l = tlsC.NewListenerForHttps(l, &httpServer, tlsConfig) + } else { + l = tlsC.NewListener(l, tlsConfig) + } } sl.listeners = append(sl.listeners, l) go func() { - if httpHandler != nil { - _ = http.Serve(l, httpHandler) + if httpServer.Handler != nil { + _ = httpServer.Serve(l) return } for { diff --git a/listener/trojan/server.go b/listener/trojan/server.go index 780273f21..3ea7c3879 100644 --- a/listener/trojan/server.go +++ b/listener/trojan/server.go @@ -72,7 +72,7 @@ func New(config LC.TrojanServer, tunnel C.Tunnel, additions ...inbound.Addition) tlsConfig := &tlsC.Config{} var realityBuilder *reality.Builder - var httpHandler http.Handler + var httpServer http.Server if config.Certificate != "" && config.PrivateKey != "" { cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) @@ -107,16 +107,16 @@ func New(config LC.TrojanServer, tunnel C.Tunnel, additions ...inbound.Addition) } sl.HandleConn(conn, tunnel, additions...) }) - httpHandler = httpMux + httpServer.Handler = httpMux tlsConfig.NextProtos = append(tlsConfig.NextProtos, "http/1.1") } if config.GrpcServiceName != "" { - httpHandler = gun.NewServerHandler(gun.ServerOption{ + httpServer.Handler = gun.NewServerHandler(gun.ServerOption{ ServiceName: config.GrpcServiceName, ConnHandler: func(conn net.Conn) { sl.HandleConn(conn, tunnel, additions...) }, - HttpHandler: httpHandler, + HttpHandler: httpServer.Handler, }) tlsConfig.NextProtos = append([]string{"h2"}, tlsConfig.NextProtos...) // h2 must before http/1.1 } @@ -132,15 +132,19 @@ func New(config LC.TrojanServer, tunnel C.Tunnel, additions ...inbound.Addition) if realityBuilder != nil { l = realityBuilder.NewListener(l) } else if len(tlsConfig.Certificates) > 0 { - l = tlsC.NewListener(l, tlsConfig) + if httpServer.Handler != nil { + l = tlsC.NewListenerForHttps(l, &httpServer, tlsConfig) + } else { + l = tlsC.NewListener(l, tlsConfig) + } } else if !config.TrojanSSOption.Enabled { return nil, errors.New("disallow using Trojan without both certificates/reality/ss config") } sl.listeners = append(sl.listeners, l) go func() { - if httpHandler != nil { - _ = http.Serve(l, httpHandler) + if httpServer.Handler != nil { + _ = httpServer.Serve(l) return } for {