aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pkg/origin/origin.go73
-rw-r--r--pkg/origin/origin_test.go99
2 files changed, 142 insertions, 30 deletions
diff --git a/pkg/origin/origin.go b/pkg/origin/origin.go
index 2638f3e..c354e27 100644
--- a/pkg/origin/origin.go
+++ b/pkg/origin/origin.go
@@ -5,14 +5,22 @@ package origin
import (
"context"
"encoding/xml"
+ "errors"
"fmt"
"io"
+ "mime"
"net/http"
"strconv"
"strings"
"sync/atomic"
)
+var (
+ ErrInvalidResponse = errors.New("invalid origin api response")
+ ErrOrigin = errors.New("origin api error")
+ ErrAuthRequired = errors.New("origin authentication required")
+)
+
type SIDStore interface {
GetSID(ctx context.Context) (string, error)
SetSID(ctx context.Context, sid string) error
@@ -96,41 +104,66 @@ func (c *Client) getUserInfo(retry bool, ctx context.Context, uid ...int) ([]Use
}
defer resp.Body.Close()
- if needAuth, err := checkResponse(resp); err != nil {
- if retry && needAuth {
- if err := c.Login(ctx); err != nil {
- return nil, err
- }
- return c.getUserInfo(false, ctx, uid...)
- }
+ buf, root, err := checkResponseXML(resp)
+ if err != nil {
return nil, err
}
- return parseUserInfo(resp.Body)
+ return parseUserInfo(buf, root)
}
-func checkResponse(resp *http.Response) (bool, error) {
- // TODO: return true and err for auth required
- if resp.StatusCode != http.StatusOK {
- return false, fmt.Errorf("response status %q", resp.Status)
+func checkResponseXML(resp *http.Response) ([]byte, xml.Name, error) {
+ var root xml.Name
+ buf, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return buf, root, err
+ }
+ if mt, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")); mt != "application/xml" && mt != "text/xml" {
+ if resp.StatusCode != http.StatusOK {
+ return buf, root, fmt.Errorf("%w: response status %d (%s)", ErrOrigin, resp.StatusCode, resp.Status)
+ }
+ return buf, root, fmt.Errorf("%w: expected xml, got %q", ErrOrigin, mt)
+ }
+ if err := xml.Unmarshal(buf, &root); err != nil {
+ return buf, root, fmt.Errorf("%w: invalid xml: %v", ErrInvalidResponse, err)
}
- return false, nil
+ if root.Local == "error" {
+ var obj struct {
+ Code int `xml:"code,attr"`
+ Failure []struct {
+ Field string `xml:"field,attr"`
+ Cause string `xml:"cause,attr"`
+ Value string `xml:"value,attr"`
+ } `xml:"failure"`
+ }
+ if err := xml.Unmarshal(buf, &obj); err != nil {
+ return buf, root, fmt.Errorf("%w: response %#q (unmarshal: %v)", ErrOrigin, string(buf), err)
+ }
+ for _, f := range obj.Failure {
+ if f.Cause == "invalid_token" {
+ return buf, root, fmt.Errorf("%w: invalid token", ErrAuthRequired)
+ }
+ }
+ if len(obj.Failure) == 1 {
+ return buf, root, fmt.Errorf("%w: error %d: %s (%s) %q", ErrOrigin, obj.Code, obj.Failure[0].Cause, obj.Failure[0].Field, obj.Failure[0].Value)
+ }
+ return buf, root, fmt.Errorf("%w: error %d: response %#q", ErrOrigin, obj.Code, string(buf))
+ }
+ return buf, root, nil
}
-func parseUserInfo(r io.Reader) ([]UserInfo, error) {
+func parseUserInfo(buf []byte, root xml.Name) ([]UserInfo, error) {
var obj struct {
- XMLName xml.Name `xml:"users"`
- User []struct {
+ User []struct {
UserID string `xml:"userId"`
PersonaID string `xml:"personaId"`
EAID string `xml:"EAID"`
} `xml:"user"`
}
- buf, err := io.ReadAll(r)
- if err != nil {
- return nil, err
+ if root.Local != "users" {
+ return nil, fmt.Errorf("%w: unexpected %s response", ErrInvalidResponse, root.Local)
}
if err := xml.Unmarshal(buf, &obj); err != nil {
- return nil, err
+ return nil, fmt.Errorf("%w: invalid xml: %v", ErrInvalidResponse, err)
}
res := make([]UserInfo, len(obj.User))
for i, x := range obj.User {
diff --git a/pkg/origin/origin_test.go b/pkg/origin/origin_test.go
index ac06d82..7735355 100644
--- a/pkg/origin/origin_test.go
+++ b/pkg/origin/origin_test.go
@@ -1,20 +1,99 @@
package origin
import (
+ "errors"
+ "io"
+ "net/http"
"reflect"
+ "strconv"
"strings"
"testing"
)
func TestUserInfoResponse(t *testing.T) {
- ui, err := parseUserInfo(strings.NewReader(`<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?><users><user><userId>1001111111111</userId><personaId>1001111111111</personaId><EAID>test</EAID></user><user><userId>1001111111112</userId><personaId>1001111111112</personaId><EAID>test1</EAID></user></users>`))
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(ui, []UserInfo{
- {UserID: 1001111111111, PersonaID: "1001111111111", EAID: "test"},
- {UserID: 1001111111112, PersonaID: "1001111111112", EAID: "test1"},
- }) {
- t.Errorf("unexpected result %#v", ui)
- }
+ testUserInfoResponse(t,
+ "SuccessNew",
+ 200, "text/xml", `<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?><users><user><userId>1001111111111</userId><personaId>1001111111111</personaId><EAID>test</EAID></user><user><userId>1001111111112</userId><personaId>1001111111112</personaId><EAID>test1</EAID></user></users>`,
+ []UserInfo{
+ {UserID: 1001111111111, PersonaID: "1001111111111", EAID: "test"},
+ {UserID: 1001111111112, PersonaID: "1001111111112", EAID: "test1"},
+ }, nil,
+ )
+ testUserInfoResponse(t,
+ "SuccessOld",
+ 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><users><user><userId>2291234567</userId><personaId>328123456</personaId><EAID>blahblah</EAID></user></users>`,
+ []UserInfo{
+ {UserID: 2291234567, PersonaID: "328123456", EAID: "blahblah"},
+ }, nil,
+ )
+ testUserInfoResponse(t,
+ "EmptyToken",
+ 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><error code="10044"><failure value="" field="authToken" cause="MISSING_AUTHTOKEN"/></error>`,
+ nil, ErrOrigin,
+ )
+ testUserInfoResponse(t,
+ "InvalidExpiredToken",
+ 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><error code="10044"><failure value="" field="authToken" cause="invalid_token"/></error>`,
+ nil, ErrAuthRequired,
+ )
+ testUserInfoResponse(t,
+ "FakeWrongRootElement",
+ 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><fake/></error>`,
+ nil, ErrInvalidResponse,
+ )
+ testUserInfoResponse(t,
+ "FakeError",
+ 200, "text/xml", `<?xml version="1.0" encoding="UTF-8" standalone="yes"?><error code="12345"><failure value="" field="dummy" cause="fake"/></error>`,
+ nil, ErrOrigin,
+ )
+ testUserInfoResponse(t,
+ "FakeBadResponse",
+ 500, "text/plain", `Fake Internal Server Error`,
+ nil, ErrOrigin,
+ )
+ testUserInfoResponse(t,
+ "FakeInvalidXML",
+ 200, "text/xml", `fake`,
+ nil, ErrInvalidResponse,
+ )
+}
+
+func testUserInfoResponse(t *testing.T, name string, status int, mime, xml string, v []UserInfo, err error) {
+ t.Run(name, func(t *testing.T) {
+ buf, root, err1 := checkResponseXML(&http.Response{
+ Status: strconv.Itoa(status) + " " + http.StatusText(status),
+ StatusCode: status,
+ Body: io.NopCloser(strings.NewReader(xml)),
+ Header: http.Header{
+ "Content-Type": {mime},
+ },
+ })
+ if err1 != nil {
+ if err == nil {
+ t.Fatalf("expected no error, got %q", err1)
+ }
+ if !errors.Is(err1, err) {
+ t.Fatalf("expected error %q, got %q", err, err1)
+ }
+ return
+ }
+
+ ui, err1 := parseUserInfo(buf, root)
+ if err1 != nil {
+ if err == nil {
+ t.Fatalf("expected no error, got %q", err1)
+ }
+ if !errors.Is(err1, err) {
+ t.Fatalf("expected error %q, got %q", err, err1)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("expected error %q, got nothing", err)
+ }
+
+ if !reflect.DeepEqual(ui, v) {
+ t.Errorf("unexpected result %#v", ui)
+ }
+ })
}