记录一下 简单udp和sni 代理 done

由于之前借鉴 Kestrel 了非常多抽象和优化实现,对于后续的扩展非常便利,

实现 简单udp和sni 代理 两个功能比预期快了超多(当然也有偷懒因素)

(PS 大家有空的话,能否在 GitHub github.com/fs7744/NZOr... 点个 star 呢?毕竟借鉴代码也不易呀 哈哈哈哈哈)

简单udp代理

这里的udp 代理功能比较简单:代理程序收到任何 udp 包都会通过路由匹配找 upstream ,然后转发给upstream

udp proxy 使用配置

基本格式和之前 tcp proxy 一致,

只是Protocols得选择UDP, 然后多了UdpResponses 允许 upstream 返回多少个 udp 包给请求者, 默认为0,即不返回任何包

json 复制代码
{
  "Logging": {
    "LogLevel": {
      "Default": "Information"
    }
  },
  "ReverseProxy": {
    "Routes": {
      "udpTest": {
        "Protocols": [ "UDP" ],
        "Match": {
          "Hosts": [ "*:5000" ]
        },
        "ClusterId": "udpTest",
        "RetryCount": 1,
        "UdpResponses": 1,
        "Timeout": "00:00:11"
      }
    },
    "Clusters": {
      "udpTest": {
        "LoadBalancingPolicy": "RoundRobin",
        "HealthCheck": {
          "Passive": {
            "Enable": true
          }
        },
        "Destinations": [
          {
            "Address": "127.0.0.1:11000"
          }
        ]
      }
    }
  }
}

实现

这里列举一下,表明有多简单

ps: 由于要实现的是非常简单udp代理,所以不基于IMultiplexedConnectionListener ,而基于 IConnectionListener 方式 (对,就是俺偷懒了)

1. 实现 UdpConnectionContext

偷懒就直接把udp 包数据放 context 上了,不放 Parameters 上,减少字典实例和内存使用

csharp 复制代码
public sealed class UdpConnectionContext : TransportConnection
{
    private readonly IMemoryOwner<byte> memory;
    public Socket Socket { get; }
    public int ReceivedBytesCount { get; }

    public Memory<byte> ReceivedBytes => memory.Memory.Slice(0, ReceivedBytesCount);

    public UdpConnectionContext(Socket socket, UdpReceiveFromResult result)
    {
        Socket = socket;
        ReceivedBytesCount = result.ReceivedBytesCount;
        this.memory = result.Buffer;
        LocalEndPoint = socket.LocalEndPoint;
        RemoteEndPoint = result.RemoteEndPoint;
    }

    public UdpConnectionContext(Socket socket, EndPoint remoteEndPoint, int receivedBytes, IMemoryOwner<byte> memory)
    {
        Socket = socket;
        ReceivedBytesCount = receivedBytes;
        this.memory = memory;
        LocalEndPoint = socket.LocalEndPoint;
        RemoteEndPoint = remoteEndPoint;
    }

    public override ValueTask DisposeAsync()
    {
        memory.Dispose();
        return default;
    }
}
2. 实现 IConnectionListener
csharp 复制代码
internal sealed class UdpConnectionListener : IConnectionListener
{
    private EndPoint? udpEndPoint;
    private readonly GatewayProtocols protocols;
    private OrzLogger _logger;
    private readonly IUdpConnectionFactory connectionFactory;
    private readonly Func<EndPoint, GatewayProtocols, Socket> createBoundListenSocket;
    private Socket? _listenSocket;

    public UdpConnectionListener(EndPoint? udpEndPoint, GatewayProtocols protocols, IRouteContractor contractor, OrzLogger logger, IUdpConnectionFactory connectionFactory)
    {
        this.udpEndPoint = udpEndPoint;
        this.protocols = protocols;
        _logger = logger;
        this.connectionFactory = connectionFactory;
        createBoundListenSocket = contractor.GetSocketTransportOptions().CreateBoundListenSocket;
    }

    public EndPoint EndPoint => udpEndPoint;

    internal void Bind()
    {
        if (_listenSocket != null)
        {
            throw new InvalidOperationException("Transport is already bound.");
        }

        Socket listenSocket;
        try
        {
            listenSocket = createBoundListenSocket(EndPoint, protocols);
        }
        catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse)
        {
            throw new AddressInUseException(e.Message, e);
        }

