aboutsummaryrefslogtreecommitdiff
path: root/pkg/origin
diff options
context:
space:
mode:
authorpg9182 <96569817+pg9182@users.noreply.github.com>2022-10-14 14:18:28 -0400
committerpg9182 <96569817+pg9182@users.noreply.github.com>2022-10-14 14:18:28 -0400
commit33a1bc3047de6a3f3184171a5e73abd2639bd261 (patch)
tree87f75cb0a41676080dd1aec1d2d1a5ec393bdda0 /pkg/origin
parent01155bde059e93700ea4b8bd47e6835b8232f57c (diff)
downloadAtlas-33a1bc3047de6a3f3184171a5e73abd2639bd261.tar.gz
Atlas-33a1bc3047de6a3f3184171a5e73abd2639bd261.zip
pkg/origin: Implement AuthMgr
Diffstat (limited to 'pkg/origin')
-rw-r--r--pkg/origin/authmgr.go136
1 files changed, 136 insertions, 0 deletions
diff --git a/pkg/origin/authmgr.go b/pkg/origin/authmgr.go
new file mode 100644
index 0000000..bd2f14e
--- /dev/null
+++ b/pkg/origin/authmgr.go
@@ -0,0 +1,136 @@
+package origin
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+)
+
+// AuthMgr manages Origin NucleusTokens. It is efficient and safe for concurrent
+// use.
+//
+// For persistence, load the credentials on startup using SetAuth, and store the
+// credentials using Updated.
+type AuthMgr struct {
+ // Timeout is the timeout for refreshing tokens. If zero, a reasonable
+ // default is used. If negative, there is no timeout.
+ Timeout time.Duration
+
+ // Updated, if provided, is called in a new goroutine when tokens have
+ // changed. AuthState is always set and should be saved, even if an error
+ // occured.
+ Updated func(AuthState, error)
+
+ // Credentials, if provided, is called to get credentials when updating the
+ // SID.
+ Credentials func() (email, password string, err error)
+
+ authMu sync.Mutex // guards authWg so only one goroutine can get it
+ authWg sync.WaitGroup // guards auth/authErr and allows waiting for updates
+ authErr error // stores the last auth error
+ auth AuthState // stores the current auth tokens
+}
+
+// AuthState contains the current authentication tokens.
+type AuthState struct {
+ SID SID `json:"sid,omitempty"`
+ NucleusToken NucleusToken `json:"nucleus_token,omitempty"`
+ NucleusTokenExpiry time.Time `json:"nucleus_token_expiry,omitempty"`
+}
+
+// SetAuth sets the current Origin credentials. If authentication is in
+// progress, it will block.
+func (a *AuthMgr) SetAuth(auth AuthState) {
+ a.authMu.Lock()
+ defer a.authMu.Unlock()
+ a.authWg.Add(1)
+ defer a.authWg.Done()
+ a.auth = auth
+ a.authErr = nil
+}
+
+// OriginAuth gets the current NucleusToken. If refresh is true or the nucleus
+// token is missing/expired, it generates a new NucleusToken, getting a new SID
+// if required. If another refresh is in progress, it waits for the result of
+// it. True is returned (on success or failure) if this call performed a
+// refresh. This function may block for up to Timeout.
+//
+// In general, OriginAuth(false) should be used first, then if an API call error
+// is ErrAuthRequired, try it again with the token from OriginAuth(true).
+func (a *AuthMgr) OriginAuth(refresh bool) (NucleusToken, bool, error) {
+ if a.auth.NucleusToken == "" || !time.Now().Before(a.auth.NucleusTokenExpiry) {
+ refresh = true
+ }
+ if !refresh {
+ // wait for an in-progress auth, if any, to complete
+ a.authWg.Wait()
+ return a.auth.NucleusToken, false, a.authErr
+ }
+ if a.authMu.TryLock() {
+ // refresh the auth
+ defer a.authMu.Unlock()
+ // if another goroutine gets scheduled in between us locking and adding
+ // to the waitgroup, they'll get outdated auth, but it isn't a big deal
+ // since if they try to refresh it right after, they'll end up waiting
+ // on us to complete
+ a.authWg.Add(1)
+ defer a.authWg.Done()
+ } else {
+ // another goroutine is refreshing
+ return a.OriginAuth(false)
+ }
+ a.authErr = func() (err error) {
+ defer func() {
+ if p := recover(); p != nil {
+ err = fmt.Errorf("panic: %v", p)
+ }
+ }()
+
+ var ctx context.Context
+ var cancel context.CancelFunc
+ if a.Timeout > 0 {
+ ctx, cancel = context.WithTimeout(context.Background(), a.Timeout)
+ } else if a.Timeout == 0 {
+ ctx, cancel = context.WithTimeout(context.Background(), time.Second*15)
+ } else {
+ ctx, cancel = context.WithCancel(context.Background())
+ }
+ defer cancel()
+
+ if a.auth.SID != "" {
+ if tok, exp, aerr := GetNucleusToken(ctx, a.auth.SID); aerr == nil {
+ a.auth.NucleusToken = tok
+ a.auth.NucleusTokenExpiry = exp
+ return
+ } else if !errors.Is(err, ErrAuthRequired) {
+ err = fmt.Errorf("refresh nucleus token: %w", aerr)
+ return
+ }
+ }
+ if a.Credentials == nil {
+ err = fmt.Errorf("no origin credentials to refresh sid with")
+ return
+ } else if email, password, aerr := a.Credentials(); aerr != nil {
+ err = fmt.Errorf("get origin credentials: %w", aerr)
+ return
+ } else if sid, aerr := Login(ctx, email, password); aerr != nil {
+ err = fmt.Errorf("refresh sid: %w", aerr)
+ return
+ } else {
+ a.auth.SID = sid
+ }
+ if tok, exp, aerr := GetNucleusToken(ctx, a.auth.SID); aerr != nil {
+ err = fmt.Errorf("refresh nucleus token with new sid: %w", aerr)
+ } else {
+ a.auth.NucleusToken = tok
+ a.auth.NucleusTokenExpiry = exp
+ }
+ return
+ }()
+ if a.Updated != nil {
+ go a.Updated(a.auth, a.authErr)
+ }
+ return a.auth.NucleusToken, true, a.authErr
+}