diff options
Diffstat (limited to 'cmd/r2-a2s-probe')
-rw-r--r-- | cmd/r2-a2s-probe/main.go | 65 |
1 files changed, 61 insertions, 4 deletions
diff --git a/cmd/r2-a2s-probe/main.go b/cmd/r2-a2s-probe/main.go index dd38b7d..9e60edc 100644 --- a/cmd/r2-a2s-probe/main.go +++ b/cmd/r2-a2s-probe/main.go @@ -2,25 +2,32 @@ package main import ( + "context" "fmt" + "math/rand" + "net" "net/netip" "os" "sync" "time" - "github.com/r2northstar/atlas/pkg/a2s" + "github.com/r2northstar/atlas/pkg/nspkt" "github.com/spf13/pflag" ) var opt struct { + Addr string Connections int Timeout time.Duration + Interval time.Duration Silent bool Help bool } func init() { + pflag.StringVarP(&opt.Addr, "listen", "a", "[::]:0", "UDP listen address") pflag.DurationVarP(&opt.Timeout, "timeout", "t", time.Second*3, "Amount of time to wait for a response") + pflag.DurationVarP(&opt.Interval, "interval", "i", time.Second, "Interval to send packets at") pflag.IntVarP(&opt.Connections, "connections", "c", 1, "Number of concurrent connections") pflag.BoolVarP(&opt.Silent, "silent", "s", false, "Don't show the result") pflag.BoolVarP(&opt.Help, "help", "h", false, "Show this help text") @@ -37,8 +44,9 @@ func main() { os.Exit(0) } - if opt.Connections < 1 { - fmt.Fprintf(os.Stderr, "fatal: --connections must be at least 1\n") + uaddr, err := netip.ParseAddrPort(opt.Addr) + if err != nil { + fmt.Fprintf(os.Stderr, "fatal: invalid udp listen address: %v\n", err) os.Exit(2) } @@ -48,6 +56,15 @@ func main() { os.Exit(2) } + conn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(uaddr)) + if err != nil { + fmt.Fprintf(os.Stderr, "fatal: %v\n", err) + os.Exit(2) + } + + l := nspkt.NewListener() + go l.Serve(conn) + queue := make(chan int) go func() { defer close(queue) @@ -66,9 +83,12 @@ func main() { for n := 0; n < opt.Connections; n++ { wg.Add(1) go func() { + ctx, cancel := context.WithTimeout(context.Background(), opt.Timeout) + defer cancel() + defer wg.Done() for i := range queue { - res <- Result{i, a2s.Probe(addr[i], opt.Timeout)} + res <- Result{i, probe(ctx, addr[i], l)} } }() } @@ -95,6 +115,43 @@ func main() { } } +func probe(ctx context.Context, addr netip.AddrPort, l *nspkt.Listener) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + uid := rand.Uint64() + + x := make(chan error, 1) + go func() { + t := time.NewTicker(opt.Interval) + defer t.Stop() + + for { + if err := l.SendConnect(addr, uid); err != nil { + select { + case x <- err: + default: + } + } + select { + case <-ctx.Done(): + return + case <-t.C: + } + } + }() + + err := l.WaitConnectReply(ctx, addr, uid) + if err != nil { + select { + case err = <-x: + // error could be due to an issue sending the packet + default: + } + } + return err +} + func parseAddrPorts(a []string) ([]netip.AddrPort, error) { r := make([]netip.AddrPort, len(a)) for i, x := range a { |