        Debug.Assert(listenSocket.LocalEndPoint != null);

        _listenSocket = listenSocket;
    }

    public async ValueTask<ConnectionContext?> AcceptAsync(CancellationToken cancellationToken = default)
    {
        while (true)
        {
            try
            {
                Debug.Assert(_listenSocket != null, "Bind must be called first.");
                var r = await connectionFactory.ReceiveAsync(_listenSocket, cancellationToken);
                return new UdpConnectionContext(_listenSocket, r);
            }
            catch (ObjectDisposedException)
            {
                // A call was made to UnbindAsync/DisposeAsync just return null which signals we're done
                return null;
            }
            catch (SocketException e) when (e.SocketErrorCode == SocketError.OperationAborted)
            {
                // A call was made to UnbindAsync/DisposeAsync just return null which signals we're done
                return null;
            }
            catch (SocketException)
            {
                // The connection got reset while it was in the backlog, so we try again.
                _logger.ConnectionReset("(null)");
            }
        }
    }

    public ValueTask DisposeAsync()
    {
        _listenSocket?.Dispose();

        return default;
    }

    public ValueTask UnbindAsync(CancellationToken cancellationToken = default)
    {
        _listenSocket?.Dispose();
        return default;
    }
}
3. 实现 IConnectionListenerFactory
csharp 复制代码
public sealed class UdpTransportFactory : IConnectionListenerFactory, IConnectionListenerFactorySelector
{
    private readonly IRouteContractor contractor;
    private readonly OrzLogger logger;
    private readonly IUdpConnectionFactory connectionFactory;

    public UdpTransportFactory(
        IRouteContractor contractor,
        OrzLogger logger,
        IUdpConnectionFactory connectionFactory)
    {
        ArgumentNullException.ThrowIfNull(contractor);
        ArgumentNullException.ThrowIfNull(logger);

        this.contractor = contractor;
        this.logger = logger;
        this.connectionFactory = connectionFactory;
    }

    public ValueTask<IConnectionListener> BindAsync(EndPoint endpoint, GatewayProtocols protocols, CancellationToken cancellationToken = default)
    {
        var transport = new UdpConnectionListener(endpoint, GatewayProtocols.UDP, contractor, logger, connectionFactory);
        transport.Bind();
        return new ValueTask<IConnectionListener>(transport);
    }

    public bool CanBind(EndPoint endpoint, GatewayProtocols protocols)
    {
        if (!protocols.HasFlag(GatewayProtocols.UDP)) return false;
        return endpoint switch
        {
            IPEndPoint _ => true,
            _ => false
        };
    }
}
4. 在 L4ProxyMiddleware 实现udp proxy 具体逻辑

路由和之前tcp的公用,这里就不列举了

csharp 复制代码
public class L4ProxyMiddleware : IOrderMiddleware
{    
    public async Task Invoke(ConnectionContext context, ConnectionDelegate next)
    {
        try
        {
            if (context.Protocols == GatewayProtocols.SNI)
            {
                await SNIProxyAsync(context);
            }
            else
            {
                var route = await router.MatchAsync(context);
                if (route is null)
                {
                    logger.NotFoundRouteL4(context.LocalEndPoint);
                }
                else
                {
                    context.Route = route;
                    logger.ProxyBegin(route.RouteId);
                    if (context.Protocols == GatewayProtocols.TCP)
                    {
                        await TcpProxyAsync(context, route);
                    }
                    else
                    {
                        await UdpProxyAsync((UdpConnectionContext)context, route);
                    }
                    logger.ProxyEnd(route.RouteId);
                }
            }
        }
        catch (Exception ex)
        {
            logger.UnexpectedException(ex.Message, ex);
        }
        finally
        {
            await next(context);
        }
    }

