From acd38c8205e4443cafdf6bbb3d59827816f707f6 Mon Sep 17 00:00:00 2001 From: mikestefanello Date: Sun, 9 Jan 2022 00:23:26 -0500 Subject: [PATCH] Handle context cancellations and avoid logged errors. --- context/context.go | 10 ++++++++++ context/context_test.go | 24 ++++++++++++++++++++++++ controller/controller.go | 9 ++++++++- middleware/auth.go | 6 ++++++ middleware/cache.go | 8 ++++++-- middleware/entity.go | 3 +++ routes/error.go | 25 +++++++++++++------------ routes/verify_email.go | 33 +++++++++++++-------------------- 8 files changed, 83 insertions(+), 35 deletions(-) create mode 100644 context/context_test.go diff --git a/context/context.go b/context/context.go index 177bf30..71c465f 100644 --- a/context/context.go +++ b/context/context.go @@ -1,5 +1,10 @@ package context +import ( + "context" + "errors" +) + const ( // AuthenticatedUserKey is the key value used to store the authenticated user in context AuthenticatedUserKey = "auth_user" @@ -13,3 +18,8 @@ const ( // PasswordTokenKey is the key value used to store a password token in context PasswordTokenKey = "password_token" ) + +// IsCanceledError determines if an error is due to a context cancelation +func IsCanceledError(err error) bool { + return errors.Is(err, context.Canceled) +} diff --git a/context/context_test.go b/context/context_test.go new file mode 100644 index 0000000..77abb6a --- /dev/null +++ b/context/context_test.go @@ -0,0 +1,24 @@ +package context + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestIsCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + assert.False(t, IsCanceledError(ctx.Err())) + cancel() + assert.True(t, IsCanceledError(ctx.Err())) + + ctx, cancel = context.WithTimeout(context.Background(), time.Microsecond) + time.Sleep(time.Microsecond * 2) + cancel() + assert.False(t, IsCanceledError(ctx.Err())) + + assert.False(t, IsCanceledError(errors.New("test error"))) +} diff --git a/controller/controller.go b/controller/controller.go index 19e43c4..738755f 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + "github.com/mikestefanello/pagoda/context" "github.com/mikestefanello/pagoda/middleware" "github.com/mikestefanello/pagoda/services" @@ -143,7 +144,10 @@ func (c *Controller) cachePage(ctx echo.Context, page Page, html *bytes.Buffer) err := marshaler.New(c.Container.Cache).Set(ctx.Request().Context(), key, cp, opts) if err != nil { - ctx.Logger().Errorf("failed to cache page: %v", err) + if !context.IsCanceledError(err) { + ctx.Logger().Errorf("failed to cache page: %v", err) + } + return } @@ -158,6 +162,9 @@ func (c *Controller) Redirect(ctx echo.Context, route string, routeParams ...int // Fail is a helper to fail a request by returning a 500 error and logging the error func (c *Controller) Fail(ctx echo.Context, err error, log string) error { + if context.IsCanceledError(err) { + return nil + } ctx.Logger().Errorf("%s: %v", log, err) return echo.NewHTTPError(http.StatusInternalServerError) } diff --git a/middleware/auth.go b/middleware/auth.go index 26bb812..657cea3 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -24,6 +24,9 @@ func LoadAuthenticatedUser(authClient *services.AuthClient) echo.MiddlewareFunc c.Set(context.AuthenticatedUserKey, u) c.Logger().Infof("auth user loaded in to context: %d", u.ID) default: + if context.IsCanceledError(err) { + return nil + } c.Logger().Errorf("error querying for authenticated user: %v", err) } @@ -55,6 +58,9 @@ func LoadValidPasswordToken(authClient *services.AuthClient) echo.MiddlewareFunc msg.Warning(c, "The link is either invalid or has expired. Please request a new one.") return c.Redirect(http.StatusFound, c.Echo().Reverse("forgot_password")) default: + if context.IsCanceledError(err) { + return nil + } c.Logger().Error(err) return echo.NewHTTPError(http.StatusInternalServerError) } diff --git a/middleware/cache.go b/middleware/cache.go index 5bd6bfc..47b6ba5 100644 --- a/middleware/cache.go +++ b/middleware/cache.go @@ -51,11 +51,15 @@ func ServeCachedPage(ch *cache.Cache) echo.MiddlewareFunc { new(CachedPage), ) if err != nil { - if err == redis.Nil { + switch { + case err == redis.Nil: c.Logger().Info("no cached page found") - } else { + case context.IsCanceledError(err): + return nil + default: c.Logger().Errorf("failed getting cached page: %v", err) } + return next(c) } diff --git a/middleware/entity.go b/middleware/entity.go index 8137fbe..803f24e 100644 --- a/middleware/entity.go +++ b/middleware/entity.go @@ -32,6 +32,9 @@ func LoadUser(orm *ent.Client) echo.MiddlewareFunc { case *ent.NotFoundError: return echo.NewHTTPError(http.StatusNotFound) default: + if context.IsCanceledError(err) { + return nil + } c.Logger().Error(err) return echo.NewHTTPError(http.StatusInternalServerError) } diff --git a/routes/error.go b/routes/error.go index 0dc28bc..50e2bfc 100644 --- a/routes/error.go +++ b/routes/error.go @@ -3,6 +3,7 @@ package routes import ( "net/http" + "github.com/mikestefanello/pagoda/context" "github.com/mikestefanello/pagoda/controller" "github.com/labstack/echo/v4" @@ -12,8 +13,8 @@ type Error struct { controller.Controller } -func (e *Error) Get(err error, c echo.Context) { - if c.Response().Committed { +func (e *Error) Get(err error, ctx echo.Context) { + if ctx.Response().Committed || context.IsCanceledError(err) { return } @@ -23,19 +24,19 @@ func (e *Error) Get(err error, c echo.Context) { } if code >= 500 { - c.Logger().Error(err) + ctx.Logger().Error(err) } else { - c.Logger().Info(err) + ctx.Logger().Info(err) } - p := controller.NewPage(c) - p.Layout = "main" - p.Title = http.StatusText(code) - p.Name = "error" - p.StatusCode = code - p.HTMX.Request.Enabled = false + page := controller.NewPage(ctx) + page.Layout = "main" + page.Title = http.StatusText(code) + page.Name = "error" + page.StatusCode = code + page.HTMX.Request.Enabled = false - if err = e.RenderPage(c, p); err != nil { - c.Logger().Error(err) + if err = e.RenderPage(ctx, page); err != nil { + ctx.Logger().Error(err) } } diff --git a/routes/verify_email.go b/routes/verify_email.go index afa3706..3f247b5 100644 --- a/routes/verify_email.go +++ b/routes/verify_email.go @@ -14,12 +14,6 @@ type VerifyEmail struct { } func (c *VerifyEmail) Get(ctx echo.Context) error { - c.verifyToken(ctx) - - return c.Redirect(ctx, "home") -} - -func (c *VerifyEmail) verifyToken(ctx echo.Context) { var usr *ent.User // Validate the token @@ -27,7 +21,7 @@ func (c *VerifyEmail) verifyToken(ctx echo.Context) { email, err := c.Container.Auth.ValidateEmailVerificationToken(token) if err != nil { msg.Warning(ctx, "The link is either invalid or has expired.") - return + return c.Redirect(ctx, "home") } // Check if it matches the authenticated user @@ -47,24 +41,23 @@ func (c *VerifyEmail) verifyToken(ctx echo.Context) { Only(ctx.Request().Context()) if err != nil { - ctx.Logger().Errorf("error querying user during email verification: %v", err) - msg.Danger(ctx, "An error occurred. Please try again.") - return + return c.Fail(ctx, err, "query failed loading email verification token user") } } - // Verify the user - err = c.Container.ORM.User. - Update(). - SetVerified(true). - Where(user.ID(usr.ID)). - Exec(ctx.Request().Context()) + // Verify the user, if needed + if !usr.Verified { + err = c.Container.ORM.User. + Update(). + SetVerified(true). + Where(user.ID(usr.ID)). + Exec(ctx.Request().Context()) - if err != nil { - ctx.Logger().Errorf("error setting user as verified: %v", err) - msg.Danger(ctx, "An error occurred. Please try again.") - return + if err != nil { + return c.Fail(ctx, err, "failed to set user as verified") + } } msg.Success(ctx, "Your email has been successfully verified.") + return c.Redirect(ctx, "home") }