格式化
This commit is contained in:
+183
-110
@@ -21,11 +21,12 @@ import (
|
||||
|
||||
"mmw-agent/internal/collector"
|
||||
"mmw-agent/internal/config"
|
||||
"mmw-agent/internal/constants"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// ConnectionMode represents the current connection mode
|
||||
// ConnectionMode 表示当前连接模式。
|
||||
type ConnectionMode string
|
||||
|
||||
const (
|
||||
@@ -35,7 +36,7 @@ const (
|
||||
ModeAuto ConnectionMode = "auto"
|
||||
)
|
||||
|
||||
// Client represents an agent client that connects to a master server
|
||||
// Client 表示连接主控端的 agent 客户端。
|
||||
type Client struct {
|
||||
config *config.Config
|
||||
collector *collector.Collector
|
||||
@@ -47,20 +48,20 @@ type Client struct {
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Connection state
|
||||
// 连接状态
|
||||
currentMode ConnectionMode
|
||||
httpClient *http.Client
|
||||
httpAvailable bool
|
||||
modeMu sync.RWMutex
|
||||
|
||||
// Speed calculation (from system network interface)
|
||||
// 速率计算(基于系统网卡统计)
|
||||
lastRxBytes int64
|
||||
lastTxBytes int64
|
||||
lastSampleTime time.Time
|
||||
speedMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewClient creates a new agent client
|
||||
// 创建 agent 客户端。
|
||||
func NewClient(cfg *config.Config) *Client {
|
||||
return &Client{
|
||||
config: cfg,
|
||||
@@ -68,20 +69,20 @@ func NewClient(cfg *config.Config) *Client {
|
||||
xrayServers: cfg.XrayServers,
|
||||
stopCh: make(chan struct{}),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: constants.DefaultHTTPClientTimeout,
|
||||
},
|
||||
currentMode: ModePull, // Default to pull mode
|
||||
currentMode: ModePull, // 默认使用拉取模式
|
||||
}
|
||||
}
|
||||
|
||||
// wsHeaders returns HTTP headers for WebSocket handshake
|
||||
// 生成 WebSocket 握手请求头。
|
||||
func (c *Client) wsHeaders() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("User-Agent", config.AgentUserAgent)
|
||||
h.Set(constants.HeaderUserAgent, constants.AgentUserAgent)
|
||||
return h
|
||||
}
|
||||
|
||||
// newRequest creates an HTTP request with standard headers (Content-Type, Authorization, User-Agent)
|
||||
// 创建带标准请求头的 HTTP 请求。
|
||||
func (c *Client) newRequest(ctx context.Context, method, urlStr string, body []byte) (*http.Request, error) {
|
||||
var req *http.Request
|
||||
var err error
|
||||
@@ -93,13 +94,13 @@ func (c *Client) newRequest(ctx context.Context, method, urlStr string, body []b
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.Token)
|
||||
req.Header.Set("User-Agent", config.AgentUserAgent)
|
||||
req.Header.Set(constants.HeaderContentType, constants.ContentTypeJSON)
|
||||
req.Header.Set(constants.HeaderAuthorization, constants.BearerPrefix+c.config.Token)
|
||||
req.Header.Set(constants.HeaderUserAgent, constants.AgentUserAgent)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Start starts the agent client with automatic mode selection
|
||||
// 按配置启动客户端。
|
||||
func (c *Client) Start(ctx context.Context) {
|
||||
log.Printf("[Agent] Starting in %s mode", c.config.ConnectionMode)
|
||||
|
||||
@@ -117,7 +118,7 @@ func (c *Client) Start(ctx context.Context) {
|
||||
case ModePull:
|
||||
c.setCurrentMode(ModePull)
|
||||
log.Printf("[Agent] Pull mode enabled - API will be served at /api/child/traffic and /api/child/speed")
|
||||
// Report agent info immediately via HTTP heartbeat
|
||||
// 启动后先通过 HTTP 上报一次心跳信息
|
||||
if err := c.sendHeartbeatHTTP(ctx); err != nil {
|
||||
log.Printf("[Agent] Failed to send initial heartbeat in pull mode: %v", err)
|
||||
}
|
||||
@@ -130,7 +131,7 @@ func (c *Client) Start(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the agent client
|
||||
// 停止客户端。
|
||||
func (c *Client) Stop() {
|
||||
close(c.stopCh)
|
||||
c.wg.Wait()
|
||||
@@ -144,33 +145,33 @@ func (c *Client) Stop() {
|
||||
log.Printf("[Agent] Stopped")
|
||||
}
|
||||
|
||||
// IsConnected returns whether the WebSocket is connected
|
||||
// 返回 WebSocket 连接状态。
|
||||
func (c *Client) IsConnected() bool {
|
||||
c.wsMu.Lock()
|
||||
defer c.wsMu.Unlock()
|
||||
return c.connected
|
||||
}
|
||||
|
||||
// GetCurrentMode returns the current connection mode
|
||||
// 返回当前连接模式。
|
||||
func (c *Client) GetCurrentMode() ConnectionMode {
|
||||
c.modeMu.RLock()
|
||||
defer c.modeMu.RUnlock()
|
||||
return c.currentMode
|
||||
}
|
||||
|
||||
// setCurrentMode sets the current connection mode
|
||||
// 设置当前连接模式。
|
||||
func (c *Client) setCurrentMode(mode ConnectionMode) {
|
||||
c.modeMu.Lock()
|
||||
defer c.modeMu.Unlock()
|
||||
c.currentMode = mode
|
||||
}
|
||||
|
||||
// runWebSocket manages the WebSocket connection lifecycle with fallback to auto mode
|
||||
// 维护 WebSocket 连接,并在失败时回退自动模式。
|
||||
func (c *Client) runWebSocket(ctx context.Context) {
|
||||
defer c.wg.Done()
|
||||
|
||||
maxConsecutiveFailures := 5
|
||||
maxAuthFailures := 10
|
||||
maxConsecutiveFailures := constants.WebSocketMaxConsecutiveFailures
|
||||
maxAuthFailures := constants.WebSocketMaxAuthFailures
|
||||
consecutiveFailures := 0
|
||||
authFailures := 0
|
||||
|
||||
@@ -190,22 +191,22 @@ func (c *Client) runWebSocket(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is an authentication error
|
||||
// 判断是否为鉴权错误
|
||||
if authErr, ok := err.(*AuthError); ok {
|
||||
authFailures++
|
||||
if authErr.IsTokenInvalid() {
|
||||
log.Printf("[Agent] Authentication failed (invalid token): %v", err)
|
||||
if authFailures >= maxAuthFailures {
|
||||
log.Printf("[Agent] Too many auth failures (%d), entering sleep mode (30 min backoff)", authFailures)
|
||||
c.waitWithTrafficReport(ctx, 30*time.Minute)
|
||||
c.waitWithTrafficReport(ctx, constants.AuthFailureSleepBackoff)
|
||||
authFailures = 0
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Use longer backoff for auth errors
|
||||
backoff := time.Duration(authFailures) * 30 * time.Second
|
||||
if backoff > 10*time.Minute {
|
||||
backoff = 10 * time.Minute
|
||||
// 鉴权错误使用更长退避时间
|
||||
backoff := time.Duration(authFailures) * constants.AuthFailureBackoffStep
|
||||
if backoff > constants.AuthFailureMaxBackoff {
|
||||
backoff = constants.AuthFailureMaxBackoff
|
||||
}
|
||||
log.Printf("[Agent] Auth error, reconnecting in %v...", backoff)
|
||||
c.waitWithTrafficReport(ctx, backoff)
|
||||
@@ -233,21 +234,21 @@ func (c *Client) runWebSocket(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// calculateBackoff calculates the reconnection backoff duration with exponential increase
|
||||
// 计算重连退避时长。
|
||||
func (c *Client) calculateBackoff() time.Duration {
|
||||
c.reconnects++
|
||||
// Exponential backoff: 5s, 10s, 20s, 40s, 80s, 160s, 300s(cap)
|
||||
backoff := 5 * time.Second
|
||||
for i := 1; i < c.reconnects && backoff < 5*time.Minute; i++ {
|
||||
// 指数退避: 5s, 10s, 20s, 40s, 80s, 160s, 300s(上限)
|
||||
backoff := constants.ReconnectBaseBackoff
|
||||
for i := 1; i < c.reconnects && backoff < constants.ReconnectMaxBackoff; i++ {
|
||||
backoff *= 2
|
||||
}
|
||||
if backoff > 5*time.Minute {
|
||||
backoff = 5 * time.Minute
|
||||
if backoff > constants.ReconnectMaxBackoff {
|
||||
backoff = constants.ReconnectMaxBackoff
|
||||
}
|
||||
return backoff
|
||||
}
|
||||
|
||||
// connectAndRun establishes and maintains a WebSocket connection
|
||||
// 建立并维持 WebSocket 连接。
|
||||
func (c *Client) connectAndRun(ctx context.Context) error {
|
||||
masterURL := c.config.MasterURL
|
||||
u, err := url.Parse(masterURL)
|
||||
@@ -262,12 +263,12 @@ func (c *Client) connectAndRun(ctx context.Context) error {
|
||||
u.Scheme = "wss"
|
||||
}
|
||||
|
||||
u.Path = "/api/remote/ws"
|
||||
u.Path = constants.PathRemoteWebSocket
|
||||
|
||||
log.Printf("[Agent] Connecting to %s", u.String())
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
HandshakeTimeout: constants.WebSocketHandshakeTimeout,
|
||||
}
|
||||
|
||||
conn, _, err := dialer.DialContext(ctx, u.String(), c.wsHeaders())
|
||||
@@ -298,18 +299,18 @@ func (c *Client) connectAndRun(ctx context.Context) error {
|
||||
|
||||
log.Printf("[Agent] Connected and authenticated")
|
||||
|
||||
// Report agent info (listen_port) immediately after connection
|
||||
// 连接成功后立即上报 agent 信息(listen_port)
|
||||
if err := c.sendHeartbeat(conn); err != nil {
|
||||
log.Printf("[Agent] Failed to send initial heartbeat: %v", err)
|
||||
}
|
||||
|
||||
// Send scan result to master for auto-sync
|
||||
// 异步上报扫描结果,供主控端自动同步
|
||||
go c.sendScanResult(conn)
|
||||
|
||||
return c.runMessageLoop(ctx, conn)
|
||||
}
|
||||
|
||||
// authenticate sends the authentication message
|
||||
// 发送鉴权消息。
|
||||
func (c *Client) authenticate(conn *websocket.Conn) error {
|
||||
authPayload, _ := json.Marshal(map[string]string{
|
||||
"token": c.config.Token,
|
||||
@@ -324,7 +325,7 @@ func (c *Client) authenticate(conn *websocket.Conn) error {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
conn.SetReadDeadline(time.Now().Add(constants.WebSocketReadDeadline))
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -349,11 +350,11 @@ func (c *Client) authenticate(conn *websocket.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMessageLoop handles sending traffic data, speed data, and heartbeats
|
||||
// 处理流量、速率和心跳上报。
|
||||
func (c *Client) runMessageLoop(ctx context.Context, conn *websocket.Conn) error {
|
||||
trafficTicker := time.NewTicker(c.config.TrafficReportInterval)
|
||||
speedTicker := time.NewTicker(c.config.SpeedReportInterval)
|
||||
heartbeatTicker := time.NewTicker(30 * time.Second)
|
||||
heartbeatTicker := time.NewTicker(constants.WebSocketHeartbeatInterval)
|
||||
defer trafficTicker.Stop()
|
||||
defer speedTicker.Stop()
|
||||
defer heartbeatTicker.Stop()
|
||||
@@ -362,13 +363,13 @@ func (c *Client) runMessageLoop(ctx context.Context, conn *websocket.Conn) error
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
|
||||
conn.SetReadDeadline(time.Now().Add(constants.WebSocketIdleDeadline))
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
// Send message to processing channel
|
||||
// 投递到消息处理通道
|
||||
select {
|
||||
case msgCh <- message:
|
||||
default:
|
||||
@@ -406,7 +407,7 @@ func (c *Client) runMessageLoop(ctx context.Context, conn *websocket.Conn) error
|
||||
}
|
||||
}
|
||||
|
||||
// sendTrafficData collects and sends traffic data to the master
|
||||
// 采集并发送流量数据。
|
||||
func (c *Client) sendTrafficData(conn *websocket.Conn) error {
|
||||
stats, err := c.collectLocalMetrics()
|
||||
if err != nil {
|
||||
@@ -437,7 +438,7 @@ func (c *Client) sendTrafficData(conn *websocket.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendHeartbeat sends a heartbeat message
|
||||
// 发送心跳消息。
|
||||
func (c *Client) sendHeartbeat(conn *websocket.Conn) error {
|
||||
now := time.Now()
|
||||
listenPort, _ := strconv.Atoi(c.config.ListenPort)
|
||||
@@ -458,7 +459,7 @@ func (c *Client) sendHeartbeat(conn *websocket.Conn) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// collectLocalMetrics collects traffic metrics from local Xray servers
|
||||
// 采集本机 Xray 流量指标。
|
||||
func (c *Client) collectLocalMetrics() (*collector.XrayStats, error) {
|
||||
stats := &collector.XrayStats{
|
||||
Inbound: make(map[string]collector.TrafficData),
|
||||
@@ -487,23 +488,23 @@ func (c *Client) collectLocalMetrics() (*collector.XrayStats, error) {
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetStats returns the current traffic stats (for pull mode)
|
||||
// 返回当前流量统计(拉取模式)。
|
||||
func (c *Client) GetStats() (*collector.XrayStats, error) {
|
||||
return c.collectLocalMetrics()
|
||||
}
|
||||
|
||||
// GetSpeed returns the current speed data (for pull mode)
|
||||
// 返回当前速率(拉取模式)。
|
||||
func (c *Client) GetSpeed() (uploadSpeed, downloadSpeed int64) {
|
||||
return c.collectSpeed()
|
||||
}
|
||||
|
||||
// runAutoMode implements the three-tier fallback: WebSocket -> HTTP -> Pull
|
||||
// 使用三层回退:WebSocket -> HTTP -> Pull。
|
||||
func (c *Client) runAutoMode(ctx context.Context) {
|
||||
defer c.wg.Done()
|
||||
c.runAutoModeLoop(ctx)
|
||||
}
|
||||
|
||||
// runAutoModeLoop is the internal loop for auto mode fallback
|
||||
// 是自动模式的内部循环。
|
||||
func (c *Client) runAutoModeLoop(ctx context.Context) {
|
||||
autoRetries := 0
|
||||
for {
|
||||
@@ -548,14 +549,14 @@ func (c *Client) runAutoModeLoop(ctx context.Context) {
|
||||
log.Printf("[Agent] Falling back to pull mode - API available at /api/child/traffic and /api/child/speed")
|
||||
c.sendHeartbeatHTTP(ctx)
|
||||
|
||||
// Exponential backoff for pull mode: 30s, 60s, 120s, 240s, 300s(cap)
|
||||
// 拉取模式退避: 30s, 60s, 120s, 240s, 300s(上限)
|
||||
autoRetries++
|
||||
pullDuration := 30 * time.Second
|
||||
for i := 1; i < autoRetries && pullDuration < 5*time.Minute; i++ {
|
||||
pullDuration := constants.AutoModePullFallbackBackoff
|
||||
for i := 1; i < autoRetries && pullDuration < constants.ReconnectMaxBackoff; i++ {
|
||||
pullDuration *= 2
|
||||
}
|
||||
if pullDuration > 5*time.Minute {
|
||||
pullDuration = 5 * time.Minute
|
||||
if pullDuration > constants.ReconnectMaxBackoff {
|
||||
pullDuration = constants.ReconnectMaxBackoff
|
||||
}
|
||||
|
||||
c.runPullModeWithTrafficReport(ctx, pullDuration)
|
||||
@@ -567,7 +568,7 @@ func (c *Client) runAutoModeLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// tryWebSocketOnce attempts a single WebSocket connection test
|
||||
// 执行一次 WebSocket 可用性探测。
|
||||
func (c *Client) tryWebSocketOnce(ctx context.Context) error {
|
||||
masterURL := c.config.MasterURL
|
||||
u, err := url.Parse(masterURL)
|
||||
@@ -581,10 +582,10 @@ func (c *Client) tryWebSocketOnce(ctx context.Context) error {
|
||||
case "https":
|
||||
u.Scheme = "wss"
|
||||
}
|
||||
u.Path = "/api/remote/ws"
|
||||
u.Path = constants.PathRemoteWebSocket
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
HandshakeTimeout: constants.WebSocketHandshakeTimeout,
|
||||
}
|
||||
|
||||
conn, _, err := dialer.DialContext(ctx, u.String(), c.wsHeaders())
|
||||
@@ -595,13 +596,13 @@ func (c *Client) tryWebSocketOnce(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryHTTPOnce tests if HTTP push is available
|
||||
// 探测 HTTP 推送是否可用。
|
||||
func (c *Client) tryHTTPOnce(ctx context.Context) bool {
|
||||
u, err := url.Parse(c.config.MasterURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
u.Path = "/api/remote/heartbeat"
|
||||
u.Path = constants.PathRemoteHeartbeat
|
||||
|
||||
req, err := c.newRequest(ctx, http.MethodPost, u.String(), []byte("{}"))
|
||||
if err != nil {
|
||||
@@ -619,18 +620,18 @@ func (c *Client) tryHTTPOnce(ctx context.Context) bool {
|
||||
return c.httpAvailable
|
||||
}
|
||||
|
||||
// runHTTPReporter runs the HTTP push reporter
|
||||
// 运行 HTTP 推送上报器。
|
||||
func (c *Client) runHTTPReporter(ctx context.Context) {
|
||||
defer c.wg.Done()
|
||||
c.setCurrentMode(ModeHTTP)
|
||||
c.runHTTPReporterLoop(ctx)
|
||||
}
|
||||
|
||||
// runHTTPReporterLoop runs the HTTP reporting loop
|
||||
// 执行 HTTP 上报循环。
|
||||
func (c *Client) runHTTPReporterLoop(ctx context.Context) {
|
||||
trafficTicker := time.NewTicker(c.config.TrafficReportInterval)
|
||||
speedTicker := time.NewTicker(c.config.SpeedReportInterval)
|
||||
heartbeatTicker := time.NewTicker(30 * time.Second)
|
||||
heartbeatTicker := time.NewTicker(constants.WebSocketHeartbeatInterval)
|
||||
defer trafficTicker.Stop()
|
||||
defer speedTicker.Stop()
|
||||
defer heartbeatTicker.Stop()
|
||||
@@ -676,7 +677,7 @@ func (c *Client) runHTTPReporterLoop(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// sendTrafficHTTP sends traffic data via HTTP POST
|
||||
// 通过 HTTP POST 发送流量数据。
|
||||
func (c *Client) sendTrafficHTTP(ctx context.Context) error {
|
||||
stats, err := c.collectLocalMetrics()
|
||||
if err != nil {
|
||||
@@ -691,7 +692,7 @@ func (c *Client) sendTrafficHTTP(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Path = "/api/remote/traffic"
|
||||
u.Path = constants.PathRemoteTraffic
|
||||
|
||||
req, err := c.newRequest(ctx, http.MethodPost, u.String(), payload)
|
||||
if err != nil {
|
||||
@@ -714,7 +715,7 @@ func (c *Client) sendTrafficHTTP(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendSpeedHTTP sends speed data via HTTP POST
|
||||
// 通过 HTTP POST 发送速率数据。
|
||||
func (c *Client) sendSpeedHTTP(ctx context.Context) error {
|
||||
uploadSpeed, downloadSpeed := c.collectSpeed()
|
||||
|
||||
@@ -727,7 +728,7 @@ func (c *Client) sendSpeedHTTP(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Path = "/api/remote/speed"
|
||||
u.Path = constants.PathRemoteSpeed
|
||||
|
||||
req, err := c.newRequest(ctx, http.MethodPost, u.String(), payload)
|
||||
if err != nil {
|
||||
@@ -749,7 +750,7 @@ func (c *Client) sendSpeedHTTP(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendHeartbeatHTTP sends heartbeat via HTTP POST
|
||||
// 通过 HTTP POST 发送心跳。
|
||||
func (c *Client) sendHeartbeatHTTP(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
listenPort, _ := strconv.Atoi(c.config.ListenPort)
|
||||
@@ -762,7 +763,7 @@ func (c *Client) sendHeartbeatHTTP(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Path = "/api/remote/heartbeat"
|
||||
u.Path = constants.PathRemoteHeartbeat
|
||||
|
||||
req, err := c.newRequest(ctx, http.MethodPost, u.String(), payload)
|
||||
if err != nil {
|
||||
@@ -783,7 +784,7 @@ func (c *Client) sendHeartbeatHTTP(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runPullModeWithTrafficReport runs pull mode while sending traffic data to keep server online
|
||||
// 在拉取模式下持续上报流量,保持在线状态。
|
||||
func (c *Client) runPullModeWithTrafficReport(ctx context.Context, duration time.Duration) {
|
||||
trafficTicker := time.NewTicker(c.config.TrafficReportInterval)
|
||||
defer trafficTicker.Stop()
|
||||
@@ -810,13 +811,13 @@ func (c *Client) runPullModeWithTrafficReport(ctx context.Context, duration time
|
||||
}
|
||||
}
|
||||
|
||||
// waitWithTrafficReport waits for the specified duration while sending traffic data
|
||||
// 在等待期间继续上报流量。
|
||||
func (c *Client) waitWithTrafficReport(ctx context.Context, duration time.Duration) {
|
||||
if duration <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if duration > 30*time.Second {
|
||||
if duration > constants.PullModeTrafficReportThreshold {
|
||||
if err := c.sendTrafficHTTP(ctx); err != nil {
|
||||
log.Printf("[Agent] Traffic report during backoff failed: %v", err)
|
||||
}
|
||||
@@ -843,7 +844,7 @@ func (c *Client) waitWithTrafficReport(ctx context.Context, duration time.Durati
|
||||
}
|
||||
}
|
||||
|
||||
// sendSpeedData sends speed data via WebSocket
|
||||
// 通过 WebSocket 发送速率数据。
|
||||
func (c *Client) sendSpeedData(conn *websocket.Conn) error {
|
||||
uploadSpeed, downloadSpeed := c.collectSpeed()
|
||||
|
||||
@@ -869,7 +870,7 @@ func (c *Client) sendSpeedData(conn *websocket.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectSpeed calculates the current upload and download speed from system network interface
|
||||
// 基于系统网卡统计计算当前上下行速率。
|
||||
func (c *Client) collectSpeed() (uploadSpeed, downloadSpeed int64) {
|
||||
c.speedMu.Lock()
|
||||
defer c.speedMu.Unlock()
|
||||
@@ -900,7 +901,7 @@ func (c *Client) collectSpeed() (uploadSpeed, downloadSpeed int64) {
|
||||
return uploadSpeed, downloadSpeed
|
||||
}
|
||||
|
||||
// getSystemNetworkStats reads network statistics from /proc/net/dev
|
||||
// 从 /proc/net/dev 读取网卡统计。
|
||||
func (c *Client) getSystemNetworkStats() (rxBytes, txBytes int64) {
|
||||
data, err := os.ReadFile("/proc/net/dev")
|
||||
if err != nil {
|
||||
@@ -936,7 +937,7 @@ func (c *Client) getSystemNetworkStats() (rxBytes, txBytes int64) {
|
||||
return rxBytes, txBytes
|
||||
}
|
||||
|
||||
// AuthError represents an authentication error
|
||||
// AuthError 表示鉴权失败错误。
|
||||
type AuthError struct {
|
||||
Message string
|
||||
Code string // "token_expired", "token_invalid", "server_error"
|
||||
@@ -946,12 +947,12 @@ func (e *AuthError) Error() string {
|
||||
return "authentication failed: " + e.Message
|
||||
}
|
||||
|
||||
// IsTokenInvalid returns true if the error indicates an invalid token
|
||||
// 判断是否为 token 无效错误。
|
||||
func (e *AuthError) IsTokenInvalid() bool {
|
||||
return e.Code == "token_invalid" || e.Message == "Invalid token"
|
||||
}
|
||||
|
||||
// WebSocket message types
|
||||
// WebSocket 消息类型
|
||||
const (
|
||||
WSMsgTypeCertDeploy = "cert_deploy"
|
||||
WSMsgTypeTokenUpdate = "token_update"
|
||||
@@ -960,7 +961,7 @@ const (
|
||||
WSMsgTypeDomainLatencyResult = "domain_latency_result"
|
||||
)
|
||||
|
||||
// WSCertDeployPayload represents a certificate deploy command from master
|
||||
// WSCertDeployPayload 是主控端下发的证书部署指令。
|
||||
type WSCertDeployPayload struct {
|
||||
Domain string `json:"domain"`
|
||||
CertPEM string `json:"cert_pem"`
|
||||
@@ -970,20 +971,20 @@ type WSCertDeployPayload struct {
|
||||
Reload string `json:"reload"`
|
||||
}
|
||||
|
||||
// WSTokenUpdatePayload represents a token update from master
|
||||
// WSTokenUpdatePayload 是主控端下发的 token 更新指令。
|
||||
type WSTokenUpdatePayload struct {
|
||||
ServerToken string `json:"server_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// WSDomainLatencyProbePayload is received from master
|
||||
// WSDomainLatencyProbePayload 是主控端下发的域名延迟探测请求。
|
||||
type WSDomainLatencyProbePayload struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Domains []string `json:"domains"`
|
||||
TimeoutMs int `json:"timeout_ms"`
|
||||
}
|
||||
|
||||
// handleMessage processes incoming messages from master
|
||||
// 处理主控端下发的消息。
|
||||
func (c *Client) handleMessage(conn *websocket.Conn, message []byte) {
|
||||
var msg struct {
|
||||
Type string `json:"type"`
|
||||
@@ -1018,11 +1019,11 @@ func (c *Client) handleMessage(conn *websocket.Conn, message []byte) {
|
||||
}
|
||||
go c.handleDomainLatencyProbe(conn, payload)
|
||||
default:
|
||||
// Ignore unknown message types
|
||||
// 忽略未知消息类型
|
||||
}
|
||||
}
|
||||
|
||||
// handleCertDeploy deploys a certificate received from master
|
||||
// 处理主控端下发的证书部署。
|
||||
func (c *Client) handleCertDeploy(payload WSCertDeployPayload) {
|
||||
log.Printf("[Agent] Received cert_deploy for domain: %s, target: %s", payload.Domain, payload.Reload)
|
||||
|
||||
@@ -1065,7 +1066,7 @@ func deployCert(certPEM, keyPEM, certPath, keyPath, reloadTarget string) error {
|
||||
}
|
||||
|
||||
func reloadNginxCmd() error {
|
||||
for _, bin := range []string{"/usr/local/nginx/sbin/nginx", "nginx"} {
|
||||
for _, bin := range constants.NginxBinarySearchPaths {
|
||||
if path, err := exec.LookPath(bin); err == nil {
|
||||
return runCmd(path, "-s", "reload")
|
||||
}
|
||||
@@ -1080,43 +1081,47 @@ func runCmd(name string, args ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTokenUpdate processes a token update from master
|
||||
// 处理主控端下发的 token 更新。
|
||||
func (c *Client) handleTokenUpdate(payload WSTokenUpdatePayload) {
|
||||
log.Printf("[Agent] Received token update from master, new token expires at %s", payload.ExpiresAt.Format(time.RFC3339))
|
||||
|
||||
// Update the token in memory
|
||||
// 更新内存中的 token
|
||||
c.config.Token = payload.ServerToken
|
||||
|
||||
log.Printf("[Agent] Token updated successfully in memory")
|
||||
}
|
||||
|
||||
// handleDomainLatencyProbe probes domain latency locally and sends results back via WebSocket
|
||||
// 在本机探测域名延迟并回传结果。
|
||||
func (c *Client) handleDomainLatencyProbe(conn *websocket.Conn, payload WSDomainLatencyProbePayload) {
|
||||
log.Printf("[Agent] Received domain_latency_probe: %d domains, timeout=%dms", len(payload.Domains), payload.TimeoutMs)
|
||||
|
||||
timeoutMs := payload.TimeoutMs
|
||||
if timeoutMs <= 0 {
|
||||
timeoutMs = 2000
|
||||
timeoutMs = constants.DomainProbeDefaultTimeoutMS
|
||||
}
|
||||
if timeoutMs < 200 {
|
||||
timeoutMs = 200
|
||||
if timeoutMs < constants.DomainProbeMinTimeoutMS {
|
||||
timeoutMs = constants.DomainProbeMinTimeoutMS
|
||||
}
|
||||
if timeoutMs > 10000 {
|
||||
timeoutMs = 10000
|
||||
if timeoutMs > constants.DomainProbeMaxTimeoutMS {
|
||||
timeoutMs = constants.DomainProbeMaxTimeoutMS
|
||||
}
|
||||
timeout := time.Duration(timeoutMs) * time.Millisecond
|
||||
|
||||
type probeResult struct {
|
||||
Domain string `json:"domain"`
|
||||
Target string `json:"target"`
|
||||
Success bool `json:"success"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Domain string `json:"domain"`
|
||||
Target string `json:"target"`
|
||||
Success bool `json:"success"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
NginxSSLPort int `json:"nginx_ssl_port,omitempty"`
|
||||
}
|
||||
|
||||
// 读取本机 nginx 配置,构造 domain -> ssl 端口映射
|
||||
nginxPortMap := readNginxSSLPorts(payload.Domains)
|
||||
|
||||
results := make([]probeResult, 0, len(payload.Domains))
|
||||
resultCh := make(chan probeResult, len(payload.Domains))
|
||||
sem := make(chan struct{}, 16)
|
||||
sem := make(chan struct{}, constants.DomainProbeConcurrency)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, domain := range payload.Domains {
|
||||
@@ -1145,7 +1150,7 @@ func (c *Client) handleDomainLatencyProbe(conn *websocket.Conn, payload WSDomain
|
||||
return
|
||||
}
|
||||
_ = tcpConn.Close()
|
||||
resultCh <- probeResult{Domain: host, Target: target, Success: true, LatencyMs: time.Since(start).Milliseconds()}
|
||||
resultCh <- probeResult{Domain: host, Target: target, Success: true, LatencyMs: time.Since(start).Milliseconds(), NginxSSLPort: nginxPortMap[host]}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1155,7 +1160,7 @@ func (c *Client) handleDomainLatencyProbe(conn *websocket.Conn, payload WSDomain
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
// Sort: success first, then by latency
|
||||
// 排序:成功优先,再按延迟升序
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
if results[i].Success != results[j].Success {
|
||||
return results[i].Success
|
||||
@@ -1201,9 +1206,80 @@ func (c *Client) handleDomainLatencyProbe(conn *websocket.Conn, payload WSDomain
|
||||
log.Printf("[Agent] Sent domain_latency_result: %d results", len(results))
|
||||
}
|
||||
|
||||
// sendScanResult scans local xray status and sends results to master
|
||||
// readNginxSSLPorts 读取 nginx 配置并返回 domain -> SSL 端口映射。
|
||||
// 会在常见 nginx 配置目录下查找 servers/{domain}.conf。
|
||||
func readNginxSSLPorts(domains []string) map[string]int {
|
||||
result := make(map[string]int)
|
||||
if len(domains) == 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
confDirs := constants.NginxSSLServerDirPaths
|
||||
|
||||
for _, domain := range domains {
|
||||
host := domain
|
||||
if h, _, err := net.SplitHostPort(domain); err == nil && h != "" {
|
||||
host = h
|
||||
}
|
||||
for _, dir := range confDirs {
|
||||
confPath := filepath.Join(dir, host+".conf")
|
||||
data, err := os.ReadFile(confPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if port := extractSSLListenPort(string(data)); port > 0 {
|
||||
result[host] = port
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// 提取 nginx 配置块中第一个 "listen <port> ssl" 端口。
|
||||
func extractSSLListenPort(conf string) int {
|
||||
// 匹配示例: listen 58443 ssl
|
||||
for _, line := range strings.Split(conf, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "listen") {
|
||||
continue
|
||||
}
|
||||
// 去掉 "listen" 前缀和结尾分号
|
||||
rest := strings.TrimPrefix(line, "listen")
|
||||
rest = strings.TrimRight(rest, ";")
|
||||
rest = strings.TrimSpace(rest)
|
||||
fields := strings.Fields(rest)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
// 判断字段中是否包含 "ssl"
|
||||
hasSSL := false
|
||||
for _, f := range fields[1:] {
|
||||
if f == "ssl" {
|
||||
hasSSL = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasSSL {
|
||||
continue
|
||||
}
|
||||
// 第一个字段是端口(或 [::]:port)
|
||||
portStr := fields[0]
|
||||
// 兼容 [::]:port 形式
|
||||
if idx := strings.LastIndex(portStr, ":"); idx >= 0 {
|
||||
portStr = portStr[idx+1:]
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err == nil && port > 0 {
|
||||
return port
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// 扫描本机 xray 状态并上报主控端。
|
||||
func (c *Client) sendScanResult(conn *websocket.Conn) {
|
||||
// Check xray running status
|
||||
// 检查 xray 运行状态
|
||||
xrayRunning := false
|
||||
xrayVersion := ""
|
||||
cmd := exec.Command("xray", "version")
|
||||
@@ -1214,12 +1290,9 @@ func (c *Client) sendScanResult(conn *websocket.Conn) {
|
||||
xrayRunning = true
|
||||
}
|
||||
|
||||
// Read inbounds from config
|
||||
// 从配置读取入站列表
|
||||
var inbounds []map[string]interface{}
|
||||
configPaths := []string{
|
||||
"/usr/local/etc/xray/config.json",
|
||||
"/etc/xray/config.json",
|
||||
}
|
||||
configPaths := constants.DefaultXrayConfigPaths
|
||||
for _, cfgPath := range configPaths {
|
||||
data, err := os.ReadFile(cfgPath)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user