    private async Task UdpProxyAsync(UdpConnectionContext context, RouteConfig route)
    {
        try
        {
            var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
            var cts = route.CreateTimeoutTokenSource(cancellationTokenSourcePool);
            var token = cts.Token;
            if (await DoUdpSendToAsync(socket, context, route, route.RetryCount, await reqUdp(context, context.ReceivedBytes, token), token))
            {
                var c = route.UdpResponses;
                while (c > 0)
                {
                    var r = await udp.ReceiveAsync(socket, token);
                    c--;
                    await udp.SendToAsync(context.Socket, context.RemoteEndPoint, await respUdp(context, r.GetReceivedBytes(), token), token);
                }
            }
            else
            {
                logger.NotFoundAvailableUpstream(route.ClusterId);
            }
        }
        catch (OperationCanceledException)
        {
            logger.ConnectUpstreamTimeout(route.RouteId);
        }
        catch (Exception ex)
        {
            logger.UnexpectedException(nameof(UdpProxyAsync), ex);
        }
        finally
        {
            context.SelectedDestination?.ConcurrencyCounter.Decrement();
        }
    }

所以是不是真的简单, 理论上基于 Kestrel 也是一个样子哦

优化

当然参考于 Kestrel 的 tcp socket 处理,也是有些简单优化的, 比如

  • 不使用 UdpClient (ps 不是因为实现烂哈,而是其比较公用,没有机会让我们改变里面的内容)
  • 基于 SocketAsyncEventArgs, IValueTaskSource<SocketReceiveFromResult>SocketAsyncEventArgs, IValueTaskSource<int> 实现 将异步读写交予 PipeScheduler 的逻辑
  • 基于 ConcurrentQueue<UdpSender> 实现简单的 udp发送对象池,加强对象复用,稍稍稍微减少内存占用
  • 基于 ConcurrentQueue<PooledCancellationTokenSource> 实现简单的 CancellationTokenSource对象池,加强对象复用,稍稍稍微减少内存占用

sni代理

除了 tcp 和 udp 的基本代理, 也尝试实现了一个 对tcp的 sni 代理,(比如 http1 和 http2 的 https)

不过目前只实现了代理不做ssl加密解密,upstream自己处理的pass 模式,如果代理要实现ssl加密解密,理论上基于现成的 sslstream

sni proxy 使用配置

只需配置Listen 中 公用的 sni 监听端口

然后不同 sni 配置自己的路由和upstream就好

同时每个route 可以通过SupportSslProtocols限制 tls 版本

举个栗子

json 复制代码
{
  "Logging": {
    "LogLevel": {
      "Default": "Information"
    }
  },
  "ReverseProxy": {
    "Listen": {
      "snitest": {
        "Protocols": "SNI",
        "Address": [ "*:444" ]
      }
    },
    "Routes": {
      "snitestroute": {
        "Protocols": "SNI",
        "SupportSslProtocols": [ "Tls13", "Tls12" ],
        "Match": {
          "Hosts": [ "*google.com" ]
        },
        "ClusterId": "apidemo"
      },
      "snitestroute2": {
        "Protocols": "Tcp",
        "Match": {
          "Hosts": [ "*:448" ]
        },
        "ClusterId": "apidemo"
      }
    },
    "Clusters": {
      "apidemo": {
        "LoadBalancingPolicy": "RoundRobin",
        "HealthCheck": {
          "Active": {
            "Enable": true,
            "Policy": "Connect"
          }
        },
        "Destinations": [
          {
            "Address": "https://www.google.com"
          }
        ]
      }
    }
  }
}

实现

核心实现其实只有 路由 处理 ,proxy 代理和 tcp 代理一模一样(在请求 和 upstream 间搬运 tcp数据而已)

路由处理

通过 ClientHello 找到要访问的 域名, 然后通过域名匹配路由找到 upstream, 最后搬运 tcp数据

ClientHello 解析就直接搬运自TlsFrameHelper

csharp 复制代码
    /// 路由匹配
    public async ValueTask<(RouteConfig, ReadResult)> MatchSNIAsync(ConnectionContext context, CancellationToken token)
    {
        if (sniRoute is null) return (null, default);
        var (hello, rr) = await TryGetClientHelloAsync(context, token);
        if (hello.HasValue)
        {
            var h = hello.Value;
            var r = await sniRoute.MatchAsync(h.TargetName.Reverse(), h, MatchSNI);
            if (r is null)
            {
                logger.NotFoundRouteSni(h.TargetName);
            }
            return (r, rr);
        }
        else
        {
            logger.NotFoundRouteSni("client hello failed");
            return (null, rr);
        }
    }

