Files
mmw-agent/internal/agent/client.go
T

1339 lines
33 KiB
Go
Raw Normal View History

2026-01-28 13:13:58 +08:00
package agent
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
2026-04-08 21:10:56 +08:00
"net"
2026-01-28 13:13:58 +08:00
"net/http"
"net/url"
"os"
2026-03-12 17:20:16 +08:00
"os/exec"
"path/filepath"
2026-04-08 21:10:56 +08:00
"sort"
2026-01-28 13:13:58 +08:00
"strconv"
"strings"
"sync"
"time"
"mmw-agent/internal/collector"
"mmw-agent/internal/config"
2026-04-10 15:25:21 +08:00
"mmw-agent/internal/constants"
2026-01-28 13:13:58 +08:00
"github.com/gorilla/websocket"
)
2026-04-10 15:25:21 +08:00
// ConnectionMode 表示当前连接模式。
2026-01-28 13:13:58 +08:00
type ConnectionMode string
const (
ModeWebSocket ConnectionMode = "websocket"
ModeHTTP ConnectionMode = "http"
ModePull ConnectionMode = "pull"
ModeAuto ConnectionMode = "auto"
)
2026-04-10 15:25:21 +08:00
// Client 表示连接主控端的 agent 客户端。
2026-01-28 13:13:58 +08:00
type Client struct {
config *config.Config
collector *collector.Collector
xrayServers []config.XrayServer
wsConn *websocket.Conn
wsMu sync.Mutex
connected bool
reconnects int
stopCh chan struct{}
wg sync.WaitGroup
2026-04-10 15:25:21 +08:00
// 连接状态
2026-01-28 13:13:58 +08:00
currentMode ConnectionMode
httpClient *http.Client
httpAvailable bool
modeMu sync.RWMutex
2026-04-10 15:25:21 +08:00
// 速率计算(基于系统网卡统计)
2026-01-28 13:13:58 +08:00
lastRxBytes int64
lastTxBytes int64
lastSampleTime time.Time
speedMu sync.Mutex
}
2026-04-10 15:25:21 +08:00
// 创建 agent 客户端。
2026-01-28 13:13:58 +08:00
func NewClient(cfg *config.Config) *Client {
return &Client{
config: cfg,
collector: collector.NewCollector(),
xrayServers: cfg.XrayServers,
stopCh: make(chan struct{}),
httpClient: &http.Client{
2026-04-10 15:25:21 +08:00
Timeout: constants.DefaultHTTPClientTimeout,
2026-01-28 13:13:58 +08:00
},
2026-04-10 15:25:21 +08:00
currentMode: ModePull, // 默认使用拉取模式
2026-01-28 13:13:58 +08:00
}
}
2026-04-10 15:25:21 +08:00
// 生成 WebSocket 握手请求头。
2026-03-12 16:13:50 +08:00
func (c *Client) wsHeaders() http.Header {
h := http.Header{}
2026-04-10 15:25:21 +08:00
h.Set(constants.HeaderUserAgent, constants.AgentUserAgent)
2026-03-12 16:13:50 +08:00
return h
}
2026-04-10 15:25:21 +08:00
// 创建带标准请求头的 HTTP 请求。
2026-03-12 16:13:50 +08:00
func (c *Client) newRequest(ctx context.Context, method, urlStr string, body []byte) (*http.Request, error) {
var req *http.Request
var err error
if body != nil {
req, err = http.NewRequestWithContext(ctx, method, urlStr, bytes.NewReader(body))
} else {
req, err = http.NewRequestWithContext(ctx, method, urlStr, nil)
}
if err != nil {
return nil, err
}
2026-04-10 15:25:21 +08:00
req.Header.Set(constants.HeaderContentType, constants.ContentTypeJSON)
req.Header.Set(constants.HeaderAuthorization, constants.BearerPrefix+c.config.Token)
req.Header.Set(constants.HeaderUserAgent, constants.AgentUserAgent)
2026-03-12 16:13:50 +08:00
return req, nil
}
2026-04-10 15:25:21 +08:00
// 按配置启动客户端。
2026-01-28 13:13:58 +08:00
func (c *Client) Start(ctx context.Context) {
log.Printf("[Agent] Starting in %s mode", c.config.ConnectionMode)
mode := ConnectionMode(c.config.ConnectionMode)
switch mode {
case ModeWebSocket:
c.wg.Add(1)
go c.runWebSocket(ctx)
case ModeHTTP:
c.wg.Add(1)
go c.runHTTPReporter(ctx)
case ModePull:
c.setCurrentMode(ModePull)
log.Printf("[Agent] Pull mode enabled - API will be served at /api/child/traffic and /api/child/speed")
2026-04-10 15:25:21 +08:00
// 启动后先通过 HTTP 上报一次心跳信息
2026-03-12 21:32:36 +08:00
if err := c.sendHeartbeatHTTP(ctx); err != nil {
log.Printf("[Agent] Failed to send initial heartbeat in pull mode: %v", err)
}
2026-01-28 13:13:58 +08:00
case ModeAuto:
fallthrough
default:
c.wg.Add(1)
go c.runAutoMode(ctx)
}
}
2026-04-10 15:25:21 +08:00
// 停止客户端。
2026-01-28 13:13:58 +08:00
func (c *Client) Stop() {
close(c.stopCh)
c.wg.Wait()
c.wsMu.Lock()
if c.wsConn != nil {
c.wsConn.Close()
}
c.wsMu.Unlock()
log.Printf("[Agent] Stopped")
}
2026-04-10 15:25:21 +08:00
// 返回 WebSocket 连接状态。
2026-01-28 13:13:58 +08:00
func (c *Client) IsConnected() bool {
c.wsMu.Lock()
defer c.wsMu.Unlock()
return c.connected
}
2026-04-10 15:25:21 +08:00
// 返回当前连接模式。
2026-01-28 13:13:58 +08:00
func (c *Client) GetCurrentMode() ConnectionMode {
c.modeMu.RLock()
defer c.modeMu.RUnlock()
return c.currentMode
}
2026-04-10 15:25:21 +08:00
// 设置当前连接模式。
2026-01-28 13:13:58 +08:00
func (c *Client) setCurrentMode(mode ConnectionMode) {
c.modeMu.Lock()
defer c.modeMu.Unlock()
c.currentMode = mode
}
2026-04-10 15:25:21 +08:00
// 维护 WebSocket 连接,并在失败时回退自动模式。
2026-01-28 13:13:58 +08:00
func (c *Client) runWebSocket(ctx context.Context) {
defer c.wg.Done()
2026-04-10 15:25:21 +08:00
maxConsecutiveFailures := constants.WebSocketMaxConsecutiveFailures
maxAuthFailures := constants.WebSocketMaxAuthFailures
2026-01-28 13:13:58 +08:00
consecutiveFailures := 0
2026-01-29 18:01:08 +08:00
authFailures := 0
2026-01-28 13:13:58 +08:00
for {
select {
case <-ctx.Done():
return
case <-c.stopCh:
return
default:
}
c.setCurrentMode(ModeWebSocket)
if err := c.connectAndRun(ctx); err != nil {
if ctx.Err() != nil {
log.Printf("[Agent] Context canceled, stopping gracefully")
return
}
2026-04-10 15:25:21 +08:00
// 判断是否为鉴权错误
2026-01-29 18:01:08 +08:00
if authErr, ok := err.(*AuthError); ok {
authFailures++
if authErr.IsTokenInvalid() {
log.Printf("[Agent] Authentication failed (invalid token): %v", err)
if authFailures >= maxAuthFailures {
log.Printf("[Agent] Too many auth failures (%d), entering sleep mode (30 min backoff)", authFailures)
2026-04-10 15:25:21 +08:00
c.waitWithTrafficReport(ctx, constants.AuthFailureSleepBackoff)
2026-01-29 18:01:08 +08:00
authFailures = 0
continue
}
}
2026-04-10 15:25:21 +08:00
// 鉴权错误使用更长退避时间
backoff := time.Duration(authFailures) * constants.AuthFailureBackoffStep
if backoff > constants.AuthFailureMaxBackoff {
backoff = constants.AuthFailureMaxBackoff
2026-01-29 18:01:08 +08:00
}
log.Printf("[Agent] Auth error, reconnecting in %v...", backoff)
c.waitWithTrafficReport(ctx, backoff)
continue
}
2026-01-28 13:13:58 +08:00
log.Printf("[Agent] WebSocket error: %v", err)
consecutiveFailures++
2026-01-29 18:01:08 +08:00
authFailures = 0 // Reset auth failures on non-auth errors
2026-01-28 13:13:58 +08:00
if consecutiveFailures >= maxConsecutiveFailures {
log.Printf("[Agent] Too many WebSocket failures (%d), switching to auto mode for fallback...", consecutiveFailures)
c.runAutoModeLoop(ctx)
consecutiveFailures = 0
continue
}
} else {
consecutiveFailures = 0
2026-01-29 18:01:08 +08:00
authFailures = 0
2026-01-28 13:13:58 +08:00
}
backoff := c.calculateBackoff()
log.Printf("[Agent] Reconnecting in %v...", backoff)
c.waitWithTrafficReport(ctx, backoff)
}
}
2026-04-10 15:25:21 +08:00
// 计算重连退避时长。
2026-01-28 13:13:58 +08:00
func (c *Client) calculateBackoff() time.Duration {
c.reconnects++
2026-04-10 15:25:21 +08:00
// 指数退避: 5s, 10s, 20s, 40s, 80s, 160s, 300s(上限)
backoff := constants.ReconnectBaseBackoff
for i := 1; i < c.reconnects && backoff < constants.ReconnectMaxBackoff; i++ {
2026-03-12 21:32:36 +08:00
backoff *= 2
}
2026-04-10 15:25:21 +08:00
if backoff > constants.ReconnectMaxBackoff {
backoff = constants.ReconnectMaxBackoff
2026-01-28 13:13:58 +08:00
}
return backoff
}
2026-04-10 15:25:21 +08:00
// 建立并维持 WebSocket 连接。
2026-01-28 13:13:58 +08:00
func (c *Client) connectAndRun(ctx context.Context) error {
masterURL := c.config.MasterURL
u, err := url.Parse(masterURL)
if err != nil {
return err
}
switch u.Scheme {
case "http":
u.Scheme = "ws"
case "https":
u.Scheme = "wss"
}
2026-04-10 15:25:21 +08:00
u.Path = constants.PathRemoteWebSocket
2026-01-28 13:13:58 +08:00
log.Printf("[Agent] Connecting to %s", u.String())
dialer := websocket.Dialer{
2026-04-10 15:25:21 +08:00
HandshakeTimeout: constants.WebSocketHandshakeTimeout,
2026-01-28 13:13:58 +08:00
}
2026-03-12 16:13:50 +08:00
conn, _, err := dialer.DialContext(ctx, u.String(), c.wsHeaders())
2026-01-28 13:13:58 +08:00
if err != nil {
return err
}
c.wsMu.Lock()
c.wsConn = conn
c.wsMu.Unlock()
defer func() {
c.wsMu.Lock()
c.wsConn = nil
c.connected = false
c.wsMu.Unlock()
conn.Close()
}()
if err := c.authenticate(conn); err != nil {
return err
}
c.wsMu.Lock()
c.connected = true
c.reconnects = 0
c.wsMu.Unlock()
log.Printf("[Agent] Connected and authenticated")
2026-04-10 15:25:21 +08:00
// 连接成功后立即上报 agent 信息(listen_port
2026-03-12 21:32:36 +08:00
if err := c.sendHeartbeat(conn); err != nil {
log.Printf("[Agent] Failed to send initial heartbeat: %v", err)
}
2026-04-10 15:25:21 +08:00
// 异步上报扫描结果,供主控端自动同步
2026-04-07 16:35:45 +08:00
go c.sendScanResult(conn)
2026-01-28 13:13:58 +08:00
return c.runMessageLoop(ctx, conn)
}
2026-04-10 15:25:21 +08:00
// 发送鉴权消息。
2026-01-28 13:13:58 +08:00
func (c *Client) authenticate(conn *websocket.Conn) error {
authPayload, _ := json.Marshal(map[string]string{
"token": c.config.Token,
})
msg := map[string]interface{}{
"type": "auth",
"payload": json.RawMessage(authPayload),
}
if err := conn.WriteJSON(msg); err != nil {
return err
}
2026-04-10 15:25:21 +08:00
conn.SetReadDeadline(time.Now().Add(constants.WebSocketReadDeadline))
2026-01-28 13:13:58 +08:00
_, message, err := conn.ReadMessage()
if err != nil {
return err
}
var result struct {
Type string `json:"type"`
Payload struct {
Success bool `json:"success"`
Message string `json:"message"`
} `json:"payload"`
}
if err := json.Unmarshal(message, &result); err != nil {
return err
}
if result.Type != "auth_result" || !result.Payload.Success {
return &AuthError{Message: result.Payload.Message}
}
return nil
}
2026-04-10 15:25:21 +08:00
// 处理流量、速率和心跳上报。
2026-01-28 13:13:58 +08:00
func (c *Client) runMessageLoop(ctx context.Context, conn *websocket.Conn) error {
trafficTicker := time.NewTicker(c.config.TrafficReportInterval)
speedTicker := time.NewTicker(c.config.SpeedReportInterval)
2026-04-10 15:25:21 +08:00
heartbeatTicker := time.NewTicker(constants.WebSocketHeartbeatInterval)
2026-01-28 13:13:58 +08:00
defer trafficTicker.Stop()
defer speedTicker.Stop()
defer heartbeatTicker.Stop()
msgCh := make(chan []byte, 10)
errCh := make(chan error, 1)
go func() {
for {
2026-04-10 15:25:21 +08:00
conn.SetReadDeadline(time.Now().Add(constants.WebSocketIdleDeadline))
2026-01-28 13:13:58 +08:00
_, message, err := conn.ReadMessage()
if err != nil {
errCh <- err
return
}
2026-04-10 15:25:21 +08:00
// 投递到消息处理通道
2026-01-28 13:13:58 +08:00
select {
case msgCh <- message:
default:
log.Printf("[Agent] Message queue full, dropping message")
}
}
}()
c.sendTrafficData(conn)
c.sendSpeedData(conn)
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.stopCh:
return nil
case err := <-errCh:
return err
case msg := <-msgCh:
c.handleMessage(conn, msg)
case <-trafficTicker.C:
if err := c.sendTrafficData(conn); err != nil {
return err
}
case <-speedTicker.C:
if err := c.sendSpeedData(conn); err != nil {
return err
}
case <-heartbeatTicker.C:
if err := c.sendHeartbeat(conn); err != nil {
return err
}
}
}
}
2026-04-10 15:25:21 +08:00
// 采集并发送流量数据。
2026-01-28 13:13:58 +08:00
func (c *Client) sendTrafficData(conn *websocket.Conn) error {
stats, err := c.collectLocalMetrics()
if err != nil {
log.Printf("[Agent] Failed to collect metrics: %v", err)
stats = &collector.XrayStats{}
}
payload, _ := json.Marshal(map[string]interface{}{
"stats": stats,
})
msg := map[string]interface{}{
"type": "traffic",
"payload": json.RawMessage(payload),
}
c.wsMu.Lock()
err = conn.WriteJSON(msg)
c.wsMu.Unlock()
if err != nil {
return err
}
log.Printf("[Agent] Sent traffic data: %d inbounds, %d outbounds, %d users",
len(stats.Inbound), len(stats.Outbound), len(stats.User))
return nil
}
2026-04-10 15:25:21 +08:00
// 发送心跳消息。
2026-01-28 13:13:58 +08:00
func (c *Client) sendHeartbeat(conn *websocket.Conn) error {
now := time.Now()
listenPort, _ := strconv.Atoi(c.config.ListenPort)
payload, _ := json.Marshal(map[string]interface{}{
"boot_time": now,
"listen_port": listenPort,
})
msg := map[string]interface{}{
"type": "heartbeat",
"payload": json.RawMessage(payload),
}
c.wsMu.Lock()
err := conn.WriteJSON(msg)
c.wsMu.Unlock()
return err
}
2026-04-10 15:25:21 +08:00
// 采集本机 Xray 流量指标。
2026-01-28 13:13:58 +08:00
func (c *Client) collectLocalMetrics() (*collector.XrayStats, error) {
stats := &collector.XrayStats{
Inbound: make(map[string]collector.TrafficData),
Outbound: make(map[string]collector.TrafficData),
User: make(map[string]collector.TrafficData),
}
for _, server := range c.xrayServers {
host, port, err := c.collector.GetMetricsPortFromConfig(server.ConfigPath)
if err != nil {
log.Printf("[Agent] Failed to get metrics config for %s: %v", server.Name, err)
continue
}
metrics, err := c.collector.FetchMetrics(host, port)
if err != nil {
log.Printf("[Agent] Failed to fetch metrics for %s: %v", server.Name, err)
continue
}
if metrics.Stats != nil {
collector.MergeStats(stats, metrics.Stats)
}
}
return stats, nil
}
2026-04-10 15:25:21 +08:00
// 返回当前流量统计(拉取模式)。
2026-01-28 13:13:58 +08:00
func (c *Client) GetStats() (*collector.XrayStats, error) {
return c.collectLocalMetrics()
}
2026-04-10 15:25:21 +08:00
// 返回当前速率(拉取模式)。
2026-01-28 13:13:58 +08:00
func (c *Client) GetSpeed() (uploadSpeed, downloadSpeed int64) {
return c.collectSpeed()
}
2026-04-10 15:25:21 +08:00
// 使用三层回退:WebSocket -> HTTP -> Pull。
2026-01-28 13:13:58 +08:00
func (c *Client) runAutoMode(ctx context.Context) {
defer c.wg.Done()
c.runAutoModeLoop(ctx)
}
2026-04-10 15:25:21 +08:00
// 是自动模式的内部循环。
2026-01-28 13:13:58 +08:00
func (c *Client) runAutoModeLoop(ctx context.Context) {
2026-03-12 21:32:36 +08:00
autoRetries := 0
2026-01-28 13:13:58 +08:00
for {
select {
case <-ctx.Done():
return
case <-c.stopCh:
return
default:
}
log.Printf("[Agent] Trying WebSocket connection...")
if err := c.tryWebSocketOnce(ctx); err == nil {
c.setCurrentMode(ModeWebSocket)
log.Printf("[Agent] WebSocket mode active")
if err := c.connectAndRun(ctx); err != nil {
if ctx.Err() != nil {
log.Printf("[Agent] Context canceled, stopping gracefully")
return
}
log.Printf("[Agent] WebSocket disconnected: %v", err)
}
c.reconnects = 0
2026-03-12 21:32:36 +08:00
autoRetries = 0
2026-01-28 13:13:58 +08:00
continue
} else {
log.Printf("[Agent] WebSocket failed: %v, trying HTTP...", err)
}
if c.tryHTTPOnce(ctx) {
c.setCurrentMode(ModeHTTP)
log.Printf("[Agent] HTTP mode active")
c.runHTTPReporterLoop(ctx)
if ctx.Err() != nil {
return
}
2026-03-12 21:32:36 +08:00
autoRetries = 0
2026-01-28 13:13:58 +08:00
continue
}
c.setCurrentMode(ModePull)
log.Printf("[Agent] Falling back to pull mode - API available at /api/child/traffic and /api/child/speed")
2026-03-12 21:32:36 +08:00
c.sendHeartbeatHTTP(ctx)
2026-04-10 15:25:21 +08:00
// 拉取模式退避: 30s, 60s, 120s, 240s, 300s(上限)
2026-03-12 21:32:36 +08:00
autoRetries++
2026-04-10 15:25:21 +08:00
pullDuration := constants.AutoModePullFallbackBackoff
for i := 1; i < autoRetries && pullDuration < constants.ReconnectMaxBackoff; i++ {
2026-03-12 21:32:36 +08:00
pullDuration *= 2
}
2026-04-10 15:25:21 +08:00
if pullDuration > constants.ReconnectMaxBackoff {
pullDuration = constants.ReconnectMaxBackoff
2026-03-12 21:32:36 +08:00
}
2026-01-28 13:13:58 +08:00
2026-03-12 21:32:36 +08:00
c.runPullModeWithTrafficReport(ctx, pullDuration)
2026-01-28 13:13:58 +08:00
if ctx.Err() != nil {
return
}
log.Printf("[Agent] Retrying higher-priority connection modes...")
}
}
2026-04-10 15:25:21 +08:00
// 执行一次 WebSocket 可用性探测。
2026-01-28 13:13:58 +08:00
func (c *Client) tryWebSocketOnce(ctx context.Context) error {
masterURL := c.config.MasterURL
u, err := url.Parse(masterURL)
if err != nil {
return err
}
switch u.Scheme {
case "http":
u.Scheme = "ws"
case "https":
u.Scheme = "wss"
}
2026-04-10 15:25:21 +08:00
u.Path = constants.PathRemoteWebSocket
2026-01-28 13:13:58 +08:00
dialer := websocket.Dialer{
2026-04-10 15:25:21 +08:00
HandshakeTimeout: constants.WebSocketHandshakeTimeout,
2026-01-28 13:13:58 +08:00
}
2026-03-12 16:13:50 +08:00
conn, _, err := dialer.DialContext(ctx, u.String(), c.wsHeaders())
2026-01-28 13:13:58 +08:00
if err != nil {
return err
}
conn.Close()
return nil
}
2026-04-10 15:25:21 +08:00
// 探测 HTTP 推送是否可用。
2026-01-28 13:13:58 +08:00
func (c *Client) tryHTTPOnce(ctx context.Context) bool {
u, err := url.Parse(c.config.MasterURL)
if err != nil {
return false
}
2026-04-10 15:25:21 +08:00
u.Path = constants.PathRemoteHeartbeat
2026-01-28 13:13:58 +08:00
2026-03-12 16:13:50 +08:00
req, err := c.newRequest(ctx, http.MethodPost, u.String(), []byte("{}"))
2026-01-28 13:13:58 +08:00
if err != nil {
return false
}
resp, err := c.httpClient.Do(req)
if err != nil {
log.Printf("[Agent] HTTP test failed: %v", err)
return false
}
defer resp.Body.Close()
c.httpAvailable = resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized
return c.httpAvailable
}
2026-04-10 15:25:21 +08:00
// 运行 HTTP 推送上报器。
2026-01-28 13:13:58 +08:00
func (c *Client) runHTTPReporter(ctx context.Context) {
defer c.wg.Done()
c.setCurrentMode(ModeHTTP)
c.runHTTPReporterLoop(ctx)
}
2026-04-10 15:25:21 +08:00
// 执行 HTTP 上报循环。
2026-01-28 13:13:58 +08:00
func (c *Client) runHTTPReporterLoop(ctx context.Context) {
trafficTicker := time.NewTicker(c.config.TrafficReportInterval)
speedTicker := time.NewTicker(c.config.SpeedReportInterval)
2026-04-10 15:25:21 +08:00
heartbeatTicker := time.NewTicker(constants.WebSocketHeartbeatInterval)
2026-01-28 13:13:58 +08:00
defer trafficTicker.Stop()
defer speedTicker.Stop()
defer heartbeatTicker.Stop()
2026-03-12 21:32:36 +08:00
c.sendHeartbeatHTTP(ctx)
2026-01-28 13:13:58 +08:00
c.sendTrafficHTTP(ctx)
c.sendSpeedHTTP(ctx)
consecutiveErrors := 0
maxErrors := 5
for {
select {
case <-ctx.Done():
return
case <-c.stopCh:
return
case <-trafficTicker.C:
if err := c.sendTrafficHTTP(ctx); err != nil {
consecutiveErrors++
if consecutiveErrors >= maxErrors {
log.Printf("[Agent] Too many HTTP errors, will retry connection modes")
return
}
} else {
consecutiveErrors = 0
}
case <-speedTicker.C:
if err := c.sendSpeedHTTP(ctx); err != nil {
log.Printf("[Agent] Failed to send speed via HTTP: %v", err)
}
case <-heartbeatTicker.C:
if err := c.sendHeartbeatHTTP(ctx); err != nil {
consecutiveErrors++
if consecutiveErrors >= maxErrors {
log.Printf("[Agent] Too many HTTP errors, will retry connection modes")
return
}
} else {
consecutiveErrors = 0
}
}
}
}
2026-04-10 15:25:21 +08:00
// 通过 HTTP POST 发送流量数据。
2026-01-28 13:13:58 +08:00
func (c *Client) sendTrafficHTTP(ctx context.Context) error {
stats, err := c.collectLocalMetrics()
if err != nil {
stats = &collector.XrayStats{}
}
payload, _ := json.Marshal(map[string]interface{}{
"stats": stats,
})
u, err := url.Parse(c.config.MasterURL)
if err != nil {
return err
}
2026-04-10 15:25:21 +08:00
u.Path = constants.PathRemoteTraffic
2026-01-28 13:13:58 +08:00
2026-03-12 16:13:50 +08:00
req, err := c.newRequest(ctx, http.MethodPost, u.String(), payload)
2026-01-28 13:13:58 +08:00
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
log.Printf("[Agent] Sent traffic data via HTTP: %d inbounds, %d outbounds, %d users",
len(stats.Inbound), len(stats.Outbound), len(stats.User))
return nil
}
2026-04-10 15:25:21 +08:00
// 通过 HTTP POST 发送速率数据。
2026-01-28 13:13:58 +08:00
func (c *Client) sendSpeedHTTP(ctx context.Context) error {
uploadSpeed, downloadSpeed := c.collectSpeed()
payload, _ := json.Marshal(map[string]interface{}{
"upload_speed": uploadSpeed,
"download_speed": downloadSpeed,
})
u, err := url.Parse(c.config.MasterURL)
if err != nil {
return err
}
2026-04-10 15:25:21 +08:00
u.Path = constants.PathRemoteSpeed
2026-01-28 13:13:58 +08:00
2026-03-12 16:13:50 +08:00
req, err := c.newRequest(ctx, http.MethodPost, u.String(), payload)
2026-01-28 13:13:58 +08:00
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
log.Printf("[Agent] Sent speed via HTTP: ↑%d B/s ↓%d B/s", uploadSpeed, downloadSpeed)
return nil
}
2026-04-10 15:25:21 +08:00
// 通过 HTTP POST 发送心跳。
2026-01-28 13:13:58 +08:00
func (c *Client) sendHeartbeatHTTP(ctx context.Context) error {
now := time.Now()
listenPort, _ := strconv.Atoi(c.config.ListenPort)
payload, _ := json.Marshal(map[string]interface{}{
"boot_time": now,
"listen_port": listenPort,
})
u, err := url.Parse(c.config.MasterURL)
if err != nil {
return err
}
2026-04-10 15:25:21 +08:00
u.Path = constants.PathRemoteHeartbeat
2026-01-28 13:13:58 +08:00
2026-03-12 16:13:50 +08:00
req, err := c.newRequest(ctx, http.MethodPost, u.String(), payload)
2026-01-28 13:13:58 +08:00
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
return nil
}
2026-04-10 15:25:21 +08:00
// 在拉取模式下持续上报流量,保持在线状态。
2026-01-28 13:13:58 +08:00
func (c *Client) runPullModeWithTrafficReport(ctx context.Context, duration time.Duration) {
trafficTicker := time.NewTicker(c.config.TrafficReportInterval)
defer trafficTicker.Stop()
timeout := time.After(duration)
if err := c.sendTrafficHTTP(ctx); err != nil {
log.Printf("[Agent] Pull mode traffic report failed: %v", err)
}
for {
select {
case <-ctx.Done():
return
case <-c.stopCh:
return
case <-timeout:
return
case <-trafficTicker.C:
if err := c.sendTrafficHTTP(ctx); err != nil {
log.Printf("[Agent] Pull mode traffic report failed: %v", err)
}
}
}
}
2026-04-10 15:25:21 +08:00
// 在等待期间继续上报流量。
2026-01-28 13:13:58 +08:00
func (c *Client) waitWithTrafficReport(ctx context.Context, duration time.Duration) {
if duration <= 0 {
return
}
2026-04-10 15:25:21 +08:00
if duration > constants.PullModeTrafficReportThreshold {
2026-01-28 13:13:58 +08:00
if err := c.sendTrafficHTTP(ctx); err != nil {
log.Printf("[Agent] Traffic report during backoff failed: %v", err)
}
}
trafficTicker := time.NewTicker(c.config.TrafficReportInterval)
defer trafficTicker.Stop()
timeout := time.After(duration)
for {
select {
case <-ctx.Done():
return
case <-c.stopCh:
return
case <-timeout:
return
case <-trafficTicker.C:
if err := c.sendTrafficHTTP(ctx); err != nil {
log.Printf("[Agent] Traffic report during backoff failed: %v", err)
}
}
}
}
2026-04-10 15:25:21 +08:00
// 通过 WebSocket 发送速率数据。
2026-01-28 13:13:58 +08:00
func (c *Client) sendSpeedData(conn *websocket.Conn) error {
uploadSpeed, downloadSpeed := c.collectSpeed()
payload, _ := json.Marshal(map[string]interface{}{
"upload_speed": uploadSpeed,
"download_speed": downloadSpeed,
})
msg := map[string]interface{}{
"type": "speed",
"payload": json.RawMessage(payload),
}
c.wsMu.Lock()
err := conn.WriteJSON(msg)
c.wsMu.Unlock()
if err != nil {
return err
}
log.Printf("[Agent] Sent speed data: ↑%d B/s ↓%d B/s", uploadSpeed, downloadSpeed)
return nil
}
2026-04-10 15:25:21 +08:00
// 基于系统网卡统计计算当前上下行速率。
2026-01-28 13:13:58 +08:00
func (c *Client) collectSpeed() (uploadSpeed, downloadSpeed int64) {
c.speedMu.Lock()
defer c.speedMu.Unlock()
rxBytes, txBytes := c.getSystemNetworkStats()
now := time.Now()
if !c.lastSampleTime.IsZero() && c.lastRxBytes > 0 {
elapsed := now.Sub(c.lastSampleTime).Seconds()
if elapsed > 0 {
uploadSpeed = int64(float64(txBytes-c.lastTxBytes) / elapsed)
downloadSpeed = int64(float64(rxBytes-c.lastRxBytes) / elapsed)
if uploadSpeed < 0 {
uploadSpeed = 0
}
if downloadSpeed < 0 {
downloadSpeed = 0
}
}
}
c.lastRxBytes = rxBytes
c.lastTxBytes = txBytes
c.lastSampleTime = now
return uploadSpeed, downloadSpeed
}
2026-04-10 15:25:21 +08:00
// 从 /proc/net/dev 读取网卡统计。
2026-01-28 13:13:58 +08:00
func (c *Client) getSystemNetworkStats() (rxBytes, txBytes int64) {
data, err := os.ReadFile("/proc/net/dev")
if err != nil {
log.Printf("[Agent] Failed to read /proc/net/dev: %v", err)
return 0, 0
}
lines := strings.Split(string(data), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "Inter") || strings.HasPrefix(line, "face") || strings.HasPrefix(line, "lo:") {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
fields := strings.Fields(parts[1])
if len(fields) < 10 {
continue
}
rx, err1 := strconv.ParseInt(fields[0], 10, 64)
tx, err2 := strconv.ParseInt(fields[8], 10, 64)
if err1 == nil && err2 == nil {
rxBytes += rx
txBytes += tx
}
}
return rxBytes, txBytes
}
2026-04-10 15:25:21 +08:00
// AuthError 表示鉴权失败错误。
2026-01-28 13:13:58 +08:00
type AuthError struct {
Message string
2026-01-29 18:01:08 +08:00
Code string // "token_expired", "token_invalid", "server_error"
2026-01-28 13:13:58 +08:00
}
func (e *AuthError) Error() string {
return "authentication failed: " + e.Message
}
2026-04-10 15:25:21 +08:00
// 判断是否为 token 无效错误。
2026-01-29 18:01:08 +08:00
func (e *AuthError) IsTokenInvalid() bool {
return e.Code == "token_invalid" || e.Message == "Invalid token"
}
2026-04-10 15:25:21 +08:00
// WebSocket 消息类型
2026-01-28 13:13:58 +08:00
const (
2026-04-08 21:10:56 +08:00
WSMsgTypeCertDeploy = "cert_deploy"
WSMsgTypeTokenUpdate = "token_update"
WSMsgTypeScanResult = "scan_result"
WSMsgTypeDomainLatencyProbe = "domain_latency_probe"
WSMsgTypeDomainLatencyResult = "domain_latency_result"
2026-01-28 13:13:58 +08:00
)
2026-04-10 15:25:21 +08:00
// WSCertDeployPayload 是主控端下发的证书部署指令。
2026-03-12 16:13:50 +08:00
type WSCertDeployPayload struct {
Domain string `json:"domain"`
CertPEM string `json:"cert_pem"`
KeyPEM string `json:"key_pem"`
CertPath string `json:"cert_path"`
KeyPath string `json:"key_path"`
Reload string `json:"reload"`
2026-01-28 13:13:58 +08:00
}
2026-04-10 15:25:21 +08:00
// WSTokenUpdatePayload 是主控端下发的 token 更新指令。
2026-01-29 18:01:08 +08:00
type WSTokenUpdatePayload struct {
ServerToken string `json:"server_token"`
ExpiresAt time.Time `json:"expires_at"`
}
2026-04-10 15:25:21 +08:00
// WSDomainLatencyProbePayload 是主控端下发的域名延迟探测请求。
2026-04-08 21:10:56 +08:00
type WSDomainLatencyProbePayload struct {
RequestID string `json:"request_id"`
Domains []string `json:"domains"`
TimeoutMs int `json:"timeout_ms"`
}
2026-04-10 15:25:21 +08:00
// 处理主控端下发的消息。
2026-01-28 13:13:58 +08:00
func (c *Client) handleMessage(conn *websocket.Conn, message []byte) {
var msg struct {
Type string `json:"type"`
Payload json.RawMessage `json:"payload"`
}
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("[Agent] Failed to parse message: %v", err)
return
}
switch msg.Type {
2026-03-12 16:13:50 +08:00
case WSMsgTypeCertDeploy:
var payload WSCertDeployPayload
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
log.Printf("[Agent] Failed to parse cert_deploy payload: %v", err)
return
}
go c.handleCertDeploy(payload)
2026-01-29 18:01:08 +08:00
case WSMsgTypeTokenUpdate:
var payload WSTokenUpdatePayload
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
log.Printf("[Agent] Failed to parse token_update payload: %v", err)
return
}
c.handleTokenUpdate(payload)
2026-04-08 21:10:56 +08:00
case WSMsgTypeDomainLatencyProbe:
var payload WSDomainLatencyProbePayload
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
log.Printf("[Agent] Failed to parse domain_latency_probe payload: %v", err)
return
}
go c.handleDomainLatencyProbe(conn, payload)
2026-01-28 13:13:58 +08:00
default:
2026-04-10 15:25:21 +08:00
// 忽略未知消息类型
2026-01-28 13:13:58 +08:00
}
}
2026-04-10 15:25:21 +08:00
// 处理主控端下发的证书部署。
2026-03-12 17:20:16 +08:00
func (c *Client) handleCertDeploy(payload WSCertDeployPayload) {
log.Printf("[Agent] Received cert_deploy for domain: %s, target: %s", payload.Domain, payload.Reload)
2026-01-28 13:13:58 +08:00
2026-03-12 17:20:16 +08:00
if err := deployCert(payload.CertPEM, payload.KeyPEM, payload.CertPath, payload.KeyPath, payload.Reload); err != nil {
log.Printf("[Agent] cert_deploy failed for %s: %v", payload.Domain, err)
2026-01-28 13:13:58 +08:00
} else {
2026-03-12 17:20:16 +08:00
log.Printf("[Agent] cert_deploy succeeded for %s", payload.Domain)
2026-01-28 13:13:58 +08:00
}
}
2026-03-12 17:20:16 +08:00
func deployCert(certPEM, keyPEM, certPath, keyPath, reloadTarget string) error {
if certPath == "" || keyPath == "" {
return fmt.Errorf("deploy paths are required")
2026-01-28 13:13:58 +08:00
}
2026-03-12 17:20:16 +08:00
if err := os.MkdirAll(filepath.Dir(certPath), 0755); err != nil {
return fmt.Errorf("create cert dir: %w", err)
2026-01-28 13:13:58 +08:00
}
2026-03-12 17:20:16 +08:00
if err := os.MkdirAll(filepath.Dir(keyPath), 0755); err != nil {
return fmt.Errorf("create key dir: %w", err)
}
if err := os.WriteFile(certPath, []byte(certPEM), 0644); err != nil {
return fmt.Errorf("write cert: %w", err)
}
if err := os.WriteFile(keyPath, []byte(keyPEM), 0600); err != nil {
return fmt.Errorf("write key: %w", err)
2026-03-12 16:13:50 +08:00
}
2026-03-12 17:20:16 +08:00
switch reloadTarget {
case "nginx":
2026-04-07 16:35:45 +08:00
return reloadNginxCmd()
2026-03-12 17:20:16 +08:00
case "xray":
return runCmd("systemctl", "restart", "xray")
case "both":
2026-04-07 16:35:45 +08:00
if err := reloadNginxCmd(); err != nil {
2026-03-12 17:20:16 +08:00
return err
2026-03-12 16:13:50 +08:00
}
2026-03-12 17:20:16 +08:00
return runCmd("systemctl", "restart", "xray")
2026-03-12 16:13:50 +08:00
}
2026-03-12 17:20:16 +08:00
return nil
2026-01-28 13:13:58 +08:00
}
2026-01-29 18:01:08 +08:00
2026-04-07 16:35:45 +08:00
func reloadNginxCmd() error {
2026-04-10 15:25:21 +08:00
for _, bin := range constants.NginxBinarySearchPaths {
2026-04-07 16:35:45 +08:00
if path, err := exec.LookPath(bin); err == nil {
return runCmd(path, "-s", "reload")
}
}
return runCmd("systemctl", "reload", "nginx")
}
2026-03-12 17:20:16 +08:00
func runCmd(name string, args ...string) error {
if output, err := exec.Command(name, args...).CombinedOutput(); err != nil {
return fmt.Errorf("%s: %s: %w", name, string(output), err)
2026-03-12 16:13:50 +08:00
}
2026-03-12 17:20:16 +08:00
return nil
2026-03-12 16:13:50 +08:00
}
2026-04-10 15:25:21 +08:00
// 处理主控端下发的 token 更新。
2026-01-29 18:01:08 +08:00
func (c *Client) handleTokenUpdate(payload WSTokenUpdatePayload) {
log.Printf("[Agent] Received token update from master, new token expires at %s", payload.ExpiresAt.Format(time.RFC3339))
2026-04-10 15:25:21 +08:00
// 更新内存中的 token
2026-01-29 18:01:08 +08:00
c.config.Token = payload.ServerToken
log.Printf("[Agent] Token updated successfully in memory")
}
2026-04-07 16:35:45 +08:00
2026-04-10 15:25:21 +08:00
// 在本机探测域名延迟并回传结果。
2026-04-08 21:10:56 +08:00
func (c *Client) handleDomainLatencyProbe(conn *websocket.Conn, payload WSDomainLatencyProbePayload) {
log.Printf("[Agent] Received domain_latency_probe: %d domains, timeout=%dms", len(payload.Domains), payload.TimeoutMs)
timeoutMs := payload.TimeoutMs
if timeoutMs <= 0 {
2026-04-10 15:25:21 +08:00
timeoutMs = constants.DomainProbeDefaultTimeoutMS
2026-04-08 21:10:56 +08:00
}
2026-04-10 15:25:21 +08:00
if timeoutMs < constants.DomainProbeMinTimeoutMS {
timeoutMs = constants.DomainProbeMinTimeoutMS
2026-04-08 21:10:56 +08:00
}
2026-04-10 15:25:21 +08:00
if timeoutMs > constants.DomainProbeMaxTimeoutMS {
timeoutMs = constants.DomainProbeMaxTimeoutMS
2026-04-08 21:10:56 +08:00
}
timeout := time.Duration(timeoutMs) * time.Millisecond
type probeResult struct {
2026-04-10 15:25:21 +08:00
Domain string `json:"domain"`
Target string `json:"target"`
Success bool `json:"success"`
LatencyMs int64 `json:"latency_ms,omitempty"`
Error string `json:"error,omitempty"`
NginxSSLPort int `json:"nginx_ssl_port,omitempty"`
2026-04-08 21:10:56 +08:00
}
2026-04-10 15:25:21 +08:00
// 读取本机 nginx 配置,构造 domain -> ssl 端口映射
nginxPortMap := readNginxSSLPorts(payload.Domains)
2026-04-08 21:10:56 +08:00
results := make([]probeResult, 0, len(payload.Domains))
resultCh := make(chan probeResult, len(payload.Domains))
2026-04-10 15:25:21 +08:00
sem := make(chan struct{}, constants.DomainProbeConcurrency)
2026-04-08 21:10:56 +08:00
var wg sync.WaitGroup
for _, domain := range payload.Domains {
wg.Add(1)
domain := domain
go func() {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
host := domain
port := "443"
if h, p, err := net.SplitHostPort(domain); err == nil && h != "" && p != "" {
host = h
port = p
}
if host == "" {
resultCh <- probeResult{Domain: domain, Target: domain, Success: false, Error: "empty host"}
return
}
target := net.JoinHostPort(host, port)
start := time.Now()
tcpConn, err := net.DialTimeout("tcp", target, timeout)
if err != nil {
resultCh <- probeResult{Domain: host, Target: target, Success: false, Error: err.Error()}
return
}
_ = tcpConn.Close()
2026-04-10 15:25:21 +08:00
resultCh <- probeResult{Domain: host, Target: target, Success: true, LatencyMs: time.Since(start).Milliseconds(), NginxSSLPort: nginxPortMap[host]}
2026-04-08 21:10:56 +08:00
}()
}
wg.Wait()
close(resultCh)
for r := range resultCh {
results = append(results, r)
}
2026-04-10 15:25:21 +08:00
// 排序:成功优先,再按延迟升序
2026-04-08 21:10:56 +08:00
sort.Slice(results, func(i, j int) bool {
if results[i].Success != results[j].Success {
return results[i].Success
}
if !results[i].Success {
return results[i].Domain < results[j].Domain
}
if results[i].LatencyMs == results[j].LatencyMs {
return results[i].Domain < results[j].Domain
}
return results[i].LatencyMs < results[j].LatencyMs
})
response := map[string]any{
"request_id": payload.RequestID,
"success": true,
"results": results,
}
respBytes, err := json.Marshal(response)
if err != nil {
log.Printf("[Agent] Failed to marshal domain_latency_result: %v", err)
return
}
msg := map[string]any{
"type": WSMsgTypeDomainLatencyResult,
"payload": json.RawMessage(respBytes),
}
msgBytes, err := json.Marshal(msg)
if err != nil {
log.Printf("[Agent] Failed to marshal WS message: %v", err)
return
}
c.wsMu.Lock()
err = conn.WriteMessage(websocket.TextMessage, msgBytes)
c.wsMu.Unlock()
if err != nil {
log.Printf("[Agent] Failed to send domain_latency_result: %v", err)
return
}
log.Printf("[Agent] Sent domain_latency_result: %d results", len(results))
}
2026-04-10 15:25:21 +08:00
// readNginxSSLPorts 读取 nginx 配置并返回 domain -> SSL 端口映射。
// 会在常见 nginx 配置目录下查找 servers/{domain}.conf。
func readNginxSSLPorts(domains []string) map[string]int {
result := make(map[string]int)
if len(domains) == 0 {
return result
}
confDirs := constants.NginxSSLServerDirPaths
for _, domain := range domains {
host := domain
if h, _, err := net.SplitHostPort(domain); err == nil && h != "" {
host = h
}
for _, dir := range confDirs {
confPath := filepath.Join(dir, host+".conf")
data, err := os.ReadFile(confPath)
if err != nil {
continue
}
if port := extractSSLListenPort(string(data)); port > 0 {
result[host] = port
break
}
}
}
return result
}
// 提取 nginx 配置块中第一个 "listen <port> ssl" 端口。
func extractSSLListenPort(conf string) int {
// 匹配示例: listen 58443 ssl
for _, line := range strings.Split(conf, "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "listen") {
continue
}
// 去掉 "listen" 前缀和结尾分号
rest := strings.TrimPrefix(line, "listen")
rest = strings.TrimRight(rest, ";")
rest = strings.TrimSpace(rest)
fields := strings.Fields(rest)
if len(fields) < 2 {
continue
}
// 判断字段中是否包含 "ssl"
hasSSL := false
for _, f := range fields[1:] {
if f == "ssl" {
hasSSL = true
break
}
}
if !hasSSL {
continue
}
// 第一个字段是端口(或 [::]:port)
portStr := fields[0]
// 兼容 [::]:port 形式
if idx := strings.LastIndex(portStr, ":"); idx >= 0 {
portStr = portStr[idx+1:]
}
port, err := strconv.Atoi(portStr)
if err == nil && port > 0 {
return port
}
}
return 0
}
// 扫描本机 xray 状态并上报主控端。
2026-04-07 16:35:45 +08:00
func (c *Client) sendScanResult(conn *websocket.Conn) {
2026-04-10 15:25:21 +08:00
// 检查 xray 运行状态
2026-04-07 16:35:45 +08:00
xrayRunning := false
xrayVersion := ""
cmd := exec.Command("xray", "version")
if out, err := cmd.Output(); err == nil {
xrayVersion = strings.TrimSpace(strings.Split(string(out), "\n")[0])
}
if exec.Command("systemctl", "is-active", "--quiet", "xray").Run() == nil {
xrayRunning = true
}
2026-04-10 15:25:21 +08:00
// 从配置读取入站列表
2026-04-07 16:35:45 +08:00
var inbounds []map[string]interface{}
2026-04-10 15:25:21 +08:00
configPaths := constants.DefaultXrayConfigPaths
2026-04-07 16:35:45 +08:00
for _, cfgPath := range configPaths {
data, err := os.ReadFile(cfgPath)
if err != nil {
continue
}
var config map[string]interface{}
if json.Unmarshal(data, &config) != nil {
continue
}
if ibs, ok := config["inbounds"].([]interface{}); ok {
for _, ib := range ibs {
if m, ok := ib.(map[string]interface{}); ok {
if tag, _ := m["tag"].(string); tag == "api" {
continue
}
inbounds = append(inbounds, m)
}
}
}
break
}
payload, _ := json.Marshal(map[string]interface{}{
"xray_running": xrayRunning,
"xray_version": xrayVersion,
"inbounds": inbounds,
})
msg := map[string]interface{}{
"type": WSMsgTypeScanResult,
"payload": json.RawMessage(payload),
}
c.wsMu.Lock()
err := conn.WriteJSON(msg)
c.wsMu.Unlock()
if err != nil {
log.Printf("[Agent] Failed to send scan_result: %v", err)
return
}
log.Printf("[Agent] Sent scan_result: xray_running=%v, inbounds=%d", xrayRunning, len(inbounds))
}