diff --git a/internal/agent/client.go b/internal/agent/client.go index 15ff871..7d767c7 100644 --- a/internal/agent/client.go +++ b/internal/agent/client.go @@ -7,11 +7,13 @@ import ( "fmt" "io" "log" + "net" "net/http" "net/url" "os" "os/exec" "path/filepath" + "sort" "strconv" "strings" "sync" @@ -951,9 +953,11 @@ func (e *AuthError) IsTokenInvalid() bool { // WebSocket message types const ( - WSMsgTypeCertDeploy = "cert_deploy" - WSMsgTypeTokenUpdate = "token_update" - WSMsgTypeScanResult = "scan_result" + WSMsgTypeCertDeploy = "cert_deploy" + WSMsgTypeTokenUpdate = "token_update" + WSMsgTypeScanResult = "scan_result" + WSMsgTypeDomainLatencyProbe = "domain_latency_probe" + WSMsgTypeDomainLatencyResult = "domain_latency_result" ) // WSCertDeployPayload represents a certificate deploy command from master @@ -972,6 +976,13 @@ type WSTokenUpdatePayload struct { ExpiresAt time.Time `json:"expires_at"` } +// WSDomainLatencyProbePayload is received from master +type WSDomainLatencyProbePayload struct { + RequestID string `json:"request_id"` + Domains []string `json:"domains"` + TimeoutMs int `json:"timeout_ms"` +} + // handleMessage processes incoming messages from master func (c *Client) handleMessage(conn *websocket.Conn, message []byte) { var msg struct { @@ -999,6 +1010,13 @@ func (c *Client) handleMessage(conn *websocket.Conn, message []byte) { return } c.handleTokenUpdate(payload) + case WSMsgTypeDomainLatencyProbe: + var payload WSDomainLatencyProbePayload + if err := json.Unmarshal(msg.Payload, &payload); err != nil { + log.Printf("[Agent] Failed to parse domain_latency_probe payload: %v", err) + return + } + go c.handleDomainLatencyProbe(conn, payload) default: // Ignore unknown message types } @@ -1072,6 +1090,117 @@ func (c *Client) handleTokenUpdate(payload WSTokenUpdatePayload) { log.Printf("[Agent] Token updated successfully in memory") } +// handleDomainLatencyProbe probes domain latency locally and sends results back via WebSocket +func (c *Client) handleDomainLatencyProbe(conn *websocket.Conn, payload WSDomainLatencyProbePayload) { + log.Printf("[Agent] Received domain_latency_probe: %d domains, timeout=%dms", len(payload.Domains), payload.TimeoutMs) + + timeoutMs := payload.TimeoutMs + if timeoutMs <= 0 { + timeoutMs = 2000 + } + if timeoutMs < 200 { + timeoutMs = 200 + } + if timeoutMs > 10000 { + timeoutMs = 10000 + } + timeout := time.Duration(timeoutMs) * time.Millisecond + + type probeResult struct { + Domain string `json:"domain"` + Target string `json:"target"` + Success bool `json:"success"` + LatencyMs int64 `json:"latency_ms,omitempty"` + Error string `json:"error,omitempty"` + } + + results := make([]probeResult, 0, len(payload.Domains)) + resultCh := make(chan probeResult, len(payload.Domains)) + sem := make(chan struct{}, 16) + var wg sync.WaitGroup + + for _, domain := range payload.Domains { + wg.Add(1) + domain := domain + go func() { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + + host := domain + port := "443" + if h, p, err := net.SplitHostPort(domain); err == nil && h != "" && p != "" { + host = h + port = p + } + if host == "" { + resultCh <- probeResult{Domain: domain, Target: domain, Success: false, Error: "empty host"} + return + } + target := net.JoinHostPort(host, port) + start := time.Now() + tcpConn, err := net.DialTimeout("tcp", target, timeout) + if err != nil { + resultCh <- probeResult{Domain: host, Target: target, Success: false, Error: err.Error()} + return + } + _ = tcpConn.Close() + resultCh <- probeResult{Domain: host, Target: target, Success: true, LatencyMs: time.Since(start).Milliseconds()} + }() + } + + wg.Wait() + close(resultCh) + for r := range resultCh { + results = append(results, r) + } + + // Sort: success first, then by latency + sort.Slice(results, func(i, j int) bool { + if results[i].Success != results[j].Success { + return results[i].Success + } + if !results[i].Success { + return results[i].Domain < results[j].Domain + } + if results[i].LatencyMs == results[j].LatencyMs { + return results[i].Domain < results[j].Domain + } + return results[i].LatencyMs < results[j].LatencyMs + }) + + response := map[string]any{ + "request_id": payload.RequestID, + "success": true, + "results": results, + } + respBytes, err := json.Marshal(response) + if err != nil { + log.Printf("[Agent] Failed to marshal domain_latency_result: %v", err) + return + } + + msg := map[string]any{ + "type": WSMsgTypeDomainLatencyResult, + "payload": json.RawMessage(respBytes), + } + msgBytes, err := json.Marshal(msg) + if err != nil { + log.Printf("[Agent] Failed to marshal WS message: %v", err) + return + } + + c.wsMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, msgBytes) + c.wsMu.Unlock() + if err != nil { + log.Printf("[Agent] Failed to send domain_latency_result: %v", err) + return + } + + log.Printf("[Agent] Sent domain_latency_result: %d results", len(results)) +} + // sendScanResult scans local xray status and sends results to master func (c *Client) sendScanResult(conn *websocket.Conn) { // Check xray running status