diff --git a/Makefile b/Makefile index 5be7fe4..827448c 100644 --- a/Makefile +++ b/Makefile @@ -26,4 +26,8 @@ up: .PHONY: run run: clear - go run main.go \ No newline at end of file + go run main.go + +.PHONY: test +run: + go test ./... \ No newline at end of file diff --git a/auth/auth.go b/auth/auth.go index b99580f..3e6953c 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -16,31 +16,23 @@ import ( ) const ( - sessionName = "ua" - sessionKeyUserID = "user_id" + // sessionName stores the name of the session which contains authentication data + sessionName = "ua" + + // sessionKeyUserID stores the key used to store the user ID in the session + sessionKeyUserID = "user_id" + + // sessionKeyAuthenticated stores the key used to store the authentication status in the session sessionKeyAuthenticated = "authenticated" - passwordTokenLength = 64 ) -type NotAuthenticatedError struct{} - -// Error implements the error interface. -func (e NotAuthenticatedError) Error() string { - return "user not authenticated" -} - -type InvalidTokenError struct{} - -// Error implements the error interface. -func (e InvalidTokenError) Error() string { - return "invalid token" -} - +// Client is the client that handles authentication requests type Client struct { config *config.Config orm *ent.Client } +// NewClient creates a new authentication client func NewClient(cfg *config.Config, orm *ent.Client) *Client { return &Client{ config: cfg, @@ -48,6 +40,7 @@ func NewClient(cfg *config.Config, orm *ent.Client) *Client { } } +// Login logs in a user of a given ID func (c *Client) Login(ctx echo.Context, userID int) error { sess, err := session.Get(sessionName, ctx) if err != nil { @@ -58,6 +51,7 @@ func (c *Client) Login(ctx echo.Context, userID int) error { return sess.Save(ctx.Request(), ctx.Response()) } +// Logout logs the requesting user out func (c *Client) Logout(ctx echo.Context) error { sess, err := session.Get(sessionName, ctx) if err != nil { @@ -67,6 +61,7 @@ func (c *Client) Logout(ctx echo.Context) error { return sess.Save(ctx.Request(), ctx.Response()) } +// GetAuthenticatedUserID returns the authenticated user's ID, if the user is logged in func (c *Client) GetAuthenticatedUserID(ctx echo.Context) (int, error) { sess, err := session.Get(sessionName, ctx) if err != nil { @@ -80,6 +75,7 @@ func (c *Client) GetAuthenticatedUserID(ctx echo.Context) (int, error) { return 0, NotAuthenticatedError{} } +// GetAuthenticatedUser returns the authenticated user if the user is logged in func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { if userID, err := c.GetAuthenticatedUserID(ctx); err == nil { return c.orm.User.Query(). @@ -90,6 +86,7 @@ func (c *Client) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) { return nil, NotAuthenticatedError{} } +// HashPassword returns a hash of a given password func (c *Client) HashPassword(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { @@ -98,13 +95,18 @@ func (c *Client) HashPassword(password string) (string, error) { return string(hash), nil } +// CheckPassword check if a given password matches a given hash func (c *Client) CheckPassword(password, hash string) error { return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) } +// GeneratePasswordResetToken generates a password reset token for a given user. +// For security purposes, the token itself is not stored in the database but rather +// a hash of the token, exactly how passwords are handled. This method returns both +// the generated token as well as the token entity which only contains the hash. func (c *Client) GeneratePasswordResetToken(ctx echo.Context, userID int) (string, *ent.PasswordToken, error) { // Generate the token, which is what will go in the URL, but not the database - token, err := c.RandomToken(passwordTokenLength) + token, err := c.RandomToken(c.config.App.PasswordToken.Length) if err != nil { return "", nil, err } @@ -125,9 +127,13 @@ func (c *Client) GeneratePasswordResetToken(ctx echo.Context, userID int) (strin return token, pt, err } +// GetValidPasswordToken returns a valid password token entity for a given user and a given token. +// Since the actual token is not stored in the database for security purposes, all non-expired token entities +// are fetched from the database belonging to the requesting user and a hash of the provided token is compared +// with the hash stored in the database. func (c *Client) GetValidPasswordToken(ctx echo.Context, token string, userID int) (*ent.PasswordToken, error) { // Ensure expired tokens are never returned - expiration := time.Now().Add(-c.config.App.PasswordTokenExpiration) + expiration := time.Now().Add(-c.config.App.PasswordToken.Expiration) // Query to find all tokens for te user that haven't expired // We need to get all of them in order to properly match the token to the hashes @@ -152,6 +158,8 @@ func (c *Client) GetValidPasswordToken(ctx echo.Context, token string, userID in return nil, InvalidTokenError{} } +// DeletePasswordTokens deletes all password tokens in the database for a belonging to a given user. +// This should be called after a successful password reset. func (c *Client) DeletePasswordTokens(ctx echo.Context, userID int) error { _, err := c.orm.PasswordToken. Delete(). @@ -161,6 +169,7 @@ func (c *Client) DeletePasswordTokens(ctx echo.Context, userID int) error { return err } +// RandomToken generates a random token string of a given length func (c *Client) RandomToken(length int) (string, error) { b := make([]byte, length) if _, err := rand.Read(b); err != nil { diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 0000000..5b1f725 --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,117 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "goweb/config" + "goweb/container" + "goweb/ent" + + "github.com/labstack/echo/v4" + + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" +) + +var ( + authClient *Client + c *container.Container + ctx echo.Context + usr *ent.User +) + +func TestMain(m *testing.M) { + // Set the environment to test + config.SwitchEnvironment(config.EnvTest) + + // Create an auth client + c := container.NewContainer() + authClient = c.Auth + + // Create a web context + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) + rec := httptest.NewRecorder() + ctx = c.Web.NewContext(req, rec) + + // Create a test uset + var err error + usr, err = c.ORM.User. + Create(). + SetEmail("test@test.dev"). + SetPassword("abc"). + SetName("Test User"). + Save(context.Background()) + + if err != nil { + panic(err) + } + + // Run tests + exitVal := m.Run() + os.Exit(exitVal) +} + +func TestLogin(t *testing.T) { + +} + +func TestLogout(t *testing.T) { + +} + +func TestGetAuthenticatedUserID(t *testing.T) { + +} + +func TestGetAuthenticatedUser(t *testing.T) { + +} + +func TestHashPassword(t *testing.T) { + pw := "abcdef" + hash, err := authClient.HashPassword(pw) + assert.NoError(t, err) + assert.NotEqual(t, hash, pw) +} + +func TestCheckPassword(t *testing.T) { + pw := "testcheckpassword" + hash, err := authClient.HashPassword(pw) + assert.NoError(t, err) + err = authClient.CheckPassword(pw, hash) + assert.NoError(t, err) +} + +func TestGeneratePasswordResetToken(t *testing.T) { + token, pt, err := authClient.GeneratePasswordResetToken(ctx, usr.ID) + require.NoError(t, err) + hash, err := authClient.HashPassword(token) + require.NoError(t, err) + assert.Len(t, token, c.Config.App.PasswordToken.Length) + assert.Equal(t, hash, pt.Hash) + assert.Equal(t, usr.ID, pt.Edges.User.ID) +} + +func TestGetValidPasswordToken(t *testing.T) { + +} + +func TestDeletePasswordTokens(t *testing.T) { +} + +func TestRandomToken(t *testing.T) { + length := 64 + a, err := authClient.RandomToken(length) + require.NoError(t, err) + b, err := authClient.RandomToken(length) + require.NoError(t, err) + assert.Len(t, a, 64) + assert.Len(t, b, 64) + assert.NotEqual(t, a, b) +} diff --git a/auth/errors.go b/auth/errors.go new file mode 100644 index 0000000..514115c --- /dev/null +++ b/auth/errors.go @@ -0,0 +1,17 @@ +package auth + +// NotAuthenticatedError is an error returned when a user is not authenticated +type NotAuthenticatedError struct{} + +// Error implements the error interface. +func (e NotAuthenticatedError) Error() string { + return "user not authenticated" +} + +// InvalidTokenError is an error returned when an invalid token is provided +type InvalidTokenError struct{} + +// Error implements the error interface. +func (e InvalidTokenError) Error() string { + return "invalid token" +} diff --git a/config/config.go b/config/config.go index a9ec743..49695ad 100644 --- a/config/config.go +++ b/config/config.go @@ -46,11 +46,14 @@ type ( // AppConfig stores application configuration AppConfig struct { - Name string `env:"APP_NAME,default=Goweb"` - Environment Environment `env:"APP_ENVIRONMENT,default=local"` - EncryptionKey string `env:"APP_ENCRYPTION_KEY,default=?E(G+KbPeShVmYq3t6w9z$C&F)J@McQf"` - Timeout time.Duration `env:"APP_TIMEOUT,default=20s"` - PasswordTokenExpiration time.Duration `env:"APP_PASSWORD_TOKEN_EXPIRATION,default=60m"` + Name string `env:"APP_NAME,default=Goweb"` + Environment Environment `env:"APP_ENVIRONMENT,default=local"` + EncryptionKey string `env:"APP_ENCRYPTION_KEY,default=?E(G+KbPeShVmYq3t6w9z$C&F)J@McQf"` + Timeout time.Duration `env:"APP_TIMEOUT,default=20s"` + PasswordToken struct { + Expiration time.Duration `env:"APP_PASSWORD_TOKEN_EXPIRATION,default=60m"` + Length int `env:"APP_PASSWORD_TOKEN_LENGTH,default=64"` + } } CacheConfig struct {