Added redirect package.

This commit is contained in:
mikestefanello 2024-06-16 11:30:11 -04:00
parent 8cae6e6beb
commit a70003d290
6 changed files with 208 additions and 82 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/mikestefanello/pagoda/pkg/middleware"
"github.com/mikestefanello/pagoda/pkg/msg"
"github.com/mikestefanello/pagoda/pkg/page"
"github.com/mikestefanello/pagoda/pkg/redirect"
"github.com/mikestefanello/pagoda/pkg/services"
"github.com/mikestefanello/pagoda/templates"
)
@ -223,7 +224,10 @@ func (h *Auth) LoginSubmit(ctx echo.Context) error {
}
msg.Success(ctx, fmt.Sprintf("Welcome back, <strong>%s</strong>. You are now logged in.", u.Name))
return redirect(ctx, routeNameHome)
return redirect.New(ctx).
Route(routeNameHome).
Go()
}
func (h *Auth) Logout(ctx echo.Context) error {
@ -232,7 +236,9 @@ func (h *Auth) Logout(ctx echo.Context) error {
} else {
msg.Danger(ctx, "An error occurred. Please try again.")
}
return redirect(ctx, routeNameHome)
return redirect.New(ctx).
Route(routeNameHome).
Go()
}
func (h *Auth) RegisterPage(ctx echo.Context) error {
@ -280,7 +286,9 @@ func (h *Auth) RegisterSubmit(ctx echo.Context) error {
)
case *ent.ConstraintError:
msg.Warning(ctx, "A user with this email address already exists. Please log in.")
return redirect(ctx, routeNameLogin)
return redirect.New(ctx).
Route(routeNameLogin).
Go()
default:
return fail(err, "unable to create user")
}
@ -293,7 +301,9 @@ func (h *Auth) RegisterSubmit(ctx echo.Context) error {
"user_id", u.ID,
)
msg.Info(ctx, "Your account has been created.")
return redirect(ctx, routeNameLogin)
return redirect.New(ctx).
Route(routeNameLogin).
Go()
}
msg.Success(ctx, "Your account has been created. You are now logged in.")
@ -301,7 +311,9 @@ func (h *Auth) RegisterSubmit(ctx echo.Context) error {
// Send the verification email
h.sendVerificationEmail(ctx, u)
return redirect(ctx, routeNameHome)
return redirect.New(ctx).
Route(routeNameHome).
Go()
}
func (h *Auth) sendVerificationEmail(ctx echo.Context, usr *ent.User) {
@ -384,7 +396,9 @@ func (h *Auth) ResetPasswordSubmit(ctx echo.Context) error {
}
msg.Success(ctx, "Your password has been updated.")
return redirect(ctx, routeNameLogin)
return redirect.New(ctx).
Route(routeNameLogin).
Go()
}
func (h *Auth) VerifyEmail(ctx echo.Context) error {
@ -395,7 +409,9 @@ func (h *Auth) VerifyEmail(ctx echo.Context) error {
email, err := h.auth.ValidateEmailVerificationToken(token)
if err != nil {
msg.Warning(ctx, "The link is either invalid or has expired.")
return redirect(ctx, routeNameHome)
return redirect.New(ctx).
Route(routeNameHome).
Go()
}
// Check if it matches the authenticated user
@ -432,5 +448,7 @@ func (h *Auth) VerifyEmail(ctx echo.Context) error {
}
msg.Success(ctx, "Your email has been successfully verified.")
return redirect(ctx, routeNameHome)
return redirect.New(ctx).
Route(routeNameHome).
Go()
}

View file

@ -3,10 +3,8 @@ package handlers
import (
"fmt"
"net/http"
"net/url"
"github.com/labstack/echo/v4"
"github.com/mikestefanello/pagoda/pkg/htmx"
"github.com/mikestefanello/pagoda/pkg/services"
)
@ -31,30 +29,6 @@ func GetHandlers() []Handler {
return handlers
}
// redirect redirects to a given route by name with optional route parameters
func redirect(ctx echo.Context, routeName string, routeParams ...any) error {
return doRedirect(ctx, ctx.Echo().Reverse(routeName, routeParams...))
}
// redirectWithQuery redirects to a given route by name with query parameters and optional route parameters
func redirectWithQuery(ctx echo.Context, query url.Values, routeName string, routeParams ...any) error {
dest := fmt.Sprintf("%s?%s", ctx.Echo().Reverse(routeName, routeParams...), query.Encode())
return doRedirect(ctx, dest)
}
// doRedirect performs a redirect to a given URL
func doRedirect(ctx echo.Context, url string) error {
if htmx.GetRequest(ctx).Boosted {
htmx.Response{
Redirect: url,
}.Apply(ctx)
return nil
} else {
return ctx.Redirect(http.StatusFound, url)
}
}
// fail is a helper to fail a request by returning a 500 error and logging the error
func fail(err error, log string) error {
// The error handler will handle logging

View file

@ -3,12 +3,9 @@ package handlers
import (
"errors"
"net/http"
"net/url"
"testing"
"github.com/labstack/echo/v4"
"github.com/mikestefanello/pagoda/pkg/htmx"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -23,51 +20,6 @@ func TestGetSetHandlers(t *testing.T) {
assert.Equal(t, h, got[0])
}
func TestRedirect(t *testing.T) {
c.Web.GET("/path/:first/and/:second", func(c echo.Context) error {
return nil
}).Name = "redirect-test"
t.Run("no query", func(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/abc")
err := redirect(ctx, "redirect-test", "one", "two")
require.NoError(t, err)
assert.Equal(t, "/path/one/and/two", ctx.Response().Header().Get(echo.HeaderLocation))
assert.Equal(t, http.StatusFound, ctx.Response().Status)
})
t.Run("no query htmx", func(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/abc")
ctx.Request().Header.Set(htmx.HeaderBoosted, "true")
err := redirect(ctx, "redirect-test", "one", "two")
require.NoError(t, err)
assert.Equal(t, "/path/one/and/two", ctx.Response().Header().Get(htmx.HeaderRedirect))
})
t.Run("query", func(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/abc")
q := url.Values{}
q.Add("a", "1")
q.Add("b", "2")
err := redirectWithQuery(ctx, q, "redirect-test", "one", "two")
require.NoError(t, err)
assert.Equal(t, "/path/one/and/two?a=1&b=2", ctx.Response().Header().Get(echo.HeaderLocation))
assert.Equal(t, http.StatusFound, ctx.Response().Status)
})
t.Run("query htmx", func(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/abc")
ctx.Request().Header.Set(htmx.HeaderBoosted, "true")
q := url.Values{}
q.Add("a", "1")
q.Add("b", "2")
err := redirectWithQuery(ctx, q, "redirect-test", "one", "two")
require.NoError(t, err)
assert.Equal(t, "/path/one/and/two?a=1&b=2", ctx.Response().Header().Get(htmx.HeaderRedirect))
assert.Equal(t, http.StatusFound, ctx.Response().Status)
})
}
func TestFail(t *testing.T) {
err := fail(errors.New("err message"), "log message")
require.IsType(t, new(echo.HTTPError), err)