From fc3fee130670572d089b8993bddeab04e12abdd5 Mon Sep 17 00:00:00 2001 From: mikestefanello Date: Tue, 21 Dec 2021 21:02:25 -0500 Subject: [PATCH] Added tests for entity and log middleware. --- middleware/entity.go | 13 ++++++------- middleware/entity_test.go | 23 +++++++++++++++++++++++ middleware/log.go | 1 + middleware/log_test.go | 27 +++++++++++++++++++++++++++ 4 files changed, 57 insertions(+), 7 deletions(-) create mode 100644 middleware/entity_test.go create mode 100644 middleware/log_test.go diff --git a/middleware/entity.go b/middleware/entity.go index e7e4608..a9fa2af 100644 --- a/middleware/entity.go +++ b/middleware/entity.go @@ -11,12 +11,13 @@ import ( "github.com/labstack/echo/v4" ) +// LoadUser loads the user based on the ID provided as a path parameter func LoadUser(orm *ent.Client) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { userID, err := strconv.Atoi(c.Param("user")) if err != nil { - return echo.NewHTTPError(http.StatusNotFound, "Not found") + return echo.NewHTTPError(http.StatusNotFound) } u, err := orm.User. @@ -26,16 +27,14 @@ func LoadUser(orm *ent.Client) echo.MiddlewareFunc { switch err.(type) { case nil: + c.Set(context.UserKey, u) + return next(c) case *ent.NotFoundError: - return echo.NewHTTPError(http.StatusNotFound, "Not found") + return echo.NewHTTPError(http.StatusNotFound) default: c.Logger().Error(err) - return echo.NewHTTPError(http.StatusInternalServerError, "Internal server error") + return echo.NewHTTPError(http.StatusInternalServerError) } - - c.Set(context.UserKey, u) - - return next(c) } } } diff --git a/middleware/entity_test.go b/middleware/entity_test.go new file mode 100644 index 0000000..b368293 --- /dev/null +++ b/middleware/entity_test.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "fmt" + "testing" + + "goweb/context" + "goweb/ent" + "goweb/tests" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadUser(t *testing.T) { + ctx, _ := tests.NewContext(c.Web, "/") + ctx.SetParamNames("user") + ctx.SetParamValues(fmt.Sprintf("%d", usr.ID)) + _ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM)) + ctxUsr, ok := ctx.Get(context.UserKey).(*ent.User) + require.True(t, ok) + assert.Equal(t, usr.ID, ctxUsr.ID) +} diff --git a/middleware/log.go b/middleware/log.go index 0bd0be7..2c36431 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -7,6 +7,7 @@ import ( ) // LogRequestID includes the request ID in all logs for the given request +// This requires that middleware that includes the request ID first execute func LogRequestID() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { diff --git a/middleware/log_test.go b/middleware/log_test.go new file mode 100644 index 0000000..270fc5e --- /dev/null +++ b/middleware/log_test.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "bytes" + "fmt" + "testing" + + "goweb/tests" + + "github.com/labstack/echo/v4" + + "github.com/stretchr/testify/assert" + + echomw "github.com/labstack/echo/v4/middleware" +) + +func TestLogRequestID(t *testing.T) { + ctx, _ := tests.NewContext(c.Web, "/") + _ = tests.ExecuteMiddleware(ctx, echomw.RequestID()) + _ = tests.ExecuteMiddleware(ctx, LogRequestID()) + + var buf bytes.Buffer + ctx.Logger().SetOutput(&buf) + ctx.Logger().Info("test") + rID := ctx.Response().Header().Get(echo.HeaderXRequestID) + assert.Contains(t, buf.String(), fmt.Sprintf(`id":"%s"`, rID)) +}