    /// 匹配 tls 版本
    private bool MatchSNI(RouteConfig config, TlsFrameInfo info)
    {
        if (!config.SupportSslProtocols.HasValue) return true;
        var v = config.SupportSslProtocols.Value;
        if (v == SslProtocols.None) return true;
        var t = info.SupportedVersions;
        if (v.HasFlag(SslProtocols.Tls13) && t.HasFlag(SslProtocols.Tls13)) return true;
        else if (v.HasFlag(SslProtocols.Tls12) && t.HasFlag(SslProtocols.Tls12)) return true;
        else if (v.HasFlag(SslProtocols.Tls11) && t.HasFlag(SslProtocols.Tls11)) return true;
        else if (v.HasFlag(SslProtocols.Tls) && t.HasFlag(SslProtocols.Tls)) return true;
        else if (v.HasFlag(SslProtocols.Ssl3) && t.HasFlag(SslProtocols.Ssl3)) return true;
        else if (v.HasFlag(SslProtocols.Ssl2) && t.HasFlag(SslProtocols.Ssl2)) return true;
        else if (v.HasFlag(SslProtocols.Default) && t.HasFlag(SslProtocols.Default)) return true;
        else return false;
    }

    /// 解析 ClientHello
    private static async ValueTask<(TlsFrameInfo?, ReadResult)> TryGetClientHelloAsync(ConnectionContext context, CancellationToken token)
    {
        var input = context.Transport.Input;
        TlsFrameInfo info = default;
        while (true)
        {
            var f = await input.ReadAsync(token).ConfigureAwait(false);
            if (f.IsCompleted)
            {
                return (null, f);
            }
            var buffer = f.Buffer;
            if (buffer.Length == 0)
            {
                continue;
            }

            var data = buffer.IsSingleSegment ? buffer.First.Span : buffer.ToArray();
            if (TlsFrameHelper.TryGetFrameInfo(data, ref info))
            {
                return (info, f);
            }
            else
            {
                input.AdvanceTo(buffer.Start, buffer.End);
                continue;
            }
        }
    }
搬运 tcp数据
csharp 复制代码
private async Task SNIProxyAsync(ConnectionContext context)
{
    var c = cancellationTokenSourcePool.Rent();
    c.CancelAfter(options.ConnectionTimeout);
    var (route, r) = await router.MatchSNIAsync(context, c.Token);
    if (route is not null)
    {
        context.Route = route;
        logger.ProxyBegin(route.RouteId);
        ConnectionContext upstream = null;
        try
        {
            upstream = await DoConnectionAsync(context, route, route.RetryCount);
            if (upstream is null)
            {
                logger.NotFoundAvailableUpstream(route.ClusterId);
            }
            else
            {
                context.SelectedDestination?.ConcurrencyCounter.Increment();
                var cts = route.CreateTimeoutTokenSource(cancellationTokenSourcePool);
                var t = cts.Token;
                await r.CopyToAsync(upstream.Transport.Output, t); // 和tcp 代理搬运数据唯一不同, 要先发送 ClientHello 数据,因为已经被我们读取了
                context.Transport.Input.AdvanceTo(r.Buffer.End);
                var task = hasMiddlewareTcp ?
                        await Task.WhenAny(
                        context.Transport.Input.CopyToAsync(new MiddlewarePipeWriter(upstream.Transport.Output, context, reqTcp), t)
                        , upstream.Transport.Input.CopyToAsync(new MiddlewarePipeWriter(context.Transport.Output, context, respTcp), t))
                        : await Task.WhenAny(
                        context.Transport.Input.CopyToAsync(upstream.Transport.Output, t)
                        , upstream.Transport.Input.CopyToAsync(context.Transport.Output, t));
                if (task.IsCanceled)
                {
                    logger.ProxyTimeout(route.RouteId, route.Timeout);
                }
            }
        }
        catch (OperationCanceledException)
        {
            logger.ConnectUpstreamTimeout(route.RouteId);
        }
        catch (Exception ex)
        {
            logger.UnexpectedException(nameof(TcpProxyAsync), ex);
        }
        finally
        {
            context.SelectedDestination?.ConcurrencyCounter.Decrement();
            upstream?.Abort();
        }
        logger.ProxyEnd(route.RouteId);
    }
}

组件各部分都是可替换或者可增加的

因为整体都是基于ioc的,所以组件各部分都是可替换或者可增加的, 客制化扩展还是很高的哦

目前暴露的列表可在 代码这里面查看

csharp 复制代码
internal static HostApplicationBuilder UseOrzDefaults(this HostApplicationBuilder builder)
{
    var services = builder.Services;
    services.AddSingleton<IHostedService, HostedService>();
    services.AddSingleton(TimeProvider.System);
    services.AddSingleton<IMeterFactory, DummyMeterFactory>();
    services.AddSingleton<IServer, OrzServer>();
    services.AddSingleton<OrzLogger>();
    services.AddSingleton<OrzMetrics>();
    services.AddSingleton<IConnectionListenerFactory, SocketTransportFactory>();
    services.AddSingleton<IConnectionListenerFactory, UdpTransportFactory>();
    services.AddSingleton<IUdpConnectionFactory, UdpConnectionFactory>();
    services.AddSingleton<IConnectionFactory, SocketConnectionFactory>();
    services.AddSingleton<IRouteContractorValidator, RouteContractorValidator>();
    services.AddSingleton<IEndPointConvertor, CommonEndPointConvertor>();
    services.AddSingleton<IL4Router, L4Router>();
    services.AddSingleton<IOrderMiddleware, L4ProxyMiddleware>();
    services.AddSingleton<ILoadBalancingPolicyFactory, LoadBalancingPolicy>();
    services.AddSingleton<IClusterConfigValidator, ClusterConfigValidator>();
    services.AddSingleton<IDestinationResolver, DnsDestinationResolver>();

    services.AddSingleton<ILoadBalancingPolicy, RandomLoadBalancingPolicy>();
    services.AddSingleton<ILoadBalancingPolicy, RoundRobinLoadBalancingPolicy>();
    services.AddSingleton<ILoadBalancingPolicy, LeastRequestsLoadBalancingPolicy>();
    services.AddSingleton<ILoadBalancingPolicy, PowerOfTwoChoicesLoadBalancingPolicy>();

    services.AddSingleton<IHealthReporter, PassiveHealthReporter>();
    services.AddSingleton<IHealthUpdater, HealthyAndUnknownDestinationsUpdater>();
    services.AddSingleton<IActiveHealthCheckMonitor, ActiveHealthCheckMonitor>();
    services.AddSingleton<IActiveHealthChecker, ConnectionActiveHealthChecker>();

    return builder;
}

比如要添加 负载均衡策略,就可以实现

csharp 复制代码
public interface ILoadBalancingPolicy
{
    string Name { get; }

