aboutsummaryrefslogtreecommitdiff
path: root/cmd/r2-a2s-probe/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/r2-a2s-probe/main.go')
-rw-r--r--cmd/r2-a2s-probe/main.go65
1 files changed, 61 insertions, 4 deletions
diff --git a/cmd/r2-a2s-probe/main.go b/cmd/r2-a2s-probe/main.go
index dd38b7d..9e60edc 100644
--- a/cmd/r2-a2s-probe/main.go
+++ b/cmd/r2-a2s-probe/main.go
@@ -2,25 +2,32 @@
package main
import (
+ "context"
"fmt"
+ "math/rand"
+ "net"
"net/netip"
"os"
"sync"
"time"
- "github.com/r2northstar/atlas/pkg/a2s"
+ "github.com/r2northstar/atlas/pkg/nspkt"
"github.com/spf13/pflag"
)
var opt struct {
+ Addr string
Connections int
Timeout time.Duration
+ Interval time.Duration
Silent bool
Help bool
}
func init() {
+ pflag.StringVarP(&opt.Addr, "listen", "a", "[::]:0", "UDP listen address")
pflag.DurationVarP(&opt.Timeout, "timeout", "t", time.Second*3, "Amount of time to wait for a response")
+ pflag.DurationVarP(&opt.Interval, "interval", "i", time.Second, "Interval to send packets at")
pflag.IntVarP(&opt.Connections, "connections", "c", 1, "Number of concurrent connections")
pflag.BoolVarP(&opt.Silent, "silent", "s", false, "Don't show the result")
pflag.BoolVarP(&opt.Help, "help", "h", false, "Show this help text")
@@ -37,8 +44,9 @@ func main() {
os.Exit(0)
}
- if opt.Connections < 1 {
- fmt.Fprintf(os.Stderr, "fatal: --connections must be at least 1\n")
+ uaddr, err := netip.ParseAddrPort(opt.Addr)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "fatal: invalid udp listen address: %v\n", err)
os.Exit(2)
}
@@ -48,6 +56,15 @@ func main() {
os.Exit(2)
}
+ conn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(uaddr))
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "fatal: %v\n", err)
+ os.Exit(2)
+ }
+
+ l := nspkt.NewListener()
+ go l.Serve(conn)
+
queue := make(chan int)
go func() {
defer close(queue)
@@ -66,9 +83,12 @@ func main() {
for n := 0; n < opt.Connections; n++ {
wg.Add(1)
go func() {
+ ctx, cancel := context.WithTimeout(context.Background(), opt.Timeout)
+ defer cancel()
+
defer wg.Done()
for i := range queue {
- res <- Result{i, a2s.Probe(addr[i], opt.Timeout)}
+ res <- Result{i, probe(ctx, addr[i], l)}
}
}()
}
@@ -95,6 +115,43 @@ func main() {
}
}
+func probe(ctx context.Context, addr netip.AddrPort, l *nspkt.Listener) error {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ uid := rand.Uint64()
+
+ x := make(chan error, 1)
+ go func() {
+ t := time.NewTicker(opt.Interval)
+ defer t.Stop()
+
+ for {
+ if err := l.SendConnect(addr, uid); err != nil {
+ select {
+ case x <- err:
+ default:
+ }
+ }
+ select {
+ case <-ctx.Done():
+ return
+ case <-t.C:
+ }
+ }
+ }()
+
+ err := l.WaitConnectReply(ctx, addr, uid)
+ if err != nil {
+ select {
+ case err = <-x:
+ // error could be due to an issue sending the packet
+ default:
+ }
+ }
+ return err
+}
+
func parseAddrPorts(a []string) ([]netip.AddrPort, error) {
r := make([]netip.AddrPort, len(a))
for i, x := range a {