aboutsummaryrefslogtreecommitdiff
path: root/pkg/origin/authmgr.go
blob: 2fb1d55c0d7f8352ff355a636ee65997af1145f0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
package origin

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"sync"
	"time"

	"github.com/cardigann/harhar"
	"github.com/r2northstar/atlas/pkg/juno"
)

var ErrAuthMgrBackoff = errors.New("not refreshing token due to backoff")

// AuthMgr manages Origin NucleusTokens. It is efficient and safe for concurrent
// use.
//
// For persistence, load the credentials on startup using SetAuth, and store the
// credentials using Updated.
type AuthMgr struct {
	// Timeout is the timeout for refreshing tokens. If zero, a reasonable
	// default is used. If negative, there is no timeout.
	Timeout time.Duration

	// Updated, if provided, is called in a new goroutine when tokens have
	// changed. AuthState is always set and should be saved, even if an error
	// occured.
	Updated func(AuthState, error)

	// Credentials, if provided, is called to get credentials when updating the
	// SID.
	Credentials func() (email, password, otpsecret string, err error)

	// Backoff, if provided, checks if another refresh is allowed after a
	// failure. If it returns false, ErrAuthMgrBackoff will be returned
	// immediately from OriginAuth.
	Backoff func(err error, time time.Time, count int) bool

	// SaveHAR, if provided, is called after every attempt to authenticate.
	SaveHAR func(func(w io.Writer) error, error)

	authInit     sync.Once
	authPf       bool       // ensures only one update runs at a time
	authCv       *sync.Cond // allows other goroutines to wait for that update to complete
	authErr      error      // last auth error
	authErrTime  time.Time  // last auth error time
	authErrCount int        // consecutive auth errors
	auth         AuthState  // current auth tokens
}

// AuthState contains the current authentication tokens.
type AuthState struct {
	SID                juno.SID     `json:"sid,omitempty"`
	NucleusToken       NucleusToken `json:"nucleus_token,omitempty"`
	NucleusTokenExpiry time.Time    `json:"nucleus_token_expiry,omitempty"`
}

func (a *AuthMgr) init() {
	a.authInit.Do(func() {
		a.authCv = sync.NewCond(new(sync.Mutex))
	})
}

// SetAuth sets the current Origin credentials. If authentication is in
// progress, it will block.
func (a *AuthMgr) SetAuth(auth AuthState) {
	a.init()
	a.authCv.L.Lock()
	for a.authPf {
		a.authCv.Wait()
	}
	a.auth = auth
	a.authErr = nil
	a.authCv.L.Unlock()
}

// OriginAuth gets the current NucleusToken. If refresh is true or the nucleus
// token is missing/expired, it generates a new NucleusToken, getting a new SID
// if required. If another refresh is in progress, it waits for the result of
// it. True is returned (on success or failure) if this call performed a
// refresh. This function may block for up to Timeout.
//
// In general, OriginAuth(false) should be used first, then if an API call error
// is ErrAuthRequired, try it again with the token from OriginAuth(true).
func (a *AuthMgr) OriginAuth(refresh bool) (NucleusToken, bool, error) {
	a.init()
	if a.authCv.L.Lock(); a.authPf {
		for a.authPf {
			a.authCv.Wait()
		}
		defer a.authCv.L.Unlock()
		return a.auth.NucleusToken, false, a.authErr
	} else {
		if refresh || a.auth.NucleusToken == "" || !time.Now().Before(a.auth.NucleusTokenExpiry) {
			a.authPf = true
			a.authCv.L.Unlock()
			defer func() {
				a.authCv.L.Lock()
				a.authCv.Broadcast()
				a.authPf = false
				a.authCv.L.Unlock()
			}()
		} else {
			defer a.authCv.L.Unlock()
			return a.auth.NucleusToken, false, a.authErr
		}
	}
	if a.authErr != nil && a.Backoff != nil {
		if !a.Backoff(a.authErr, a.authErrTime, a.authErrCount) {
			return a.auth.NucleusToken, true, fmt.Errorf("%w (%d attempts, last error: %v)", ErrAuthMgrBackoff, a.authErrCount, a.authErrCount)
		}
	}
	a.authErr = func() (err error) {
		t := http.DefaultClient.Transport
		if t == nil {
			t = http.DefaultTransport
		}
		if a.SaveHAR != nil {
			rec := harhar.NewRecorder()
			rec.RoundTripper, t = t, rec
			defer func() {
				go a.SaveHAR(func(w io.Writer) error {
					return json.NewEncoder(w).Encode(rec.HAR)
				}, err)
			}()
		}

		defer func() {
			if p := recover(); p != nil {
				err = fmt.Errorf("panic: %v", p)
			}
		}()

		var ctx context.Context
		var cancel context.CancelFunc
		if a.Timeout > 0 {
			ctx, cancel = context.WithTimeout(context.Background(), a.Timeout)
		} else if a.Timeout == 0 {
			ctx, cancel = context.WithTimeout(context.Background(), time.Second*15)
		} else {
			ctx, cancel = context.WithCancel(context.Background())
		}
		defer cancel()

		if a.auth.SID != "" {
			if tok, exp, aerr := GetNucleusToken(ctx, t, a.auth.SID); aerr == nil {
				a.auth.NucleusToken = tok
				a.auth.NucleusTokenExpiry = exp
				return
			} else if !errors.Is(aerr, ErrAuthRequired) {
				err = fmt.Errorf("refresh nucleus token: %w", aerr)
				return
			}
		}
		if a.Credentials == nil {
			err = fmt.Errorf("no origin credentials to refresh sid with")
			return
		} else if email, password, otpsecret, aerr := a.Credentials(); aerr != nil {
			err = fmt.Errorf("get origin credentials: %w", aerr)
			return
		} else if res, aerr := juno.Login(ctx, t, email, password, otpsecret); aerr != nil {
			err = fmt.Errorf("refresh sid: %w", aerr)
			return
		} else {
			a.auth.SID = res.SID
		}
		if tok, exp, aerr := GetNucleusToken(ctx, t, a.auth.SID); aerr != nil {
			err = fmt.Errorf("refresh nucleus token with new sid: %w", aerr)
		} else {
			a.auth.NucleusToken = tok
			a.auth.NucleusTokenExpiry = exp
		}
		return
	}()
	if a.authErrCount != 0 {
		a.authErr = fmt.Errorf("%w (attempt %d)", a.authErr, a.authErrCount)
	}
	if a.authErr != nil {
		a.authErrCount++
		a.authErrTime = time.Now()
	} else {
		a.authErrCount = 0
		a.authErrTime = time.Time{}
	}
	if a.Updated != nil {
		go a.Updated(a.auth, a.authErr)
	}
	return a.auth.NucleusToken, true, a.authErr
}