    DestinationState? PickDestination(ConnectionContext context, IReadOnlyList<DestinationState> availableDestinations);
}

如果对全部已有负载均衡策略都不满意,那就可以直接替换 ILoadBalancingPolicyFactory

csharp 复制代码
public interface ILoadBalancingPolicyFactory
{
    DestinationState? PickDestination(ConnectionContext context, RouteConfig route);
}

比如你就可以通过sni将开发环境(或者其他环境)无法访问的请求在一台有其他访问权限的机器进行转发

差不多就做了这些,造轮子还是挺好玩的,当然大家如果在 GitHub github.com/fs7744/NZOr... 点个 star, 就更好玩了

相关推荐
Vane17 分钟前
从零开发一个AI插件,经历了什么?
人工智能·后端
9523629 分钟前
SpringBoot统一功能处理
java·spring boot·后端
rleS IONS1 小时前
SpringBoot中自定义Starter
java·spring boot·后端
DevilSeagull1 小时前
MySQL(2) 客户端工具和建库
开发语言·数据库·后端·mysql·服务
TeDi TIVE2 小时前
springboot和springframework版本依赖关系
java·spring boot·后端
雨辰AI2 小时前
SpringBoot3 + 人大金仓 V9 微服务监控实战|Prometheus+Grafana+SkyWalking 全链路监控
数据库·后端·微服务·grafana·prometheus·skywalking
Nicander3 小时前
理解 mybatis 源码:vibe-coding一个mini-mybatis
后端·mybatis
小呆呆6663 小时前
Codex 穷鬼大救星
前端·人工智能·后端
FelixBitSoul4 小时前
缓存淘汰策略全解:从原理到手写实现(Java / Go / Python)
后端·面试
AI人工智能+电脑小能手4 小时前
【大白话说Java面试题】【Java基础篇】第29题:静态代理和动态代理的区别是什么
java·开发语言·后端·面试·代理模式