diff options
-rw-r--r-- | pkg/origin/origin.go | 73 | ||||
-rw-r--r-- | pkg/origin/origin_test.go | 99 |
2 files changed, 142 insertions, 30 deletions
diff --git a/pkg/origin/origin.go b/pkg/origin/origin.go index 2638f3e..c354e27 100644 --- a/pkg/origin/origin.go +++ b/pkg/origin/origin.go @@ -5,14 +5,22 @@ package origin import ( "context" "encoding/xml" + "errors" "fmt" "io" + "mime" "net/http" "strconv" "strings" "sync/atomic" ) +var ( + ErrInvalidResponse = errors.New("invalid origin api response") + ErrOrigin = errors.New("origin api error") + ErrAuthRequired = errors.New("origin authentication required") +) + type SIDStore interface { GetSID(ctx context.Context) (string, error) SetSID(ctx context.Context, sid string) error @@ -96,41 +104,66 @@ func (c *Client) getUserInfo(retry bool, ctx context.Context, uid ...int) ([]Use } defer resp.Body.Close() - if needAuth, err := checkResponse(resp); err != nil { - if retry && needAuth { - if err := c.Login(ctx); err != nil { - return nil, err - } - return c.getUserInfo(false, ctx, uid...) - } + buf, root, err := checkResponseXML(resp) + if err != nil { return nil, err } - return parseUserInfo(resp.Body) + return parseUserInfo(buf, root) } -func checkResponse(resp *http.Response) (bool, error) { - // TODO: return true and err for auth required - if resp.StatusCode != http.StatusOK { - return false, fmt.Errorf("response status %q", resp.Status) +func checkResponseXML(resp *http.Response) ([]byte, xml.Name, error) { + var root xml.Name + buf, err := io.ReadAll(resp.Body) + if err != nil { + return buf, root, err + } + if mt, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")); mt != "application/xml" && mt != "text/xml" { + if resp.StatusCode != http.StatusOK { + return buf, root, fmt.Errorf("%w: response status %d (%s)", ErrOrigin, resp.StatusCode, resp.Status) + } + return buf, root, fmt.Errorf("%w: expected xml, got %q", ErrOrigin, mt) + } + if err := xml.Unmarshal(buf, &root); err != nil { + return buf, root, fmt.Errorf("%w: invalid xml: %v", ErrInvalidResponse, err) } - return false, nil + if root.Local == "error" { + var obj struct { + Code int `xml:"code,attr"` + Failure []struct { + Field string `xml:"field,attr"` + Cause string `xml:"cause,attr"` + Value string `xml:"value,attr"` + } `xml:"failure"` + } + if err := xml.Unmarshal(buf, &obj); err != nil { + return buf, root, fmt.Errorf("%w: response %#q (unmarshal: %v)", ErrOrigin, string(buf), err) + } + for _, f := range obj.Failure { + if f.Cause == "invalid_token" { + return buf, root, fmt.Errorf("%w: invalid token", ErrAuthRequired) + } + } + if len(obj.Failure) == 1 { + return buf, root, fmt.Errorf("%w: error %d: %s (%s) %q", ErrOrigin, obj.Code, obj.Failure[0].Cause, obj.Failure[0].Field, obj.Failure[0].Value) + } + return buf, root, fmt.Errorf("%w: error %d: response %#q", ErrOrigin, obj.Code, string(buf)) + } + return buf, root, nil } -func parseUserInfo(r io.Reader) ([]UserInfo, error) { +func parseUserInfo(buf []byte, root xml.Name) ([]UserInfo, error) { var obj struct { - XMLName xml.Name `xml:"users"` - User []struct { + User []struct { UserID string `xml:"userId"` PersonaID string `xml:"personaId"` EAID string `xml:"EAID"` } `xml:"user"` } - buf, err := io.ReadAll(r) - if err != nil { - return nil, err + if root.Local != "users" { + return nil, fmt.Errorf("%w: unexpected %s response", ErrInvalidResponse, root.Local) } if err := xml.Unmarshal(buf, &obj); err != nil { - return nil, err + return nil, fmt.Errorf("%w: invalid xml: %v", ErrInvalidResponse, err) } res := make([]UserInfo, len(obj.User)) for i, x := range obj.User { diff --git a/pkg/origin/origin_test.go b/pkg/origin/origin_test.go index ac06d82..7735355 100644 --- a/pkg/origin/origin_test.go +++ b/pkg/origin/origin_test.go @@ -1,20 +1,99 @@ package origin import ( + "errors" + "io" + "net/http" "reflect" + "strconv" "strings" "testing" ) func TestUserInfoResponse(t *testing.T) { - ui, err := parseUserInfo(strings.NewReader(`<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?><users><user><userId>1001111111111</userId><personaId>1001111111111</personaId><EAID>test</EAID></user><user><userId>1001111111112</userId><personaId>1001111111112</personaId><EAID>test1</EAID></user></users>`)) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(ui, []UserInfo{ - {UserID: 1001111111111, PersonaID: "1001111111111", EAID: "test"}, - {UserID: 1001111111112, PersonaID: "1001111111112", EAID: "test1"}, - }) { - t.Errorf("unexpected result %#v", ui) - } + testUserInfoResponse(t, + "SuccessNew", + 200, "text/xml", `<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?><users><user><userId>1001111111111</userId><personaId>1001111111111</personaId><EAID>test</EAID></user><user><userId>1001111111112</userId><personaId>1001111111112</personaId><EAID>test1</EAID></user></users>`, + []UserInfo{ + {UserID: 1001111111111, PersonaID: "1001111111111", EAID: "test"}, + {UserID: 1001111111112, PersonaID: "1001111111112", EAID: "test1"}, + }, nil, + ) + testUserInfoResponse(t, + "SuccessOld", + 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><users><user><userId>2291234567</userId><personaId>328123456</personaId><EAID>blahblah</EAID></user></users>`, + []UserInfo{ + {UserID: 2291234567, PersonaID: "328123456", EAID: "blahblah"}, + }, nil, + ) + testUserInfoResponse(t, + "EmptyToken", + 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><error code="10044"><failure value="" field="authToken" cause="MISSING_AUTHTOKEN"/></error>`, + nil, ErrOrigin, + ) + testUserInfoResponse(t, + "InvalidExpiredToken", + 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><error code="10044"><failure value="" field="authToken" cause="invalid_token"/></error>`, + nil, ErrAuthRequired, + ) + testUserInfoResponse(t, + "FakeWrongRootElement", + 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><fake/></error>`, + nil, ErrInvalidResponse, + ) + testUserInfoResponse(t, + "FakeError", + 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><error code="12345"><failure value="" field="dummy" cause="fake"/></error>`, + nil, ErrOrigin, + ) + testUserInfoResponse(t, + "FakeBadResponse", + 500, "text/plain", `Fake Internal Server Error`, + nil, ErrOrigin, + ) + testUserInfoResponse(t, + "FakeInvalidXML", + 200, "text/xml", `fake`, + nil, ErrInvalidResponse, + ) +} + +func testUserInfoResponse(t *testing.T, name string, status int, mime, xml string, v []UserInfo, err error) { + t.Run(name, func(t *testing.T) { + buf, root, err1 := checkResponseXML(&http.Response{ + Status: strconv.Itoa(status) + " " + http.StatusText(status), + StatusCode: status, + Body: io.NopCloser(strings.NewReader(xml)), + Header: http.Header{ + "Content-Type": {mime}, + }, + }) + if err1 != nil { + if err == nil { + t.Fatalf("expected no error, got %q", err1) + } + if !errors.Is(err1, err) { + t.Fatalf("expected error %q, got %q", err, err1) + } + return + } + + ui, err1 := parseUserInfo(buf, root) + if err1 != nil { + if err == nil { + t.Fatalf("expected no error, got %q", err1) + } + if !errors.Is(err1, err) { + t.Fatalf("expected error %q, got %q", err, err1) + } + return + } + if err != nil { + t.Fatalf("expected error %q, got nothing", err) + } + + if !reflect.DeepEqual(ui, v) { + t.Errorf("unexpected result %#v", ui) + } + }) } |