From 27e34024f187696e87b58e6fc6c6b60d02416002 Mon Sep 17 00:00:00 2001 From: mikestefanello Date: Fri, 17 Dec 2021 20:58:51 -0500 Subject: [PATCH] Moved auth to container. Added tests for auth. --- auth/auth_test.go | 117 ------------------------------------ auth/errors.go | 17 ------ {auth => container}/auth.go | 53 ++++++++++------ container/auth_test.go | 105 ++++++++++++++++++++++++++++++++ container/container.go | 5 +- container/container_test.go | 35 +++++++++++ 6 files changed, 177 insertions(+), 155 deletions(-) delete mode 100644 auth/auth_test.go delete mode 100644 auth/errors.go rename {auth => container}/auth.go (73%) create mode 100644 container/auth_test.go diff --git a/auth/auth_test.go b/auth/auth_test.go deleted file mode 100644 index 5b1f725..0000000 --- a/auth/auth_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package auth - -import ( - "context" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" - - "goweb/config" - "goweb/container" - "goweb/ent" - - "github.com/labstack/echo/v4" - - "github.com/stretchr/testify/require" - - "github.com/stretchr/testify/assert" -) - -var ( - authClient *Client - c *container.Container - ctx echo.Context - usr *ent.User -) - -func TestMain(m *testing.M) { - // Set the environment to test - config.SwitchEnvironment(config.EnvTest) - - // Create an auth client - c := container.NewContainer() - authClient = c.Auth - - // Create a web context - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) - rec := httptest.NewRecorder() - ctx = c.Web.NewContext(req, rec) - - // Create a test uset - var err error - usr, err = c.ORM.User. - Create(). - SetEmail("test@test.dev"). - SetPassword("abc"). - SetName("Test User"). - Save(context.Background()) - - if err != nil { - panic(err) - } - - // Run tests - exitVal := m.Run() - os.Exit(exitVal) -} - -func TestLogin(t *testing.T) { - -} - -func TestLogout(t *testing.T) { - -} - -func TestGetAuthenticatedUserID(t *testing.T) { - -} - -func TestGetAuthenticatedUser(t *testing.T) { - -} - -func TestHashPassword(t *testing.T) { - pw := "abcdef" - hash, err := authClient.HashPassword(pw) - assert.NoError(t, err) - assert.NotEqual(t, hash, pw) -} - -func TestCheckPassword(t *testing.T) { - pw := "testcheckpassword" - hash, err := authClient.HashPassword(pw) - assert.NoError(t, err) - err = authClient.CheckPassword(pw, hash) - assert.NoError(t, err) -} - -func TestGeneratePasswordResetToken(t *testing.T) { - token, pt, err := authClient.GeneratePasswordResetToken(ctx, usr.ID) - require.NoError(t, err) - hash, err := authClient.HashPassword(token) - require.NoError(t, err) - assert.Len(t, token, c.Config.App.PasswordToken.Length) - assert.Equal(t, hash, pt.Hash) - assert.Equal(t, usr.ID, pt.Edges.User.ID) -} - -func TestGetValidPasswordToken(t *testing.T) { - -} - -func TestDeletePasswordTokens(t *testing.T) { -} - -func TestRandomToken(t *testing.T) { - length := 64 - a, err := authClient.RandomToken(length) - require.NoError(t, err) - b, err := authClient.RandomToken(length) - require.NoError(t, err) - assert.Len(t, a, 64) - assert.Len(t, b, 64) - assert.NotEqual(t, a, b) -} diff --git a/auth/errors.go b/auth/errors.go deleted file mode 100644 index 514115c..0000000 --- a/auth/errors.go +++ /dev/null @@ -1,17 +0,0 @@ -package auth - -// NotAuthenticatedError is an error returned when a user is not authenticated -type NotAuthenticatedError struct{} - -// Error implements the error interface. -func (e NotAuthenticatedError) Error() string { - return "user not authenticated" -} - -// InvalidTokenError is an error returned when an invalid token is provided -type InvalidTokenError struct{} - -// Error implements the error interface. -func (e InvalidTokenError) Error() string { - return "invalid token" -} diff --git a/auth/auth.go b/container/auth.go similarity index 73% rename from auth/auth.go rename to container/auth.go index 3e6953c..7853d68 100644 --- a/auth/auth.go +++ b/container/auth.go @@ -1,4 +1,4 @@ -package auth +package container import ( "crypto/rand" @@ -26,22 +26,38 @@ const ( sessionKeyAuthenticated = "authenticated" ) -// Client is the client that handles authentication requests -type Client struct { +// NotAuthenticatedError is an error returned when a user is not authenticated +type NotAuthenticatedError struct{} + +// Error implements the error interface. +func (e NotAuthenticatedError) Error() string { + return "user not authenticated" +} + +// InvalidTokenError is an error returned when an invalid token is provided +type InvalidTokenError struct{} + +// Error implements the error interface. +func (e InvalidTokenError) Error() string { + return "invalid token" +} + +// AuthClient is the AuthClient that handles authentication requests +type AuthClient struct { config *config.Config orm *ent.Client } -// NewClient creates a new authentication client -func NewClient(cfg *config.Config, orm *ent.Client) *Client { - return &Client{ +// NewAuthClient creates a new authentication AuthClient +func NewAuthClient(cfg *config.Config, orm *ent.Client) *AuthClient { + return &AuthClient{ config: cfg, orm: orm, } } // Login logs in a user of a given ID -func (c *Client) Login(ctx echo.Context, userID int) error { +func (c *AuthClient) Login(ctx echo.Context, userID int) error { sess, err := session.Get(sessionName, ctx) if err != nil { return err @@ -52,7 +68,7 @@ func (c *Client) Login(ctx echo.Context, userID int) error { } // Logout logs the requesting user out -func (c *Client) Logout(ctx echo.Context) error { +func (c *AuthClient) Logout(ctx echo.Context) error { sess, err := session.Get(sessionName, ctx) if err != nil { return err @@ -62,7 +78,7 @@ func (c *Client) Logout(ctx echo.Context) error { } // GetAuthenticatedUserID returns the authenticated user's ID, if the user is logged in -func (c *Client) GetAuthenticatedUserID(ctx echo.Context) (int, error) { +func (c *AuthClient) GetAuthenticatedUserID(ctx echo.Context) (int, error) { sess, err := session.Get(sessionName, ctx) if err != nil { return 0, err @@ -76,7 +92,7 @@ func (c *Client) GetAuthenticatedUserID(ctx echo.Context) (int, error) { } // GetAuthenticatedUser returns the authenticated user if the user is logged in -func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { +func (c *AuthClient) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { if userID, err := c.GetAuthenticatedUserID(ctx); err == nil { return c.orm.User.Query(). Where(user.ID(userID)). @@ -87,7 +103,7 @@ func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { } // HashPassword returns a hash of a given password -func (c *Client) HashPassword(password string) (string, error) { +func (c *AuthClient) HashPassword(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err @@ -96,7 +112,7 @@ func (c *Client) HashPassword(password string) (string, error) { } // CheckPassword check if a given password matches a given hash -func (c *Client) CheckPassword(password, hash string) error { +func (c *AuthClient) CheckPassword(password, hash string) error { return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) } @@ -104,7 +120,7 @@ func (c *Client) CheckPassword(password, hash string) error { // For security purposes, the token itself is not stored in the database but rather // a hash of the token, exactly how passwords are handled. This method returns both // the generated token as well as the token entity which only contains the hash. -func (c *Client) GeneratePasswordResetToken(ctx echo.Context, userID int) (string, *ent.PasswordToken, error) { +func (c *AuthClient) GeneratePasswordResetToken(ctx echo.Context, userID int) (string, *ent.PasswordToken, error) { // Generate the token, which is what will go in the URL, but not the database token, err := c.RandomToken(c.config.App.PasswordToken.Length) if err != nil { @@ -131,7 +147,7 @@ func (c *Client) GeneratePasswordResetToken(ctx echo.Context, userID int) (strin // Since the actual token is not stored in the database for security purposes, all non-expired token entities // are fetched from the database belonging to the requesting user and a hash of the provided token is compared // with the hash stored in the database. -func (c *Client) GetValidPasswordToken(ctx echo.Context, token string, userID int) (*ent.PasswordToken, error) { +func (c *AuthClient) GetValidPasswordToken(ctx echo.Context, token string, userID int) (*ent.PasswordToken, error) { // Ensure expired tokens are never returned expiration := time.Now().Add(-c.config.App.PasswordToken.Expiration) @@ -160,7 +176,7 @@ func (c *Client) GetValidPasswordToken(ctx echo.Context, token string, userID in // DeletePasswordTokens deletes all password tokens in the database for a belonging to a given user. // This should be called after a successful password reset. -func (c *Client) DeletePasswordTokens(ctx echo.Context, userID int) error { +func (c *AuthClient) DeletePasswordTokens(ctx echo.Context, userID int) error { _, err := c.orm.PasswordToken. Delete(). Where(passwordtoken.HasUserWith(user.ID(userID))). @@ -170,10 +186,11 @@ func (c *Client) DeletePasswordTokens(ctx echo.Context, userID int) error { } // RandomToken generates a random token string of a given length -func (c *Client) RandomToken(length int) (string, error) { - b := make([]byte, length) +func (c *AuthClient) RandomToken(length int) (string, error) { + b := make([]byte, (length/2)+1) if _, err := rand.Read(b); err != nil { return "", err } - return hex.EncodeToString(b), nil + token := hex.EncodeToString(b) + return token[:length], nil } diff --git a/container/auth_test.go b/container/auth_test.go new file mode 100644 index 0000000..cb984ae --- /dev/null +++ b/container/auth_test.go @@ -0,0 +1,105 @@ +package container + +import ( + "context" + "errors" + "testing" + + "goweb/ent/passwordtoken" + "goweb/ent/user" + + "github.com/gorilla/sessions" + "github.com/labstack/echo-contrib/session" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" +) + +func TestAuth(t *testing.T) { + // Simulate an HTTP request through the session middleware to initiate the session + mw := session.Middleware(sessions.NewCookieStore([]byte("secret"))) + handler := mw(echo.NotFoundHandler) + assert.Error(t, handler(ctx)) + + assertNoAuth := func() { + _, err := c.Auth.GetAuthenticatedUserID(ctx) + assert.True(t, errors.Is(err, NotAuthenticatedError{})) + _, err = c.Auth.GetAuthenticatedUser(ctx) + assert.True(t, errors.Is(err, NotAuthenticatedError{})) + } + + assertNoAuth() + + err := c.Auth.Login(ctx, usr.ID) + require.NoError(t, err) + + uid, err := c.Auth.GetAuthenticatedUserID(ctx) + require.NoError(t, err) + assert.Equal(t, usr.ID, uid) + + u, err := c.Auth.GetAuthenticatedUser(ctx) + require.NoError(t, err) + assert.Equal(t, u.ID, usr.ID) + + err = c.Auth.Logout(ctx) + require.NoError(t, err) + + assertNoAuth() +} + +func TestPasswordHashing(t *testing.T) { + pw := "testcheckpassword" + hash, err := c.Auth.HashPassword(pw) + assert.NoError(t, err) + assert.NotEqual(t, hash, pw) + err = c.Auth.CheckPassword(pw, hash) + assert.NoError(t, err) +} + +func TestGeneratePasswordResetToken(t *testing.T) { + token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID) + require.NoError(t, err) + assert.Len(t, token, c.Config.App.PasswordToken.Length) + assert.NoError(t, c.Auth.CheckPassword(token, pt.Hash)) +} + +func TestGetValidPasswordToken(t *testing.T) { + _, err := c.Auth.GetValidPasswordToken(ctx, "faketoken", usr.ID) + assert.Error(t, err) + + token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID) + require.NoError(t, err) + pt2, err := c.Auth.GetValidPasswordToken(ctx, token, usr.ID) + require.NoError(t, err) + assert.Equal(t, pt.ID, pt2.ID) +} + +func TestDeletePasswordTokens(t *testing.T) { + for i := 0; i < 3; i++ { + _, _, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID) + require.NoError(t, err) + } + + err := c.Auth.DeletePasswordTokens(ctx, usr.ID) + require.NoError(t, err) + + count, err := c.ORM.PasswordToken. + Query(). + Where(passwordtoken.HasUserWith(user.ID(usr.ID))). + Count(context.Background()) + + require.NoError(t, err) + assert.Equal(t, 0, count) +} + +func TestRandomToken(t *testing.T) { + length := 64 + a, err := c.Auth.RandomToken(length) + require.NoError(t, err) + b, err := c.Auth.RandomToken(length) + require.NoError(t, err) + assert.Len(t, a, 64) + assert.Len(t, b, 64) + assert.NotEqual(t, a, b) +} diff --git a/container/container.go b/container/container.go index 7f4aa68..4016a82 100644 --- a/container/container.go +++ b/container/container.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" - "goweb/auth" "goweb/mail" "entgo.io/ent/dialect" @@ -28,7 +27,7 @@ type Container struct { Database *sql.DB ORM *ent.Client Mail *mail.Client - Auth *auth.Client + Auth *AuthClient } func NewContainer() *Container { @@ -130,5 +129,5 @@ func (c *Container) initMail() { } func (c *Container) initAuth() { - c.Auth = auth.NewClient(c.Config, c.ORM) + c.Auth = NewAuthClient(c.Config, c.ORM) } diff --git a/container/container_test.go b/container/container_test.go index 6f1af4d..ec1bce7 100644 --- a/container/container_test.go +++ b/container/container_test.go @@ -1,18 +1,53 @@ package container import ( + "context" + "net/http" + "net/http/httptest" "os" + "strings" "testing" "goweb/config" + "goweb/ent" + + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) +var ( + c *Container + ctx echo.Context + usr *ent.User + rec *httptest.ResponseRecorder +) + func TestMain(m *testing.M) { // Set the environment to test config.SwitchEnvironment(config.EnvTest) + // Create a new container + c = NewContainer() + + // Create a web context + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) + rec = httptest.NewRecorder() + ctx = c.Web.NewContext(req, rec) + + // Create a test user + var err error + usr, err = c.ORM.User. + Create(). + SetEmail("test@test.dev"). + SetPassword("abc"). + SetName("Test User"). + Save(context.Background()) + + if err != nil { + panic(err) + } + // Run tests exitVal := m.Run() os.Exit(exitVal)