Reorganized directories and packages.

This commit is contained in:
mikestefanello 2022-11-02 19:23:26 -04:00
parent 965fb540c7
commit dceb232cb2
61 changed files with 83 additions and 83 deletions

25
pkg/context/context.go Normal file
View file

@ -0,0 +1,25 @@
package context
import (
"context"
"errors"
)
const (
// AuthenticatedUserKey is the key value used to store the authenticated user in context
AuthenticatedUserKey = "auth_user"
// UserKey is the key value used to store a user in context
UserKey = "user"
// FormKey is the key value used to store a form in context
FormKey = "form"
// PasswordTokenKey is the key value used to store a password token in context
PasswordTokenKey = "password_token"
)
// IsCanceledError determines if an error is due to a context cancelation
func IsCanceledError(err error) bool {
return errors.Is(err, context.Canceled)
}

View file

@ -0,0 +1,24 @@
package context
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestIsCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
assert.False(t, IsCanceledError(ctx.Err()))
cancel()
assert.True(t, IsCanceledError(ctx.Err()))
ctx, cancel = context.WithTimeout(context.Background(), time.Microsecond)
time.Sleep(time.Microsecond * 5)
cancel()
assert.False(t, IsCanceledError(ctx.Err()))
assert.False(t, IsCanceledError(errors.New("test error")))
}

View file

@ -0,0 +1,171 @@
package controller
import (
"bytes"
"fmt"
"net/http"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/htmx"
"github.com/mikestefanello/pagoda/pkg/middleware"
"github.com/mikestefanello/pagoda/pkg/services"
"github.com/labstack/echo/v4"
)
// Controller provides base functionality and dependencies to routes.
// The proposed pattern is to embed a Controller in each individual route struct and to use
// the router to inject the container so your routes have access to the services within the container
type Controller struct {
// Container stores a services container which contains dependencies
Container *services.Container
}
// NewController creates a new Controller
func NewController(c *services.Container) Controller {
return Controller{
Container: c,
}
}
// RenderPage renders a Page as an HTTP response
func (c *Controller) RenderPage(ctx echo.Context, page Page) error {
var buf *bytes.Buffer
var err error
// Page name is required
if page.Name == "" {
return echo.NewHTTPError(http.StatusInternalServerError, "page render failed due to missing name")
}
// Use the app name in configuration if a value was not set
if page.AppName == "" {
page.AppName = c.Container.Config.App.Name
}
// Check if this is an HTMX non-boosted request which indicates that only partial
// content should be rendered
if page.HTMX.Request.Enabled && !page.HTMX.Request.Boosted {
// Parse and execute the templates only for the content portion of the page
// The templates used for this partial request will be:
// 1. The base htmx template which omits the layout and only includes the content template
// 2. The content template specified in Page.Name
// 3. All templates within the components directory
// Also included is the function map provided by the funcmap package
buf, err = c.Container.TemplateRenderer.
Parse().
Group("page:htmx").
Key(page.Name).
Base("htmx").
Files(
"htmx",
fmt.Sprintf("pages/%s", page.Name),
).
Directories("components").
Execute(page)
} else {
// Parse and execute the templates for the Page
// As mentioned in the documentation for the Page struct, the templates used for the page will be:
// 1. The layout/base template specified in Page.Layout
// 2. The content template specified in Page.Name
// 3. All templates within the components directory
// Also included is the function map provided by the funcmap package
buf, err = c.Container.TemplateRenderer.
Parse().
Group("page").
Key(page.Name).
Base(page.Layout).
Files(
fmt.Sprintf("layouts/%s", page.Layout),
fmt.Sprintf("pages/%s", page.Name),
).
Directories("components").
Execute(page)
}
if err != nil {
return c.Fail(err, "failed to parse and execute templates")
}
// Set the status code
ctx.Response().Status = page.StatusCode
// Set any headers
for k, v := range page.Headers {
ctx.Response().Header().Set(k, v)
}
// Apply the HTMX response, if one
if page.HTMX.Response != nil {
page.HTMX.Response.Apply(ctx)
}
// Cache this page, if caching was enabled
c.cachePage(ctx, page, buf)
return ctx.HTMLBlob(ctx.Response().Status, buf.Bytes())
}
// cachePage caches the HTML for a given Page if the Page has caching enabled
func (c *Controller) cachePage(ctx echo.Context, page Page, html *bytes.Buffer) {
if !page.Cache.Enabled || page.IsAuth {
return
}
// If no expiration time was provided, default to the configuration value
if page.Cache.Expiration == 0 {
page.Cache.Expiration = c.Container.Config.Cache.Expiration.Page
}
// Extract the headers
headers := make(map[string]string)
for k, v := range ctx.Response().Header() {
headers[k] = v[0]
}
// The request URL is used as the cache key so the middleware can serve the
// cached page on matching requests
key := ctx.Request().URL.String()
cp := middleware.CachedPage{
URL: key,
HTML: html.Bytes(),
Headers: headers,
StatusCode: ctx.Response().Status,
}
err := c.Container.Cache.
Set().
Group(middleware.CachedPageGroup).
Key(key).
Tags(page.Cache.Tags...).
Expiration(page.Cache.Expiration).
Data(cp).
Save(ctx.Request().Context())
switch {
case err == nil:
ctx.Logger().Info("cached page")
case !context.IsCanceledError(err):
ctx.Logger().Errorf("failed to cache page: %v", err)
}
}
// Redirect redirects to a given route name with optional route parameters
func (c *Controller) Redirect(ctx echo.Context, route string, routeParams ...interface{}) error {
url := ctx.Echo().Reverse(route, routeParams...)
if htmx.GetRequest(ctx).Boosted {
htmx.Response{
Redirect: url,
}.Apply(ctx)
return nil
} else {
return ctx.Redirect(http.StatusFound, url)
}
}
// Fail is a helper to fail a request by returning a 500 error and logging the error
func (c *Controller) Fail(err error, log string) error {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("%s: %v", log, err))
}

View file

@ -0,0 +1,187 @@
package controller
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/pkg/htmx"
"github.com/mikestefanello/pagoda/pkg/middleware"
"github.com/mikestefanello/pagoda/pkg/services"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/labstack/echo/v4"
)
var (
c *services.Container
)
func TestMain(m *testing.M) {
// Set the environment to test
config.SwitchEnvironment(config.EnvTest)
// Create a new container
c = services.NewContainer()
// Run tests
exitVal := m.Run()
// Shutdown the container
if err := c.Shutdown(); err != nil {
panic(err)
}
os.Exit(exitVal)
}
func TestController_Redirect(t *testing.T) {
c.Web.GET("/path/:first/and/:second", func(c echo.Context) error {
return nil
}).Name = "redirect-test"
ctx, _ := tests.NewContext(c.Web, "/abc")
ctr := NewController(c)
err := ctr.Redirect(ctx, "redirect-test", "one", "two")
require.NoError(t, err)
assert.Equal(t, "/path/one/and/two", ctx.Response().Header().Get(echo.HeaderLocation))
assert.Equal(t, http.StatusFound, ctx.Response().Status)
}
func TestController_RenderPage(t *testing.T) {
setup := func() (echo.Context, *httptest.ResponseRecorder, Controller, Page) {
ctx, rec := tests.NewContext(c.Web, "/test/TestController_RenderPage")
tests.InitSession(ctx)
ctr := NewController(c)
p := NewPage(ctx)
p.Name = "home"
p.Layout = "main"
p.Cache.Enabled = false
p.Headers["A"] = "b"
p.Headers["C"] = "d"
p.StatusCode = http.StatusCreated
return ctx, rec, ctr, p
}
t.Run("missing name", func(t *testing.T) {
// Rendering should fail if the Page has no name
ctx, _, ctr, p := setup()
p.Name = ""
err := ctr.RenderPage(ctx, p)
assert.Error(t, err)
})
t.Run("no page cache", func(t *testing.T) {
ctx, _, ctr, p := setup()
err := ctr.RenderPage(ctx, p)
require.NoError(t, err)
// Check status code and headers
assert.Equal(t, http.StatusCreated, ctx.Response().Status)
for k, v := range p.Headers {
assert.Equal(t, v, ctx.Response().Header().Get(k))
}
// Check the template cache
parsed, err := c.TemplateRenderer.Load("page", p.Name)
assert.NoError(t, err)
// Check that all expected templates were parsed.
// This includes the name, layout and all components
expectedTemplates := make(map[string]bool)
expectedTemplates[p.Name+config.TemplateExt] = true
expectedTemplates[p.Layout+config.TemplateExt] = true
components, err := ioutil.ReadDir(c.TemplateRenderer.GetTemplatesPath() + "/components")
require.NoError(t, err)
for _, f := range components {
expectedTemplates[f.Name()] = true
}
for _, v := range parsed.Template.Templates() {
delete(expectedTemplates, v.Name())
}
assert.Empty(t, expectedTemplates)
})
t.Run("htmx rendering", func(t *testing.T) {
ctx, _, ctr, p := setup()
p.HTMX.Request.Enabled = true
p.HTMX.Response = &htmx.Response{
Trigger: "trigger",
}
err := ctr.RenderPage(ctx, p)
require.NoError(t, err)
// Check HTMX header
assert.Equal(t, "trigger", ctx.Response().Header().Get(htmx.HeaderTrigger))
// Check the template cache
parsed, err := c.TemplateRenderer.Load("page:htmx", p.Name)
assert.NoError(t, err)
// Check that all expected templates were parsed.
// This includes the name, htmx and all components
expectedTemplates := make(map[string]bool)
expectedTemplates[p.Name+config.TemplateExt] = true
expectedTemplates["htmx"+config.TemplateExt] = true
components, err := ioutil.ReadDir(c.TemplateRenderer.GetTemplatesPath() + "/components")
require.NoError(t, err)
for _, f := range components {
expectedTemplates[f.Name()] = true
}
for _, v := range parsed.Template.Templates() {
delete(expectedTemplates, v.Name())
}
assert.Empty(t, expectedTemplates)
})
t.Run("page cache", func(t *testing.T) {
ctx, rec, ctr, p := setup()
p.Cache.Enabled = true
p.Cache.Tags = []string{"tag1"}
err := ctr.RenderPage(ctx, p)
require.NoError(t, err)
// Fetch from the cache
res, err := c.Cache.
Get().
Group(middleware.CachedPageGroup).
Key(p.URL).
Type(new(middleware.CachedPage)).
Fetch(context.Background())
require.NoError(t, err)
// Compare the cached page
cp, ok := res.(*middleware.CachedPage)
require.True(t, ok)
assert.Equal(t, p.URL, cp.URL)
assert.Equal(t, p.Headers, cp.Headers)
assert.Equal(t, p.StatusCode, cp.StatusCode)
assert.Equal(t, rec.Body.Bytes(), cp.HTML)
// Clear the tag
err = c.Cache.
Flush().
Tags(p.Cache.Tags[0]).
Execute(context.Background())
require.NoError(t, err)
// Refetch from the cache and expect no results
_, err = c.Cache.
Get().
Group(middleware.CachedPageGroup).
Key(p.URL).
Type(new(middleware.CachedPage)).
Fetch(context.Background())
assert.Error(t, err)
})
}

104
pkg/controller/form.go Normal file
View file

@ -0,0 +1,104 @@
package controller
import (
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
)
// FormSubmission represents the state of the submission of a form, not including the form itself
type FormSubmission struct {
// IsSubmitted indicates if the form has been submitted
IsSubmitted bool
// Errors stores a slice of error message strings keyed by form struct field name
Errors map[string][]string
}
// Process processes a submission for a form
func (f *FormSubmission) Process(ctx echo.Context, form interface{}) error {
f.Errors = make(map[string][]string)
f.IsSubmitted = true
// Validate the form
if err := ctx.Validate(form); err != nil {
f.setErrorMessages(err)
}
return nil
}
// HasErrors indicates if the submission has any validation errors
func (f FormSubmission) HasErrors() bool {
if f.Errors == nil {
return false
}
return len(f.Errors) > 0
}
// FieldHasErrors indicates if a given field on the form has any validation errors
func (f FormSubmission) FieldHasErrors(fieldName string) bool {
return len(f.GetFieldErrors(fieldName)) > 0
}
// SetFieldError sets an error message for a given field name
func (f *FormSubmission) SetFieldError(fieldName string, message string) {
if f.Errors == nil {
f.Errors = make(map[string][]string)
}
f.Errors[fieldName] = append(f.Errors[fieldName], message)
}
// GetFieldErrors gets the errors for a given field name
func (f FormSubmission) GetFieldErrors(fieldName string) []string {
if f.Errors == nil {
return []string{}
}
return f.Errors[fieldName]
}
// GetFieldStatusClass returns an HTML class based on the status of the field
func (f FormSubmission) GetFieldStatusClass(fieldName string) string {
if f.IsSubmitted {
if f.FieldHasErrors(fieldName) {
return "is-danger"
}
return "is-success"
}
return ""
}
// IsDone indicates if the submission is considered done which is when it has been submitted
// and there are no errors.
func (f FormSubmission) IsDone() bool {
return f.IsSubmitted && !f.HasErrors()
}
// setErrorMessages sets errors messages on the submission for all fields that failed validation
func (f *FormSubmission) setErrorMessages(err error) {
// Only this is supported right now
ves, ok := err.(validator.ValidationErrors)
if !ok {
return
}
for _, ve := range ves {
var message string
// Provide better error messages depending on the failed validation tag
// This should be expanded as you use additional tags in your validation
switch ve.Tag() {
case "required":
message = "This field is required."
case "email":
message = "Enter a valid email address."
case "eqfield":
message = "Does not match."
default:
message = "Invalid value."
}
// Add the error
f.SetFieldError(ve.Field(), message)
}
}

View file

@ -0,0 +1,36 @@
package controller
import (
"testing"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFormSubmission(t *testing.T) {
type formTest struct {
Name string `validate:"required"`
Email string `validate:"required,email"`
Submission FormSubmission
}
ctx, _ := tests.NewContext(c.Web, "/")
form := formTest{
Name: "",
Email: "a@a.com",
}
err := form.Submission.Process(ctx, form)
assert.NoError(t, err)
assert.True(t, form.Submission.HasErrors())
assert.True(t, form.Submission.FieldHasErrors("Name"))
assert.False(t, form.Submission.FieldHasErrors("Email"))
require.Len(t, form.Submission.GetFieldErrors("Name"), 1)
assert.Len(t, form.Submission.GetFieldErrors("Email"), 0)
assert.Equal(t, "This field is required.", form.Submission.GetFieldErrors("Name")[0])
assert.Equal(t, "is-danger", form.Submission.GetFieldStatusClass("Name"))
assert.Equal(t, "is-success", form.Submission.GetFieldStatusClass("Email"))
assert.False(t, form.Submission.IsDone())
}

161
pkg/controller/page.go Normal file
View file

@ -0,0 +1,161 @@
package controller
import (
"html/template"
"net/http"
"time"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/htmx"
"github.com/mikestefanello/pagoda/pkg/msg"
echomw "github.com/labstack/echo/v4/middleware"
"github.com/labstack/echo/v4"
)
// Page consists of all data that will be used to render a page response for a given controller.
// While it's not required for a controller to render a Page on a route, this is the common data
// object that will be passed to the templates, making it easy for all controllers to share
// functionality both on the back and frontend. The Page can be expanded to include anything else
// your app wants to support.
// Methods on this page also then become available in the templates, which can be more useful than
// the funcmap if your methods require data stored in the page, such as the context.
type Page struct {
// AppName stores the name of the application.
// If omitted, the configuration value will be used.
AppName string
// Title stores the title of the page
Title string
// Context stores the request context
Context echo.Context
// ToURL is a function to convert a route name and optional route parameters to a URL
ToURL func(name string, params ...interface{}) string
// Path stores the path of the current request
Path string
// URL stores the URL of the current request
URL string
// Data stores whatever additional data that needs to be passed to the templates.
// This is what the controller uses to pass the content of the page.
Data interface{}
// Form stores a struct that represents a form on the page.
// This should be a struct with fields for each form field, using both "form" and "validate" tags
// It should also contain a Submission field of type FormSubmission if you wish to have validation
// messagesa and markup presented to the user
Form interface{}
// Layout stores the name of the layout base template file which will be used when the page is rendered.
// This should match a template file located within the layouts directory inside the templates directory.
// The template extension should not be included in this value.
Layout string
// Name stores the name of the page as well as the name of the template file which will be used to render
// the content portion of the layout template.
// This should match a template file located within the pages directory inside the templates directory.
// The template extension should not be included in this value.
Name string
// IsHome stores whether the requested page is the home page or not
IsHome bool
// IsAuth stores whether or not the user is authenticated
IsAuth bool
// AuthUser stores the authenticated user
AuthUser *ent.User
// StatusCode stores the HTTP status code that will be returned
StatusCode int
// Metatags stores metatag values
Metatags struct {
// Description stores the description metatag value
Description string
// Keywords stores the keywords metatag values
Keywords []string
}
// Pager stores a pager which can be used to page lists of results
Pager Pager
// CSRF stores the CSRF token for the given request.
// This will only be populated if the CSRF middleware is in effect for the given request.
// If this is populated, all forms must include this value otherwise the requests will be rejected.
CSRF string
// Headers stores a list of HTTP headers and values to be set on the response
Headers map[string]string
// RequestID stores the ID of the given request.
// This will only be populated if the request ID middleware is in effect for the given request.
RequestID string
HTMX struct {
Request htmx.Request
Response *htmx.Response
}
// Cache stores values for caching the response of this page
Cache struct {
// Enabled dictates if the response of this page should be cached.
// Cached responses are served via middleware.
Enabled bool
// Expiration stores the amount of time that the cache entry should live for before expiring.
// If omitted, the configuration value will be used.
Expiration time.Duration
// Tags stores a list of tags to apply to the cache entry.
// These are useful when invalidating cache for dynamic events such as entity operations.
Tags []string
}
}
// NewPage creates and initiatizes a new Page for a given request context
func NewPage(ctx echo.Context) Page {
p := Page{
Context: ctx,
ToURL: ctx.Echo().Reverse,
Path: ctx.Request().URL.Path,
URL: ctx.Request().URL.String(),
StatusCode: http.StatusOK,
Pager: NewPager(ctx, DefaultItemsPerPage),
Headers: make(map[string]string),
RequestID: ctx.Response().Header().Get(echo.HeaderXRequestID),
}
p.IsHome = p.Path == "/"
if csrf := ctx.Get(echomw.DefaultCSRFConfig.ContextKey); csrf != nil {
p.CSRF = csrf.(string)
}
if u := ctx.Get(context.AuthenticatedUserKey); u != nil {
p.IsAuth = true
p.AuthUser = u.(*ent.User)
}
p.HTMX.Request = htmx.GetRequest(ctx)
return p
}
// GetMessages gets all flash messages for a given type.
// This allows for easy access to flash messages from the templates.
func (p Page) GetMessages(typ msg.Type) []template.HTML {
strs := msg.Get(p.Context, typ)
ret := make([]template.HTML, len(strs))
for k, v := range strs {
ret[k] = template.HTML(v)
}
return ret
}

View file

@ -0,0 +1,75 @@
package controller
import (
"net/http"
"testing"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/msg"
"github.com/mikestefanello/pagoda/pkg/tests"
echomw "github.com/labstack/echo/v4/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewPage(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
p := NewPage(ctx)
assert.Same(t, ctx, p.Context)
assert.NotNil(t, p.ToURL)
assert.Equal(t, "/", p.Path)
assert.Equal(t, "/", p.URL)
assert.Equal(t, http.StatusOK, p.StatusCode)
assert.Equal(t, NewPager(ctx, DefaultItemsPerPage), p.Pager)
assert.Empty(t, p.Headers)
assert.True(t, p.IsHome)
assert.False(t, p.IsAuth)
assert.Empty(t, p.CSRF)
assert.Empty(t, p.RequestID)
assert.False(t, p.Cache.Enabled)
ctx, _ = tests.NewContext(c.Web, "/abc?def=123")
usr, err := tests.CreateUser(c.ORM)
require.NoError(t, err)
ctx.Set(context.AuthenticatedUserKey, usr)
ctx.Set(echomw.DefaultCSRFConfig.ContextKey, "csrf")
p = NewPage(ctx)
assert.Equal(t, "/abc", p.Path)
assert.Equal(t, "/abc?def=123", p.URL)
assert.False(t, p.IsHome)
assert.True(t, p.IsAuth)
assert.Equal(t, usr, p.AuthUser)
assert.Equal(t, "csrf", p.CSRF)
}
func TestPage_GetMessages(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
tests.InitSession(ctx)
p := NewPage(ctx)
// Set messages
msgTests := make(map[msg.Type][]string)
msgTests[msg.TypeWarning] = []string{
"abc",
"def",
}
msgTests[msg.TypeInfo] = []string{
"123",
"456",
}
for typ, values := range msgTests {
for _, value := range values {
msg.Set(ctx, typ, value)
}
}
// Get the messages
for typ, values := range msgTests {
msgs := p.GetMessages(typ)
for i, message := range msgs {
assert.Equal(t, values[i], string(message))
}
}
}

80
pkg/controller/pager.go Normal file
View file

@ -0,0 +1,80 @@
package controller
import (
"math"
"strconv"
"github.com/labstack/echo/v4"
)
const (
// DefaultItemsPerPage stores the default amount of items per page
DefaultItemsPerPage = 20
// PageQueryKey stores the query key used to indicate the current page
PageQueryKey = "page"
)
// Pager provides a mechanism to allow a user to page results via a query parameter
type Pager struct {
// Items stores the total amount of items in the result set
Items int
// Page stores the current page number
Page int
// ItemsPerPage stores the amount of items to display per page
ItemsPerPage int
// Pages stores the total amount of pages in the result set
Pages int
}
// NewPager creates a new Pager
func NewPager(ctx echo.Context, itemsPerPage int) Pager {
p := Pager{
ItemsPerPage: itemsPerPage,
Page: 1,
}
if page := ctx.QueryParam(PageQueryKey); page != "" {
if pageInt, err := strconv.Atoi(page); err == nil {
if pageInt > 0 {
p.Page = pageInt
}
}
}
return p
}
// SetItems sets the amount of items in total for the pager and calculate the amount
// of total pages based off on the item per page.
// This should be used rather than setting either items or pages directly.
func (p *Pager) SetItems(items int) {
p.Items = items
p.Pages = int(math.Ceil(float64(items) / float64(p.ItemsPerPage)))
if p.Page > p.Pages {
p.Page = p.Pages
}
}
// IsBeginning determines if the pager is at the beginning of the pages
func (p Pager) IsBeginning() bool {
return p.Page == 1
}
// IsEnd determines if the pager is at the end of the pages
func (p Pager) IsEnd() bool {
return p.Page >= p.Pages
}
// GetOffset determines the offset of the results in order to get the items for
// the current page
func (p Pager) GetOffset() int {
if p.Page == 0 {
p.Page = 1
}
return (p.Page - 1) * p.ItemsPerPage
}

View file

@ -0,0 +1,67 @@
package controller
import (
"fmt"
"testing"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/assert"
)
func TestNewPager(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
pgr := NewPager(ctx, 10)
assert.Equal(t, 10, pgr.ItemsPerPage)
assert.Equal(t, 1, pgr.Page)
assert.Equal(t, 0, pgr.Items)
assert.Equal(t, 0, pgr.Pages)
ctx, _ = tests.NewContext(c.Web, fmt.Sprintf("/abc?%s=%d", PageQueryKey, 2))
pgr = NewPager(ctx, 10)
assert.Equal(t, 2, pgr.Page)
ctx, _ = tests.NewContext(c.Web, fmt.Sprintf("/abc?%s=%d", PageQueryKey, -2))
pgr = NewPager(ctx, 10)
assert.Equal(t, 1, pgr.Page)
}
func TestPager_SetItems(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
pgr := NewPager(ctx, 20)
pgr.SetItems(100)
assert.Equal(t, 100, pgr.Items)
assert.Equal(t, 5, pgr.Pages)
}
func TestPager_IsBeginning(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
pgr := NewPager(ctx, 20)
pgr.Pages = 10
assert.True(t, pgr.IsBeginning())
pgr.Page = 2
assert.False(t, pgr.IsBeginning())
pgr.Page = 1
assert.True(t, pgr.IsBeginning())
}
func TestPager_IsEnd(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
pgr := NewPager(ctx, 20)
pgr.Pages = 10
assert.False(t, pgr.IsEnd())
pgr.Page = 10
assert.True(t, pgr.IsEnd())
pgr.Page = 1
assert.False(t, pgr.IsEnd())
}
func TestPager_GetOffset(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
pgr := NewPager(ctx, 20)
assert.Equal(t, 0, pgr.GetOffset())
pgr.Page = 2
assert.Equal(t, 20, pgr.GetOffset())
pgr.Page = 3
assert.Equal(t, 40, pgr.GetOffset())
}

66
pkg/funcmap/funcmap.go Normal file
View file

@ -0,0 +1,66 @@
package funcmap
import (
"fmt"
"html/template"
"reflect"
"strings"
"github.com/mikestefanello/pagoda/config"
"github.com/Masterminds/sprig"
"github.com/labstack/gommon/random"
)
var (
// CacheBuster stores a random string used as a cache buster for static files.
CacheBuster = random.String(10)
)
// GetFuncMap provides a template function map
func GetFuncMap() template.FuncMap {
// See http://masterminds.github.io/sprig/ for available funcs
funcMap := sprig.FuncMap()
// Provide a list of custom functions
// Expand this as you add more functions to this package
// Avoid using a name already in use by sprig
f := template.FuncMap{
"hasField": HasField,
"file": File,
"link": Link,
}
for k, v := range f {
funcMap[k] = v
}
return funcMap
}
// HasField checks if an interface contains a given field
func HasField(v interface{}, name string) bool {
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
return false
}
return rv.FieldByName(name).IsValid()
}
// File appends a cache buster to a given filepath so it can remain cached until the app is restarted
func File(filepath string) string {
return fmt.Sprintf("/%s/%s?v=%s", config.StaticPrefix, filepath, CacheBuster)
}
// Link outputs HTML for a link element, providing the ability to dynamically set the active class
func Link(url, text, currentPath string, classes ...string) template.HTML {
if currentPath == url {
classes = append(classes, "is-active")
}
html := fmt.Sprintf(`<a class="%s" href="%s">%s</a>`, strings.Join(classes, " "), url, text)
return template.HTML(html)
}

View file

@ -0,0 +1,39 @@
package funcmap
import (
"fmt"
"testing"
"github.com/mikestefanello/pagoda/config"
"github.com/stretchr/testify/assert"
)
func TestHasField(t *testing.T) {
type example struct {
name string
}
var e example
assert.True(t, HasField(e, "name"))
assert.False(t, HasField(e, "abcd"))
}
func TestLink(t *testing.T) {
link := string(Link("/abc", "Text", "/abc"))
expected := `<a class="is-active" href="/abc">Text</a>`
assert.Equal(t, expected, link)
link = string(Link("/abc", "Text", "/abc", "first", "second"))
expected = `<a class="first second is-active" href="/abc">Text</a>`
assert.Equal(t, expected, link)
link = string(Link("/abc", "Text", "/def"))
expected = `<a class="" href="/abc">Text</a>`
assert.Equal(t, expected, link)
}
func TestGetFuncMap(t *testing.T) {
file := File("test.png")
expected := fmt.Sprintf("/%s/test.png?v=%s", config.StaticPrefix, CacheBuster)
assert.Equal(t, expected, file)
}

82
pkg/htmx/htmx.go Normal file
View file

@ -0,0 +1,82 @@
package htmx
import (
"net/http"
"github.com/labstack/echo/v4"
)
// Headers (https://htmx.org/docs/#requests)
const (
HeaderRequest = "HX-Request"
HeaderBoosted = "HX-Boosted"
HeaderTrigger = "HX-Trigger"
HeaderTriggerName = "HX-Trigger-Name"
HeaderTriggerAfterSwap = "HX-Trigger-After-Swap"
HeaderTriggerAfterSettle = "HX-Trigger-After-Settle"
HeaderTarget = "HX-Target"
HeaderPrompt = "HX-Prompt"
HeaderPush = "HX-Push"
HeaderRedirect = "HX-Redirect"
HeaderRefresh = "HX-Refresh"
)
type (
// Request contains data that HTMX provides during requests
Request struct {
Enabled bool
Boosted bool
Trigger string
TriggerName string
Target string
Prompt string
}
// Response contain data that the server can communicate back to HTMX
Response struct {
Push string
Redirect string
Refresh bool
Trigger string
TriggerAfterSwap string
TriggerAfterSettle string
NoContent bool
}
)
// GetRequest extracts HTMX data from the request
func GetRequest(ctx echo.Context) Request {
return Request{
Enabled: ctx.Request().Header.Get(HeaderRequest) == "true",
Boosted: ctx.Request().Header.Get(HeaderBoosted) == "true",
Trigger: ctx.Request().Header.Get(HeaderTrigger),
TriggerName: ctx.Request().Header.Get(HeaderTriggerName),
Target: ctx.Request().Header.Get(HeaderTarget),
Prompt: ctx.Request().Header.Get(HeaderPrompt),
}
}
// Apply applies data from a Response to a server response
func (r Response) Apply(ctx echo.Context) {
if r.Push != "" {
ctx.Response().Header().Set(HeaderPush, r.Push)
}
if r.Redirect != "" {
ctx.Response().Header().Set(HeaderRedirect, r.Redirect)
}
if r.Refresh {
ctx.Response().Header().Set(HeaderRefresh, "true")
}
if r.Trigger != "" {
ctx.Response().Header().Set(HeaderTrigger, r.Trigger)
}
if r.TriggerAfterSwap != "" {
ctx.Response().Header().Set(HeaderTriggerAfterSwap, r.TriggerAfterSwap)
}
if r.TriggerAfterSettle != "" {
ctx.Response().Header().Set(HeaderTriggerAfterSettle, r.TriggerAfterSettle)
}
if r.NoContent {
ctx.Response().Status = http.StatusNoContent
}
}

52
pkg/htmx/htmx_test.go Normal file
View file

@ -0,0 +1,52 @@
package htmx
import (
"net/http"
"testing"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/labstack/echo/v4"
)
func TestSetRequest(t *testing.T) {
ctx, _ := tests.NewContext(echo.New(), "/")
ctx.Request().Header.Set(HeaderRequest, "true")
ctx.Request().Header.Set(HeaderBoosted, "true")
ctx.Request().Header.Set(HeaderTrigger, "a")
ctx.Request().Header.Set(HeaderTriggerName, "b")
ctx.Request().Header.Set(HeaderTarget, "c")
ctx.Request().Header.Set(HeaderPrompt, "d")
r := GetRequest(ctx)
assert.Equal(t, true, r.Enabled)
assert.Equal(t, true, r.Boosted)
assert.Equal(t, "a", r.Trigger)
assert.Equal(t, "b", r.TriggerName)
assert.Equal(t, "c", r.Target)
assert.Equal(t, "d", r.Prompt)
}
func TestResponse_Apply(t *testing.T) {
ctx, _ := tests.NewContext(echo.New(), "/")
r := Response{
Push: "a",
Redirect: "b",
Refresh: true,
Trigger: "c",
TriggerAfterSwap: "d",
TriggerAfterSettle: "e",
NoContent: true,
}
r.Apply(ctx)
assert.Equal(t, "a", ctx.Response().Header().Get(HeaderPush))
assert.Equal(t, "b", ctx.Response().Header().Get(HeaderRedirect))
assert.Equal(t, "true", ctx.Response().Header().Get(HeaderRefresh))
assert.Equal(t, "c", ctx.Response().Header().Get(HeaderTrigger))
assert.Equal(t, "d", ctx.Response().Header().Get(HeaderTriggerAfterSwap))
assert.Equal(t, "e", ctx.Response().Header().Get(HeaderTriggerAfterSettle))
assert.Equal(t, http.StatusNoContent, ctx.Response().Status)
}

108
pkg/middleware/auth.go Normal file
View file

@ -0,0 +1,108 @@
package middleware
import (
"fmt"
"net/http"
"strconv"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/msg"
"github.com/mikestefanello/pagoda/pkg/services"
"github.com/labstack/echo/v4"
)
// LoadAuthenticatedUser loads the authenticated user, if one, and stores in context
func LoadAuthenticatedUser(authClient *services.AuthClient) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
u, err := authClient.GetAuthenticatedUser(c)
switch err.(type) {
case *ent.NotFoundError:
c.Logger().Warn("auth user not found")
case services.NotAuthenticatedError:
case nil:
c.Set(context.AuthenticatedUserKey, u)
c.Logger().Infof("auth user loaded in to context: %d", u.ID)
default:
return echo.NewHTTPError(
http.StatusInternalServerError,
fmt.Sprintf("error querying for authenticated user: %v", err),
)
}
return next(c)
}
}
}
// LoadValidPasswordToken loads a valid password token entity that matches the user and token
// provided in path parameters
// If the token is invalid, the user will be redirected to the forgot password route
// This requires that the user owning the token is loaded in to context
func LoadValidPasswordToken(authClient *services.AuthClient) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Extract the user parameter
if c.Get(context.UserKey) == nil {
return echo.NewHTTPError(http.StatusInternalServerError)
}
usr := c.Get(context.UserKey).(*ent.User)
// Extract the token ID
tokenID, err := strconv.Atoi(c.Param("password_token"))
if err != nil {
return echo.NewHTTPError(http.StatusNotFound)
}
// Attempt to load a valid password token
token, err := authClient.GetValidPasswordToken(
c,
usr.ID,
tokenID,
c.Param("token"),
)
switch err.(type) {
case nil:
c.Set(context.PasswordTokenKey, token)
return next(c)
case services.InvalidPasswordTokenError:
msg.Warning(c, "The link is either invalid or has expired. Please request a new one.")
return c.Redirect(http.StatusFound, c.Echo().Reverse("forgot_password"))
default:
return echo.NewHTTPError(
http.StatusInternalServerError,
fmt.Sprintf("error loading password token: %v", err),
)
}
}
}
}
// RequireAuthentication requires that the user be authenticated in order to proceed
func RequireAuthentication() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if u := c.Get(context.AuthenticatedUserKey); u == nil {
return echo.NewHTTPError(http.StatusUnauthorized)
}
return next(c)
}
}
}
// RequireNoAuthentication requires that the user not be authenticated in order to proceed
func RequireNoAuthentication() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if u := c.Get(context.AuthenticatedUserKey); u != nil {
return echo.NewHTTPError(http.StatusForbidden)
}
return next(c)
}
}
}

111
pkg/middleware/auth_test.go Normal file
View file

@ -0,0 +1,111 @@
package middleware
import (
"fmt"
"net/http"
"testing"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)
func TestLoadAuthenticatedUser(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
tests.InitSession(ctx)
mw := LoadAuthenticatedUser(c.Auth)
// Not authenticated
_ = tests.ExecuteMiddleware(ctx, mw)
assert.Nil(t, ctx.Get(context.AuthenticatedUserKey))
// Login
err := c.Auth.Login(ctx, usr.ID)
require.NoError(t, err)
// Verify the midldeware returns the authenticated user
_ = tests.ExecuteMiddleware(ctx, mw)
require.NotNil(t, ctx.Get(context.AuthenticatedUserKey))
ctxUsr, ok := ctx.Get(context.AuthenticatedUserKey).(*ent.User)
require.True(t, ok)
assert.Equal(t, usr.ID, ctxUsr.ID)
}
func TestRequireAuthentication(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
tests.InitSession(ctx)
// Not logged in
err := tests.ExecuteMiddleware(ctx, RequireAuthentication())
tests.AssertHTTPErrorCode(t, err, http.StatusUnauthorized)
// Login
err = c.Auth.Login(ctx, usr.ID)
require.NoError(t, err)
_ = tests.ExecuteMiddleware(ctx, LoadAuthenticatedUser(c.Auth))
// Logged in
err = tests.ExecuteMiddleware(ctx, RequireAuthentication())
assert.Nil(t, err)
}
func TestRequireNoAuthentication(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
tests.InitSession(ctx)
// Not logged in
err := tests.ExecuteMiddleware(ctx, RequireNoAuthentication())
assert.Nil(t, err)
// Login
err = c.Auth.Login(ctx, usr.ID)
require.NoError(t, err)
_ = tests.ExecuteMiddleware(ctx, LoadAuthenticatedUser(c.Auth))
// Logged in
err = tests.ExecuteMiddleware(ctx, RequireNoAuthentication())
tests.AssertHTTPErrorCode(t, err, http.StatusForbidden)
}
func TestLoadValidPasswordToken(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
tests.InitSession(ctx)
// Missing user context
err := tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
tests.AssertHTTPErrorCode(t, err, http.StatusInternalServerError)
// Add user and password token context but no token and expect a redirect
ctx.SetParamNames("user", "password_token")
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), "1")
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
assert.NoError(t, err)
assert.Equal(t, http.StatusFound, ctx.Response().Status)
// Add user context and invalid password token and expect a redirect
ctx.SetParamNames("user", "password_token", "token")
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), "1", "faketoken")
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
assert.NoError(t, err)
assert.Equal(t, http.StatusFound, ctx.Response().Status)
// Create a valid token
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
require.NoError(t, err)
// Add user and valid password token
ctx.SetParamNames("user", "password_token", "token")
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), fmt.Sprintf("%d", pt.ID), token)
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
assert.Nil(t, err)
ctxPt, ok := ctx.Get(context.PasswordTokenKey).(*ent.PasswordToken)
require.True(t, ok)
assert.Equal(t, pt.ID, ctxPt.ID)
}

102
pkg/middleware/cache.go Normal file
View file

@ -0,0 +1,102 @@
package middleware
import (
"fmt"
"net/http"
"time"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/services"
"github.com/go-redis/redis/v8"
"github.com/labstack/echo/v4"
)
// CachedPageGroup stores the cache group for cached pages
const CachedPageGroup = "page"
// CachedPage is what is used to store a rendered Page in the cache
type CachedPage struct {
// URL stores the URL of the requested page
URL string
// HTML stores the complete HTML of the rendered Page
HTML []byte
// StatusCode stores the HTTP status code
StatusCode int
// Headers stores the HTTP headers
Headers map[string]string
}
// ServeCachedPage attempts to load a page from the cache by matching on the complete request URL
// If a page is cached for the requested URL, it will be served here and the request terminated.
// Any request made by an authenticated user or that is not a GET will be skipped.
func ServeCachedPage(ch *services.CacheClient) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Skip non GET requests
if c.Request().Method != http.MethodGet {
return next(c)
}
// Skip if the user is authenticated
if c.Get(context.AuthenticatedUserKey) != nil {
return next(c)
}
// Attempt to load from cache
res, err := ch.
Get().
Group(CachedPageGroup).
Key(c.Request().URL.String()).
Type(new(CachedPage)).
Fetch(c.Request().Context())
if err != nil {
switch {
case err == redis.Nil:
c.Logger().Info("no cached page found")
case context.IsCanceledError(err):
return nil
default:
c.Logger().Errorf("failed getting cached page: %v", err)
}
return next(c)
}
page, ok := res.(*CachedPage)
if !ok {
c.Logger().Errorf("failed casting cached page")
return next(c)
}
// Set any headers
if page.Headers != nil {
for k, v := range page.Headers {
c.Response().Header().Set(k, v)
}
}
c.Logger().Info("serving cached page")
return c.HTMLBlob(page.StatusCode, page.HTML)
}
}
}
// CacheControl sets a Cache-Control header with a given max age
func CacheControl(maxAge time.Duration) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
v := "no-cache, no-store"
if maxAge > 0 {
v = fmt.Sprintf("public, max-age=%.0f", maxAge.Seconds())
}
c.Response().Header().Set("Cache-Control", v)
return next(c)
}
}
}

View file

@ -0,0 +1,59 @@
package middleware
import (
"context"
"net/http"
"testing"
"time"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)
func TestServeCachedPage(t *testing.T) {
// Cache a page
cp := CachedPage{
URL: "/cache",
HTML: []byte("html"),
Headers: make(map[string]string),
StatusCode: http.StatusCreated,
}
cp.Headers["a"] = "b"
cp.Headers["c"] = "d"
err := c.Cache.
Set().
Group(CachedPageGroup).
Key(cp.URL).
Data(cp).
Save(context.Background())
require.NoError(t, err)
// Request the URL of the cached page
ctx, rec := tests.NewContext(c.Web, cp.URL)
err = tests.ExecuteMiddleware(ctx, ServeCachedPage(c.Cache))
assert.NoError(t, err)
assert.Equal(t, cp.StatusCode, ctx.Response().Status)
assert.Equal(t, cp.Headers["a"], ctx.Response().Header().Get("a"))
assert.Equal(t, cp.Headers["c"], ctx.Response().Header().Get("c"))
assert.Equal(t, cp.HTML, rec.Body.Bytes())
// Login and try again
tests.InitSession(ctx)
err = c.Auth.Login(ctx, usr.ID)
require.NoError(t, err)
_ = tests.ExecuteMiddleware(ctx, LoadAuthenticatedUser(c.Auth))
err = tests.ExecuteMiddleware(ctx, ServeCachedPage(c.Cache))
assert.Nil(t, err)
}
func TestCacheControl(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
_ = tests.ExecuteMiddleware(ctx, CacheControl(time.Second*5))
assert.Equal(t, "public, max-age=5", ctx.Response().Header().Get("Cache-Control"))
_ = tests.ExecuteMiddleware(ctx, CacheControl(0))
assert.Equal(t, "no-cache, no-store", ctx.Response().Header().Get("Cache-Control"))
}

43
pkg/middleware/entity.go Normal file
View file

@ -0,0 +1,43 @@
package middleware
import (
"fmt"
"net/http"
"strconv"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/ent/user"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/labstack/echo/v4"
)
// LoadUser loads the user based on the ID provided as a path parameter
func LoadUser(orm *ent.Client) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
userID, err := strconv.Atoi(c.Param("user"))
if err != nil {
return echo.NewHTTPError(http.StatusNotFound)
}
u, err := orm.User.
Query().
Where(user.ID(userID)).
Only(c.Request().Context())
switch err.(type) {
case nil:
c.Set(context.UserKey, u)
return next(c)
case *ent.NotFoundError:
return echo.NewHTTPError(http.StatusNotFound)
default:
return echo.NewHTTPError(
http.StatusInternalServerError,
fmt.Sprintf("error querying user: %v", err),
)
}
}
}
}

View file

@ -0,0 +1,23 @@
package middleware
import (
"fmt"
"testing"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoadUser(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
ctx.SetParamNames("user")
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID))
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
ctxUsr, ok := ctx.Get(context.UserKey).(*ent.User)
require.True(t, ok)
assert.Equal(t, usr.ID, ctxUsr.ID)
}

20
pkg/middleware/log.go Normal file
View file

@ -0,0 +1,20 @@
package middleware
import (
"fmt"
"github.com/labstack/echo/v4"
)
// LogRequestID includes the request ID in all logs for the given request
// This requires that middleware that includes the request ID first execute
func LogRequestID() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
rID := c.Response().Header().Get(echo.HeaderXRequestID)
format := `{"time":"${time_rfc3339_nano}","id":"%s","level":"${level}","prefix":"${prefix}","file":"${short_file}","line":"${line}"}`
c.Logger().SetHeader(fmt.Sprintf(format, rID))
return next(c)
}
}
}

View file

@ -0,0 +1,27 @@
package middleware
import (
"bytes"
"fmt"
"testing"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
echomw "github.com/labstack/echo/v4/middleware"
)
func TestLogRequestID(t *testing.T) {
ctx, _ := tests.NewContext(c.Web, "/")
_ = tests.ExecuteMiddleware(ctx, echomw.RequestID())
_ = tests.ExecuteMiddleware(ctx, LogRequestID())
var buf bytes.Buffer
ctx.Logger().SetOutput(&buf)
ctx.Logger().Info("test")
rID := ctx.Response().Header().Get(echo.HeaderXRequestID)
assert.Contains(t, buf.String(), fmt.Sprintf(`id":"%s"`, rID))
}

View file

@ -0,0 +1,40 @@
package middleware
import (
"os"
"testing"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/services"
"github.com/mikestefanello/pagoda/pkg/tests"
)
var (
c *services.Container
usr *ent.User
)
func TestMain(m *testing.M) {
// Set the environment to test
config.SwitchEnvironment(config.EnvTest)
// Create a new container
c = services.NewContainer()
// Create a user
var err error
if usr, err = tests.CreateUser(c.ORM); err != nil {
panic(err)
}
// Run tests
exitVal := m.Run()
// Shutdown the container
if err = c.Shutdown(); err != nil {
panic(err)
}
os.Exit(exitVal)
}

92
pkg/msg/msg.go Normal file
View file

@ -0,0 +1,92 @@
package msg
import (
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
)
// Type is a message type
type Type string
const (
// TypeSuccess represents a success message type
TypeSuccess Type = "success"
// TypeInfo represents a info message type
TypeInfo Type = "info"
// TypeWarning represents a warning message type
TypeWarning Type = "warning"
// TypeDanger represents a danger message type
TypeDanger Type = "danger"
)
const (
// sessionName stores the name of the session which contains flash messages
sessionName = "msg"
)
// Success sets a success flash message
func Success(ctx echo.Context, message string) {
Set(ctx, TypeSuccess, message)
}
// Info sets an info flash message
func Info(ctx echo.Context, message string) {
Set(ctx, TypeInfo, message)
}
// Warning sets a warning flash message
func Warning(ctx echo.Context, message string) {
Set(ctx, TypeWarning, message)
}
// Danger sets a danger flash message
func Danger(ctx echo.Context, message string) {
Set(ctx, TypeDanger, message)
}
// Set adds a new flash message of a given type into the session storage
// Errors will logged and not returned
func Set(ctx echo.Context, typ Type, message string) {
if sess, err := getSession(ctx); err == nil {
sess.AddFlash(message, string(typ))
save(ctx, sess)
}
}
// Get gets flash messages of a given type from the session storage
// Errors will logged and not returned
func Get(ctx echo.Context, typ Type) []string {
var msgs []string
if sess, err := getSession(ctx); err == nil {
if flash := sess.Flashes(string(typ)); len(flash) > 0 {
save(ctx, sess)
for _, m := range flash {
msgs = append(msgs, m.(string))
}
}
}
return msgs
}
// getSession gets the flash message session
func getSession(ctx echo.Context) (*sessions.Session, error) {
sess, err := session.Get(sessionName, ctx)
if err != nil {
ctx.Logger().Errorf("cannot load flash message session: %v", err)
}
return sess, err
}
// save saves the flash message session
func save(ctx echo.Context, sess *sessions.Session) {
if err := sess.Save(ctx.Request(), ctx.Response()); err != nil {
ctx.Logger().Errorf("failed to set flash message: %v", err)
}
}

46
pkg/msg/msg_test.go Normal file
View file

@ -0,0 +1,46 @@
package msg
import (
"testing"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/labstack/echo/v4"
)
func TestMsg(t *testing.T) {
e := echo.New()
ctx, _ := tests.NewContext(e, "/")
tests.InitSession(ctx)
assertMsg := func(typ Type, message string) {
ret := Get(ctx, typ)
require.Len(t, ret, 1)
assert.Equal(t, message, ret[0])
ret = Get(ctx, typ)
require.Len(t, ret, 0)
}
text := "aaa"
Success(ctx, text)
assertMsg(TypeSuccess, text)
text = "bbb"
Info(ctx, text)
assertMsg(TypeInfo, text)
text = "ccc"
Danger(ctx, text)
assertMsg(TypeDanger, text)
text = "ddd"
Warning(ctx, text)
assertMsg(TypeWarning, text)
text = "eee"
Set(ctx, TypeSuccess, text)
assertMsg(TypeSuccess, text)
}

69
pkg/routes/about.go Normal file
View file

@ -0,0 +1,69 @@
package routes
import (
"html/template"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/labstack/echo/v4"
)
type (
about struct {
controller.Controller
}
aboutData struct {
ShowCacheWarning bool
FrontendTabs []aboutTab
BackendTabs []aboutTab
}
aboutTab struct {
Title string
Body template.HTML
}
)
func (c *about) Get(ctx echo.Context) error {
page := controller.NewPage(ctx)
page.Layout = "main"
page.Name = "about"
page.Title = "About"
// This page will be cached!
page.Cache.Enabled = true
page.Cache.Tags = []string{"page_about", "page:list"}
// A simple example of how the Data field can contain anything you want to send to the templates
// even though you wouldn't normally send markup like this
page.Data = aboutData{
ShowCacheWarning: true,
FrontendTabs: []aboutTab{
{
Title: "HTMX",
Body: template.HTML(`Completes HTML as a hypertext by providing attributes to AJAXify anything and much more. Visit <a href="https://htmx.org/">htmx.org</a> to learn more.`),
},
{
Title: "Alpine.js",
Body: template.HTML(`Drop-in, Vue-like functionality written directly in your markup. Visit <a href="https://alpinejs.dev/">alpinejs.dev</a> to learn more.`),
},
{
Title: "Bulma",
Body: template.HTML(`Ready-to-use frontend components that you can easily combine to build responsive web interfaces with no JavaScript requirements. Visit <a href="https://bulma.io/">bulma.io</a> to learn more.`),
},
},
BackendTabs: []aboutTab{
{
Title: "Echo",
Body: template.HTML(`High performance, extensible, minimalist Go web framework. Visit <a href="https://echo.labstack.com/">echo.labstack.com</a> to learn more.`),
},
{
Title: "Ent",
Body: template.HTML(`Simple, yet powerful ORM for modeling and querying data. Visit <a href="https://entgo.io/">entgo.io</a> to learn more.`),
},
},
}
return c.RenderPage(ctx, page)
}

23
pkg/routes/about_test.go Normal file
View file

@ -0,0 +1,23 @@
package routes
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
// Simple example of how to test routes and their markup using the test HTTP server spun up within
// this test package
func TestAbout_Get(t *testing.T) {
doc := request(t).
setRoute("about").
get().
assertStatusCode(http.StatusOK).
toDoc()
// Goquery is an excellent package to use for testing HTML markup
h1 := doc.Find("h1.title")
assert.Len(t, h1.Nodes, 1)
assert.Equal(t, "About", h1.Text())
}

65
pkg/routes/contact.go Normal file
View file

@ -0,0 +1,65 @@
package routes
import (
"fmt"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/labstack/echo/v4"
)
type (
contact struct {
controller.Controller
}
contactForm struct {
Email string `form:"email" validate:"required,email"`
Message string `form:"message" validate:"required"`
Submission controller.FormSubmission
}
)
func (c *contact) Get(ctx echo.Context) error {
page := controller.NewPage(ctx)
page.Layout = "main"
page.Name = "contact"
page.Title = "Contact us"
page.Form = contactForm{}
if form := ctx.Get(context.FormKey); form != nil {
page.Form = form.(*contactForm)
}
return c.RenderPage(ctx, page)
}
func (c *contact) Post(ctx echo.Context) error {
var form contactForm
ctx.Set(context.FormKey, &form)
// Parse the form values
if err := ctx.Bind(&form); err != nil {
return c.Fail(err, "unable to bind form")
}
if err := form.Submission.Process(ctx, form); err != nil {
return c.Fail(err, "unable to process form submission")
}
if !form.Submission.HasErrors() {
err := c.Container.Mail.
Compose().
To(form.Email).
Subject("Contact form submitted").
Body(fmt.Sprintf("The message is: %s", form.Message)).
Send(ctx)
if err != nil {
return c.Fail(err, "unable to send email")
}
}
return c.Get(ctx)
}

42
pkg/routes/error.go Normal file
View file

@ -0,0 +1,42 @@
package routes
import (
"net/http"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/labstack/echo/v4"
)
type errorHandler struct {
controller.Controller
}
func (e *errorHandler) Get(err error, ctx echo.Context) {
if ctx.Response().Committed || context.IsCanceledError(err) {
return
}
code := http.StatusInternalServerError
if he, ok := err.(*echo.HTTPError); ok {
code = he.Code
}
if code >= 500 {
ctx.Logger().Error(err)
} else {
ctx.Logger().Info(err)
}
page := controller.NewPage(ctx)
page.Layout = "main"
page.Title = http.StatusText(code)
page.Name = "error"
page.StatusCode = code
page.HTMX.Request.Enabled = false
if err = e.RenderPage(ctx, page); err != nil {
ctx.Logger().Error(err)
}
}

View file

@ -0,0 +1,100 @@
package routes
import (
"fmt"
"strings"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/ent/user"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/mikestefanello/pagoda/pkg/msg"
"github.com/labstack/echo/v4"
)
type (
forgotPassword struct {
controller.Controller
}
forgotPasswordForm struct {
Email string `form:"email" validate:"required,email"`
Submission controller.FormSubmission
}
)
func (c *forgotPassword) Get(ctx echo.Context) error {
page := controller.NewPage(ctx)
page.Layout = "auth"
page.Name = "forgot-password"
page.Title = "Forgot password"
page.Form = forgotPasswordForm{}
if form := ctx.Get(context.FormKey); form != nil {
page.Form = form.(*forgotPasswordForm)
}
return c.RenderPage(ctx, page)
}
func (c *forgotPassword) Post(ctx echo.Context) error {
var form forgotPasswordForm
ctx.Set(context.FormKey, &form)
succeed := func() error {
ctx.Set(context.FormKey, nil)
msg.Success(ctx, "An email containing a link to reset your password will be sent to this address if it exists in our system.")
return c.Get(ctx)
}
// Parse the form values
if err := ctx.Bind(&form); err != nil {
return c.Fail(err, "unable to parse forgot password form")
}
if err := form.Submission.Process(ctx, form); err != nil {
return c.Fail(err, "unable to process form submission")
}
if form.Submission.HasErrors() {
return c.Get(ctx)
}
// Attempt to load the user
u, err := c.Container.ORM.User.
Query().
Where(user.Email(strings.ToLower(form.Email))).
Only(ctx.Request().Context())
switch err.(type) {
case *ent.NotFoundError:
return succeed()
case nil:
default:
return c.Fail(err, "error querying user during forgot password")
}
// Generate the token
token, pt, err := c.Container.Auth.GeneratePasswordResetToken(ctx, u.ID)
if err != nil {
return c.Fail(err, "error generating password reset token")
}
ctx.Logger().Infof("generated password reset token for user %d", u.ID)
// Email the user
url := ctx.Echo().Reverse("reset_password", u.ID, pt.ID, token)
err = c.Container.Mail.
Compose().
To(u.Email).
Subject("Reset your password").
Body(fmt.Sprintf("Go here to reset your password: %s", url)).
Send(ctx)
if err != nil {
return c.Fail(err, "error sending password reset email")
}
return succeed()
}

46
pkg/routes/home.go Normal file
View file

@ -0,0 +1,46 @@
package routes
import (
"fmt"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/labstack/echo/v4"
)
type (
home struct {
controller.Controller
}
post struct {
Title string
Body string
}
)
func (c *home) Get(ctx echo.Context) error {
page := controller.NewPage(ctx)
page.Layout = "main"
page.Name = "home"
page.Metatags.Description = "Welcome to the homepage."
page.Metatags.Keywords = []string{"Go", "MVC", "Web", "Software"}
page.Pager = controller.NewPager(ctx, 4)
page.Data = c.fetchPosts(&page.Pager)
return c.RenderPage(ctx, page)
}
// fetchPosts is an mock example of fetching posts to illustrate how paging works
func (c *home) fetchPosts(pager *controller.Pager) []post {
pager.SetItems(20)
posts := make([]post, 20)
for k := range posts {
posts[k] = post{
Title: fmt.Sprintf("Post example #%d", k+1),
Body: fmt.Sprintf("Lorem ipsum example #%d ddolor sit amet, consectetur adipiscing elit. Nam elementum vulputate tristique.", k+1),
}
}
return posts[pager.GetOffset() : pager.GetOffset()+pager.ItemsPerPage]
}

94
pkg/routes/login.go Normal file
View file

@ -0,0 +1,94 @@
package routes
import (
"fmt"
"strings"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/ent/user"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/mikestefanello/pagoda/pkg/msg"
"github.com/labstack/echo/v4"
)
type (
login struct {
controller.Controller
}
loginForm struct {
Email string `form:"email" validate:"required,email"`
Password string `form:"password" validate:"required"`
Submission controller.FormSubmission
}
)
func (c *login) Get(ctx echo.Context) error {
page := controller.NewPage(ctx)
page.Layout = "auth"
page.Name = "login"
page.Title = "Log in"
page.Form = loginForm{}
if form := ctx.Get(context.FormKey); form != nil {
page.Form = form.(*loginForm)
}
return c.RenderPage(ctx, page)
}
func (c *login) Post(ctx echo.Context) error {
var form loginForm
ctx.Set(context.FormKey, &form)
authFailed := func() error {
form.Submission.SetFieldError("Email", "")
form.Submission.SetFieldError("Password", "")
msg.Danger(ctx, "Invalid credentials. Please try again.")
return c.Get(ctx)
}
// Parse the form values
if err := ctx.Bind(&form); err != nil {
return c.Fail(err, "unable to parse login form")
}
if err := form.Submission.Process(ctx, form); err != nil {
return c.Fail(err, "unable to process form submission")
}
if form.Submission.HasErrors() {
return c.Get(ctx)
}
// Attempt to load the user
u, err := c.Container.ORM.User.
Query().
Where(user.Email(strings.ToLower(form.Email))).
Only(ctx.Request().Context())
switch err.(type) {
case *ent.NotFoundError:
return authFailed()
case nil:
default:
return c.Fail(err, "error querying user during login")
}
// Check if the password is correct
err = c.Container.Auth.CheckPassword(form.Password, u.Password)
if err != nil {
return authFailed()
}
// Log the user in
err = c.Container.Auth.Login(ctx, u.ID)
if err != nil {
return c.Fail(err, "unable to log in user")
}
msg.Success(ctx, fmt.Sprintf("Welcome back, <strong>%s</strong>. You are now logged in.", u.Name))
return c.Redirect(ctx, "home")
}

21
pkg/routes/logout.go Normal file
View file

@ -0,0 +1,21 @@
package routes
import (
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/mikestefanello/pagoda/pkg/msg"
"github.com/labstack/echo/v4"
)
type logout struct {
controller.Controller
}
func (l *logout) Get(c echo.Context) error {
if err := l.Container.Auth.Logout(c); err == nil {
msg.Success(c, "You have been logged out successfully.")
} else {
msg.Danger(c, "An error occurred. Please try again.")
}
return l.Redirect(c, "home")
}

122
pkg/routes/register.go Normal file
View file

@ -0,0 +1,122 @@
package routes
import (
"fmt"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/mikestefanello/pagoda/pkg/msg"
"github.com/labstack/echo/v4"
)
type (
register struct {
controller.Controller
}
registerForm struct {
Name string `form:"name" validate:"required"`
Email string `form:"email" validate:"required,email"`
Password string `form:"password" validate:"required"`
ConfirmPassword string `form:"password-confirm" validate:"required,eqfield=Password"`
Submission controller.FormSubmission
}
)
func (c *register) Get(ctx echo.Context) error {
page := controller.NewPage(ctx)
page.Layout = "auth"
page.Name = "register"
page.Title = "Register"
page.Form = registerForm{}
if form := ctx.Get(context.FormKey); form != nil {
page.Form = form.(*registerForm)
}
return c.RenderPage(ctx, page)
}
func (c *register) Post(ctx echo.Context) error {
var form registerForm
ctx.Set(context.FormKey, &form)
// Parse the form values
if err := ctx.Bind(&form); err != nil {
return c.Fail(err, "unable to parse register form")
}
if err := form.Submission.Process(ctx, form); err != nil {
return c.Fail(err, "unable to process form submission")
}
if form.Submission.HasErrors() {
return c.Get(ctx)
}
// Hash the password
pwHash, err := c.Container.Auth.HashPassword(form.Password)
if err != nil {
return c.Fail(err, "unable to hash password")
}
// Attempt creating the user
u, err := c.Container.ORM.User.
Create().
SetName(form.Name).
SetEmail(form.Email).
SetPassword(pwHash).
Save(ctx.Request().Context())
switch err.(type) {
case nil:
ctx.Logger().Infof("user created: %s", u.Name)
case *ent.ConstraintError:
msg.Warning(ctx, "A user with this email address already exists. Please log in.")
return c.Redirect(ctx, "login")
default:
return c.Fail(err, "unable to create user")
}
// Log the user in
err = c.Container.Auth.Login(ctx, u.ID)
if err != nil {
ctx.Logger().Errorf("unable to log in: %v", err)
msg.Info(ctx, "Your account has been created.")
return c.Redirect(ctx, "login")
}
msg.Success(ctx, "Your account has been created. You are now logged in.")
// Send the verification email
c.sendVerificationEmail(ctx, u)
return c.Redirect(ctx, "home")
}
func (c *register) sendVerificationEmail(ctx echo.Context, usr *ent.User) {
// Generate a token
token, err := c.Container.Auth.GenerateEmailVerificationToken(usr.Email)
if err != nil {
ctx.Logger().Errorf("unable to generate email verification token: %v", err)
return
}
// Send the email
url := ctx.Echo().Reverse("verify_email", token)
err = c.Container.Mail.
Compose().
To(usr.Email).
Subject("Confirm your email address").
Body(fmt.Sprintf("Click here to confirm your email address: %s", url)).
Send(ctx)
if err != nil {
ctx.Logger().Errorf("unable to send email verification link: %v", err)
return
}
msg.Info(ctx, "An email was sent to you to verify your email address.")
}

View file

@ -0,0 +1,82 @@
package routes
import (
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/mikestefanello/pagoda/pkg/msg"
"github.com/labstack/echo/v4"
)
type (
resetPassword struct {
controller.Controller
}
resetPasswordForm struct {
Password string `form:"password" validate:"required"`
ConfirmPassword string `form:"password-confirm" validate:"required,eqfield=Password"`
Submission controller.FormSubmission
}
)
func (c *resetPassword) Get(ctx echo.Context) error {
page := controller.NewPage(ctx)
page.Layout = "auth"
page.Name = "reset-password"
page.Title = "Reset password"
page.Form = resetPasswordForm{}
if form := ctx.Get(context.FormKey); form != nil {
page.Form = form.(*resetPasswordForm)
}
return c.RenderPage(ctx, page)
}
func (c *resetPassword) Post(ctx echo.Context) error {
var form resetPasswordForm
ctx.Set(context.FormKey, &form)
// Parse the form values
if err := ctx.Bind(&form); err != nil {
return c.Fail(err, "unable to parse password reset form")
}
if err := form.Submission.Process(ctx, form); err != nil {
return c.Fail(err, "unable to process form submission")
}
if form.Submission.HasErrors() {
return c.Get(ctx)
}
// Hash the new password
hash, err := c.Container.Auth.HashPassword(form.Password)
if err != nil {
return c.Fail(err, "unable to hash password")
}
// Get the requesting user
usr := ctx.Get(context.UserKey).(*ent.User)
// Update the user
_, err = usr.
Update().
SetPassword(hash).
Save(ctx.Request().Context())
if err != nil {
return c.Fail(err, "unable to update password")
}
// Delete all password tokens for this user
err = c.Container.Auth.DeletePasswordTokens(ctx, usr.ID)
if err != nil {
return c.Fail(err, "unable to delete password tokens")
}
msg.Success(ctx, "Your password has been updated.")
return c.Redirect(ctx, "login")
}

109
pkg/routes/router.go Normal file
View file

@ -0,0 +1,109 @@
package routes
import (
"net/http"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/mikestefanello/pagoda/pkg/middleware"
"github.com/mikestefanello/pagoda/pkg/services"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
echomw "github.com/labstack/echo/v4/middleware"
)
// BuildRouter builds the router
func BuildRouter(c *services.Container) {
// Static files with proper cache control
// funcmap.File() should be used in templates to append a cache key to the URL in order to break cache
// after each server restart
c.Web.Group("", middleware.CacheControl(c.Config.Cache.Expiration.StaticFile)).
Static(config.StaticPrefix, config.StaticDir)
// Non static file route group
g := c.Web.Group("")
// Force HTTPS, if enabled
if c.Config.HTTP.TLS.Enabled {
g.Use(echomw.HTTPSRedirect())
}
g.Use(
echomw.RemoveTrailingSlashWithConfig(echomw.TrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
}),
echomw.Recover(),
echomw.Secure(),
echomw.RequestID(),
echomw.Gzip(),
echomw.Logger(),
middleware.LogRequestID(),
echomw.TimeoutWithConfig(echomw.TimeoutConfig{
Timeout: c.Config.App.Timeout,
}),
session.Middleware(sessions.NewCookieStore([]byte(c.Config.App.EncryptionKey))),
middleware.LoadAuthenticatedUser(c.Auth),
middleware.ServeCachedPage(c.Cache),
echomw.CSRFWithConfig(echomw.CSRFConfig{
TokenLookup: "form:csrf",
}),
)
// Base controller
ctr := controller.NewController(c)
// Error handler
err := errorHandler{Controller: ctr}
c.Web.HTTPErrorHandler = err.Get
// Example routes
navRoutes(c, g, ctr)
userRoutes(c, g, ctr)
}
func navRoutes(c *services.Container, g *echo.Group, ctr controller.Controller) {
home := home{Controller: ctr}
g.GET("/", home.Get).Name = "home"
search := search{Controller: ctr}
g.GET("/search", search.Get).Name = "search"
about := about{Controller: ctr}
g.GET("/about", about.Get).Name = "about"
contact := contact{Controller: ctr}
g.GET("/contact", contact.Get).Name = "contact"
g.POST("/contact", contact.Post).Name = "contact.post"
}
func userRoutes(c *services.Container, g *echo.Group, ctr controller.Controller) {
logout := logout{Controller: ctr}
g.GET("/logout", logout.Get, middleware.RequireAuthentication()).Name = "logout"
verifyEmail := verifyEmail{Controller: ctr}
g.GET("/email/verify/:token", verifyEmail.Get).Name = "verify_email"
noAuth := g.Group("/user", middleware.RequireNoAuthentication())
login := login{Controller: ctr}
noAuth.GET("/login", login.Get).Name = "login"
noAuth.POST("/login", login.Post).Name = "login.post"
register := register{Controller: ctr}
noAuth.GET("/register", register.Get).Name = "register"
noAuth.POST("/register", register.Post).Name = "register.post"
forgot := forgotPassword{Controller: ctr}
noAuth.GET("/password", forgot.Get).Name = "forgot_password"
noAuth.POST("/password", forgot.Post).Name = "forgot_password.post"
resetGroup := noAuth.Group("/password/reset",
middleware.LoadUser(c.ORM),
middleware.LoadValidPasswordToken(c.Auth),
)
reset := resetPassword{Controller: ctr}
resetGroup.GET("/token/:user/:password_token/:token", reset.Get).Name = "reset_password"
resetGroup.POST("/token/:user/:password_token/:token", reset.Post).Name = "reset_password.post"
}

136
pkg/routes/routes_test.go Normal file
View file

@ -0,0 +1,136 @@
package routes
import (
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"os"
"testing"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/pkg/services"
"github.com/PuerkitoBio/goquery"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
srv *httptest.Server
c *services.Container
)
func TestMain(m *testing.M) {
// Set the environment to test
config.SwitchEnvironment(config.EnvTest)
// Start a new container
c = services.NewContainer()
// Start a test HTTP server
BuildRouter(c)
srv = httptest.NewServer(c.Web)
// Run tests
exitVal := m.Run()
// Shutdown the container and test server
if err := c.Shutdown(); err != nil {
panic(err)
}
srv.Close()
os.Exit(exitVal)
}
type httpRequest struct {
route string
client http.Client
body url.Values
t *testing.T
}
func request(t *testing.T) *httpRequest {
jar, err := cookiejar.New(nil)
require.NoError(t, err)
r := httpRequest{
t: t,
body: url.Values{},
client: http.Client{
Jar: jar,
},
}
return &r
}
func (h *httpRequest) setClient(client http.Client) *httpRequest {
h.client = client
return h
}
func (h *httpRequest) setRoute(route string, params ...interface{}) *httpRequest {
h.route = srv.URL + c.Web.Reverse(route, params)
return h
}
func (h *httpRequest) setBody(body url.Values) *httpRequest {
h.body = body
return h
}
func (h *httpRequest) get() *httpResponse {
resp, err := h.client.Get(h.route)
require.NoError(h.t, err)
r := httpResponse{
t: h.t,
Response: resp,
}
return &r
}
func (h *httpRequest) post() *httpResponse {
// Make a get request to get the CSRF token
doc := h.get().
assertStatusCode(http.StatusOK).
toDoc()
// Extract the CSRF and include it in the POST request body
csrf := doc.Find(`input[name="csrf"]`).First()
token, exists := csrf.Attr("value")
assert.True(h.t, exists)
h.body["csrf"] = []string{token}
// Make the POST requests
resp, err := h.client.PostForm(h.route, h.body)
require.NoError(h.t, err)
r := httpResponse{
t: h.t,
Response: resp,
}
return &r
}
type httpResponse struct {
*http.Response
t *testing.T
}
func (h *httpResponse) assertStatusCode(code int) *httpResponse {
assert.Equal(h.t, code, h.Response.StatusCode)
return h
}
func (h *httpResponse) assertRedirect(t *testing.T, route string, params ...interface{}) *httpResponse {
assert.Equal(t, c.Web.Reverse(route, params), h.Header.Get("Location"))
return h
}
func (h *httpResponse) toDoc() *goquery.Document {
doc, err := goquery.NewDocumentFromReader(h.Body)
require.NoError(h.t, err)
err = h.Body.Close()
assert.NoError(h.t, err)
return doc
}

44
pkg/routes/search.go Normal file
View file

@ -0,0 +1,44 @@
package routes
import (
"fmt"
"math/rand"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/labstack/echo/v4"
)
type (
search struct {
controller.Controller
}
searchResult struct {
Title string
URL string
}
)
func (c *search) Get(ctx echo.Context) error {
page := controller.NewPage(ctx)
page.Layout = "main"
page.Name = "search"
// Fake search results
var results []searchResult
if search := ctx.QueryParam("query"); search != "" {
for i := 0; i < 5; i++ {
title := "Lorem ipsum example ddolor sit amet"
index := rand.Intn(len(title))
title = title[:index] + search + title[index:]
results = append(results, searchResult{
Title: title,
URL: fmt.Sprintf("https://www.%s.com", search),
})
}
}
page.Data = results
return c.RenderPage(ctx, page)
}

View file

@ -0,0 +1,62 @@
package routes
import (
"github.com/labstack/echo/v4"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/ent/user"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/mikestefanello/pagoda/pkg/controller"
"github.com/mikestefanello/pagoda/pkg/msg"
)
type verifyEmail struct {
controller.Controller
}
func (c *verifyEmail) Get(ctx echo.Context) error {
var usr *ent.User
// Validate the token
token := ctx.Param("token")
email, err := c.Container.Auth.ValidateEmailVerificationToken(token)
if err != nil {
msg.Warning(ctx, "The link is either invalid or has expired.")
return c.Redirect(ctx, "home")
}
// Check if it matches the authenticated user
if u := ctx.Get(context.AuthenticatedUserKey); u != nil {
authUser := u.(*ent.User)
if authUser.Email == email {
usr = authUser
}
}
// Query to find a matching user, if needed
if usr == nil {
usr, err = c.Container.ORM.User.
Query().
Where(user.Email(email)).
Only(ctx.Request().Context())
if err != nil {
return c.Fail(err, "query failed loading email verification token user")
}
}
// Verify the user, if needed
if !usr.Verified {
usr, err = usr.
Update().
SetVerified(true).
Save(ctx.Request().Context())
if err != nil {
return c.Fail(err, "failed to set user as verified")
}
}
msg.Success(ctx, "Your email has been successfully verified.")
return c.Redirect(ctx, "home")
}

233
pkg/services/auth.go Normal file
View file

@ -0,0 +1,233 @@
package services
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/ent/passwordtoken"
"github.com/mikestefanello/pagoda/ent/user"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"golang.org/x/crypto/bcrypt"
)
const (
// authSessionName stores the name of the session which contains authentication data
authSessionName = "ua"
// authSessionKeyUserID stores the key used to store the user ID in the session
authSessionKeyUserID = "user_id"
// authSessionKeyAuthenticated stores the key used to store the authentication status in the session
authSessionKeyAuthenticated = "authenticated"
)
// NotAuthenticatedError is an error returned when a user is not authenticated
type NotAuthenticatedError struct{}
// Error implements the error interface.
func (e NotAuthenticatedError) Error() string {
return "user not authenticated"
}
// InvalidPasswordTokenError is an error returned when an invalid token is provided
type InvalidPasswordTokenError struct{}
// Error implements the error interface.
func (e InvalidPasswordTokenError) Error() string {
return "invalid password token"
}
// AuthClient is the client that handles authentication requests
type AuthClient struct {
config *config.Config
orm *ent.Client
}
// NewAuthClient creates a new authentication client
func NewAuthClient(cfg *config.Config, orm *ent.Client) *AuthClient {
return &AuthClient{
config: cfg,
orm: orm,
}
}
// Login logs in a user of a given ID
func (c *AuthClient) Login(ctx echo.Context, userID int) error {
sess, err := session.Get(authSessionName, ctx)
if err != nil {
return err
}
sess.Values[authSessionKeyUserID] = userID
sess.Values[authSessionKeyAuthenticated] = true
return sess.Save(ctx.Request(), ctx.Response())
}
// Logout logs the requesting user out
func (c *AuthClient) Logout(ctx echo.Context) error {
sess, err := session.Get(authSessionName, ctx)
if err != nil {
return err
}
sess.Values[authSessionKeyAuthenticated] = false
return sess.Save(ctx.Request(), ctx.Response())
}
// GetAuthenticatedUserID returns the authenticated user's ID, if the user is logged in
func (c *AuthClient) GetAuthenticatedUserID(ctx echo.Context) (int, error) {
sess, err := session.Get(authSessionName, ctx)
if err != nil {
return 0, err
}
if sess.Values[authSessionKeyAuthenticated] == true {
return sess.Values[authSessionKeyUserID].(int), nil
}
return 0, NotAuthenticatedError{}
}
// GetAuthenticatedUser returns the authenticated user if the user is logged in
func (c *AuthClient) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) {
if userID, err := c.GetAuthenticatedUserID(ctx); err == nil {
return c.orm.User.Query().
Where(user.ID(userID)).
Only(ctx.Request().Context())
}
return nil, NotAuthenticatedError{}
}
// HashPassword returns a hash of a given password
func (c *AuthClient) HashPassword(password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hash), nil
}
// CheckPassword check if a given password matches a given hash
func (c *AuthClient) CheckPassword(password, hash string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
}
// GeneratePasswordResetToken generates a password reset token for a given user.
// For security purposes, the token itself is not stored in the database but rather
// a hash of the token, exactly how passwords are handled. This method returns both
// the generated token as well as the token entity which only contains the hash.
func (c *AuthClient) GeneratePasswordResetToken(ctx echo.Context, userID int) (string, *ent.PasswordToken, error) {
// Generate the token, which is what will go in the URL, but not the database
token, err := c.RandomToken(c.config.App.PasswordToken.Length)
if err != nil {
return "", nil, err
}
// Hash the token, which is what will be stored in the database
hash, err := c.HashPassword(token)
if err != nil {
return "", nil, err
}
// Create and save the password reset token
pt, err := c.orm.PasswordToken.
Create().
SetHash(hash).
SetUserID(userID).
Save(ctx.Request().Context())
return token, pt, err
}
// GetValidPasswordToken returns a valid, non-expired password token entity for a given user, token ID and token.
// Since the actual token is not stored in the database for security purposes, if a matching password token entity is
// found a hash of the provided token is compared with the hash stored in the database in order to validate.
func (c *AuthClient) GetValidPasswordToken(ctx echo.Context, userID, tokenID int, token string) (*ent.PasswordToken, error) {
// Ensure expired tokens are never returned
expiration := time.Now().Add(-c.config.App.PasswordToken.Expiration)
// Query to find a password token entity that matches the given user and token ID
pt, err := c.orm.PasswordToken.
Query().
Where(passwordtoken.ID(tokenID)).
Where(passwordtoken.HasUserWith(user.ID(userID))).
Where(passwordtoken.CreatedAtGTE(expiration)).
Only(ctx.Request().Context())
switch err.(type) {
case *ent.NotFoundError:
case nil:
// Check the token for a hash match
if err := c.CheckPassword(token, pt.Hash); err == nil {
return pt, nil
}
default:
if !context.IsCanceledError(err) {
return nil, err
}
}
return nil, InvalidPasswordTokenError{}
}
// DeletePasswordTokens deletes all password tokens in the database for a belonging to a given user.
// This should be called after a successful password reset.
func (c *AuthClient) DeletePasswordTokens(ctx echo.Context, userID int) error {
_, err := c.orm.PasswordToken.
Delete().
Where(passwordtoken.HasUserWith(user.ID(userID))).
Exec(ctx.Request().Context())
return err
}
// RandomToken generates a random token string of a given length
func (c *AuthClient) RandomToken(length int) (string, error) {
b := make([]byte, (length/2)+1)
if _, err := rand.Read(b); err != nil {
return "", err
}
token := hex.EncodeToString(b)
return token[:length], nil
}
// GenerateEmailVerificationToken generates an email verification token for a given email address using JWT which
// is set to expire based on the duration stored in configuration
func (c *AuthClient) GenerateEmailVerificationToken(email string) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"email": email,
"exp": time.Now().Add(c.config.App.EmailVerificationTokenExpiration).Unix(),
})
return token.SignedString([]byte(c.config.App.EncryptionKey))
}
// ValidateEmailVerificationToken validates an email verification token and returns the associated email address if
// the token is valid and has not expired
func (c *AuthClient) ValidateEmailVerificationToken(token string) (string, error) {
t, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(c.config.App.EncryptionKey), nil
})
if err != nil {
return "", err
}
if claims, ok := t.Claims.(jwt.MapClaims); ok && t.Valid {
return claims["email"].(string), nil
}
return "", errors.New("invalid or expired token")
}

146
pkg/services/auth_test.go Normal file
View file

@ -0,0 +1,146 @@
package services
import (
"context"
"errors"
"testing"
"time"
"github.com/mikestefanello/pagoda/ent/passwordtoken"
"github.com/mikestefanello/pagoda/ent/user"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)
func TestAuthClient_Auth(t *testing.T) {
assertNoAuth := func() {
_, err := c.Auth.GetAuthenticatedUserID(ctx)
assert.True(t, errors.Is(err, NotAuthenticatedError{}))
_, err = c.Auth.GetAuthenticatedUser(ctx)
assert.True(t, errors.Is(err, NotAuthenticatedError{}))
}
assertNoAuth()
err := c.Auth.Login(ctx, usr.ID)
require.NoError(t, err)
uid, err := c.Auth.GetAuthenticatedUserID(ctx)
require.NoError(t, err)
assert.Equal(t, usr.ID, uid)
u, err := c.Auth.GetAuthenticatedUser(ctx)
require.NoError(t, err)
assert.Equal(t, u.ID, usr.ID)
err = c.Auth.Logout(ctx)
require.NoError(t, err)
assertNoAuth()
}
func TestAuthClient_PasswordHashing(t *testing.T) {
pw := "testcheckpassword"
hash, err := c.Auth.HashPassword(pw)
assert.NoError(t, err)
assert.NotEqual(t, hash, pw)
err = c.Auth.CheckPassword(pw, hash)
assert.NoError(t, err)
}
func TestAuthClient_GeneratePasswordResetToken(t *testing.T) {
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
require.NoError(t, err)
assert.Len(t, token, c.Config.App.PasswordToken.Length)
assert.NoError(t, c.Auth.CheckPassword(token, pt.Hash))
}
func TestAuthClient_GetValidPasswordToken(t *testing.T) {
// Check that a fake token is not valid
_, err := c.Auth.GetValidPasswordToken(ctx, usr.ID, 1, "faketoken")
assert.Error(t, err)
// Generate a valid token and check that it is returned
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
require.NoError(t, err)
pt2, err := c.Auth.GetValidPasswordToken(ctx, usr.ID, pt.ID, token)
require.NoError(t, err)
assert.Equal(t, pt.ID, pt2.ID)
// Expire the token by pushing the date far enough back
count, err := c.ORM.PasswordToken.
Update().
SetCreatedAt(time.Now().Add(-(c.Config.App.PasswordToken.Expiration + time.Hour))).
Where(passwordtoken.ID(pt.ID)).
Save(context.Background())
require.NoError(t, err)
require.Equal(t, 1, count)
// Expired tokens should not be valid
_, err = c.Auth.GetValidPasswordToken(ctx, usr.ID, pt.ID, token)
assert.Error(t, err)
}
func TestAuthClient_DeletePasswordTokens(t *testing.T) {
// Create three tokens for the user
for i := 0; i < 3; i++ {
_, _, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
require.NoError(t, err)
}
// Delete all tokens for the user
err := c.Auth.DeletePasswordTokens(ctx, usr.ID)
require.NoError(t, err)
// Check that no tokens remain
count, err := c.ORM.PasswordToken.
Query().
Where(passwordtoken.HasUserWith(user.ID(usr.ID))).
Count(context.Background())
require.NoError(t, err)
assert.Equal(t, 0, count)
}
func TestAuthClient_RandomToken(t *testing.T) {
length := c.Config.App.PasswordToken.Length
a, err := c.Auth.RandomToken(length)
require.NoError(t, err)
b, err := c.Auth.RandomToken(length)
require.NoError(t, err)
assert.Len(t, a, length)
assert.Len(t, b, length)
assert.NotEqual(t, a, b)
}
func TestAuthClient_EmailVerificationToken(t *testing.T) {
t.Run("valid token", func(t *testing.T) {
email := "test@localhost.com"
token, err := c.Auth.GenerateEmailVerificationToken(email)
require.NoError(t, err)
tokenEmail, err := c.Auth.ValidateEmailVerificationToken(token)
require.NoError(t, err)
assert.Equal(t, email, tokenEmail)
})
t.Run("invalid token", func(t *testing.T) {
badToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAbG9jYWxob3N0LmNvbSIsImV4cCI6MTkxNzg2NDAwMH0.ScJCpfEEzlilKfRs_aVouzwPNKI28M3AIm-hyImQHUQ"
_, err := c.Auth.ValidateEmailVerificationToken(badToken)
assert.Error(t, err)
})
t.Run("expired token", func(t *testing.T) {
c.Config.App.EmailVerificationTokenExpiration = -time.Hour
email := "test@localhost.com"
token, err := c.Auth.GenerateEmailVerificationToken(email)
require.NoError(t, err)
_, err = c.Auth.ValidateEmailVerificationToken(token)
assert.Error(t, err)
c.Config.App.EmailVerificationTokenExpiration = time.Hour * 12
})
}

228
pkg/services/cache.go Normal file
View file

@ -0,0 +1,228 @@
package services
import (
"context"
"errors"
"fmt"
"time"
"github.com/eko/gocache/v2/cache"
"github.com/eko/gocache/v2/marshaler"
"github.com/eko/gocache/v2/store"
"github.com/go-redis/redis/v8"
"github.com/mikestefanello/pagoda/config"
)
type (
// CacheClient is the client that allows you to interact with the cache
CacheClient struct {
// Client stores the client to the underlying cache service
Client *redis.Client
// cache stores the cache interface
cache *cache.Cache
}
// cacheSet handles chaining a set operation
cacheSet struct {
client *CacheClient
key string
group string
data interface{}
expiration time.Duration
tags []string
}
// cacheGet handles chaining a get operation
cacheGet struct {
client *CacheClient
key string
group string
dataType interface{}
}
// cacheFlush handles chaining a flush operation
cacheFlush struct {
client *CacheClient
key string
group string
tags []string
}
)
// NewCacheClient creates a new cache client
func NewCacheClient(cfg *config.Config) (*CacheClient, error) {
// Determine the database based on the environment
db := cfg.Cache.Database
if cfg.App.Environment == config.EnvTest {
db = cfg.Cache.TestDatabase
}
// Connect to the cache
c := &CacheClient{}
c.Client = redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Hostname, cfg.Cache.Port),
Password: cfg.Cache.Password,
DB: db,
})
if _, err := c.Client.Ping(context.Background()).Result(); err != nil {
return c, err
}
// Flush the database if this is the test environment
if cfg.App.Environment == config.EnvTest {
if err := c.Client.FlushDB(context.Background()).Err(); err != nil {
return c, err
}
}
cacheStore := store.NewRedis(c.Client, nil)
c.cache = cache.New(cacheStore)
return c, nil
}
// Close closes the connection to the cache
func (c *CacheClient) Close() error {
return c.Client.Close()
}
// Set creates a cache set operation
func (c *CacheClient) Set() *cacheSet {
return &cacheSet{
client: c,
}
}
// Get creates a cache get operation
func (c *CacheClient) Get() *cacheGet {
return &cacheGet{
client: c,
}
}
// Flush creates a cache flush operation
func (c *CacheClient) Flush() *cacheFlush {
return &cacheFlush{
client: c,
}
}
// cacheKey formats a cache key with an optional group
func (c *CacheClient) cacheKey(group, key string) string {
if group != "" {
return fmt.Sprintf("%s::%s", group, key)
}
return key
}
// Key sets the cache key
func (c *cacheSet) Key(key string) *cacheSet {
c.key = key
return c
}
// Group sets the cache group
func (c *cacheSet) Group(group string) *cacheSet {
c.group = group
return c
}
// Data sets the data to cache
func (c *cacheSet) Data(data interface{}) *cacheSet {
c.data = data
return c
}
// Expiration sets the expiration duration of the cached data
func (c *cacheSet) Expiration(expiration time.Duration) *cacheSet {
c.expiration = expiration
return c
}
// Tags sets the cache tags
func (c *cacheSet) Tags(tags ...string) *cacheSet {
c.tags = tags
return c
}
// Save saves the data in the cache
func (c *cacheSet) Save(ctx context.Context) error {
if c.key == "" {
return errors.New("no cache key specified")
}
opts := &store.Options{
Expiration: c.expiration,
Tags: c.tags,
}
return marshaler.
New(c.client.cache).
Set(ctx, c.client.cacheKey(c.group, c.key), c.data, opts)
}
// Key sets the cache key
func (c *cacheGet) Key(key string) *cacheGet {
c.key = key
return c
}
// Group sets the cache group
func (c *cacheGet) Group(group string) *cacheGet {
c.group = group
return c
}
// Type sets the expected Go type of the data being retrieved from the cache
func (c *cacheGet) Type(expectedType interface{}) *cacheGet {
c.dataType = expectedType
return c
}
// Fetch fetches the data from the cache
func (c *cacheGet) Fetch(ctx context.Context) (interface{}, error) {
if c.key == "" {
return nil, errors.New("no cache key specified")
}
return marshaler.New(c.client.cache).Get(
ctx,
c.client.cacheKey(c.group, c.key),
c.dataType,
)
}
// Key sets the cache key
func (c *cacheFlush) Key(key string) *cacheFlush {
c.key = key
return c
}
// Group sets the cache group
func (c *cacheFlush) Group(group string) *cacheFlush {
c.group = group
return c
}
// Tags sets the cache tags
func (c *cacheFlush) Tags(tags ...string) *cacheFlush {
c.tags = tags
return c
}
// Execute flushes the data from the cache
func (c *cacheFlush) Execute(ctx context.Context) error {
if len(c.tags) > 0 {
if err := c.client.cache.Invalidate(ctx, store.InvalidateOptions{
Tags: c.tags,
}); err != nil {
return err
}
}
if c.key != "" {
return c.client.cache.Delete(ctx, c.client.cacheKey(c.group, c.key))
}
return nil
}

105
pkg/services/cache_test.go Normal file
View file

@ -0,0 +1,105 @@
package services
import (
"context"
"testing"
"time"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCacheClient(t *testing.T) {
type cacheTest struct {
Value string
}
// Cache some data
data := cacheTest{Value: "abcdef"}
group := "testgroup"
key := "testkey"
err := c.Cache.
Set().
Group(group).
Key(key).
Data(data).
Save(context.Background())
require.NoError(t, err)
// Get the data
fromCache, err := c.Cache.
Get().
Group(group).
Key(key).
Type(new(cacheTest)).
Fetch(context.Background())
require.NoError(t, err)
cast, ok := fromCache.(*cacheTest)
require.True(t, ok)
assert.Equal(t, data, *cast)
// The same key with the wrong group should fail
_, err = c.Cache.
Get().
Key(key).
Type(new(cacheTest)).
Fetch(context.Background())
assert.Error(t, err)
// Flush the data
err = c.Cache.
Flush().
Group(group).
Key(key).
Execute(context.Background())
require.NoError(t, err)
// The data should be gone
assertFlushed := func() {
// The data should be gone
_, err = c.Cache.
Get().
Group(group).
Key(key).
Type(new(cacheTest)).
Fetch(context.Background())
assert.Equal(t, redis.Nil, err)
}
assertFlushed()
// Set with tags
err = c.Cache.
Set().
Group(group).
Key(key).
Data(data).
Tags("tag1").
Save(context.Background())
require.NoError(t, err)
// Flush the tag
err = c.Cache.
Flush().
Tags("tag1").
Execute(context.Background())
require.NoError(t, err)
// The data should be gone
assertFlushed()
// Set with expiration
err = c.Cache.
Set().
Group(group).
Key(key).
Data(data).
Expiration(time.Millisecond).
Save(context.Background())
require.NoError(t, err)
// Wait for expiration
time.Sleep(time.Millisecond * 2)
// The data should be gone
assertFlushed()
}

201
pkg/services/container.go Normal file
View file

@ -0,0 +1,201 @@
package services
import (
"context"
"database/sql"
"fmt"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/schema"
// Required by ent
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/ent"
// Require by ent
_ "github.com/mikestefanello/pagoda/ent/runtime"
)
// Container contains all services used by the application and provides an easy way to handle dependency
// injection including within tests
type Container struct {
// Validator stores a validator
Validator *Validator
// Web stores the web framework
Web *echo.Echo
// Config stores the application configuration
Config *config.Config
// Cache contains the cache client
Cache *CacheClient
// Database stores the connection to the database
Database *sql.DB
// ORM stores a client to the ORM
ORM *ent.Client
// Mail stores an email sending client
Mail *MailClient
// Auth stores an authentication client
Auth *AuthClient
// TemplateRenderer stores a service to easily render and cache templates
TemplateRenderer *TemplateRenderer
// Tasks stores the task client
Tasks *TaskClient
}
// NewContainer creates and initializes a new Container
func NewContainer() *Container {
c := new(Container)
c.initConfig()
c.initValidator()
c.initWeb()
c.initCache()
c.initDatabase()
c.initORM()
c.initAuth()
c.initTemplateRenderer()
c.initMail()
c.initTasks()
return c
}
// Shutdown shuts the Container down and disconnects all connections
func (c *Container) Shutdown() error {
if err := c.Tasks.Close(); err != nil {
return err
}
if err := c.Cache.Close(); err != nil {
return err
}
if err := c.ORM.Close(); err != nil {
return err
}
if err := c.Database.Close(); err != nil {
return err
}
return nil
}
// initConfig initializes configuration
func (c *Container) initConfig() {
cfg, err := config.GetConfig()
if err != nil {
panic(fmt.Sprintf("failed to load config: %v", err))
}
c.Config = &cfg
}
// initValidator initializes the validator
func (c *Container) initValidator() {
c.Validator = NewValidator()
}
// initWeb initializes the web framework
func (c *Container) initWeb() {
c.Web = echo.New()
// Configure logging
switch c.Config.App.Environment {
case config.EnvProduction:
c.Web.Logger.SetLevel(log.WARN)
default:
c.Web.Logger.SetLevel(log.DEBUG)
}
c.Web.Validator = c.Validator
}
// initCache initializes the cache
func (c *Container) initCache() {
var err error
if c.Cache, err = NewCacheClient(c.Config); err != nil {
panic(err)
}
}
// initDatabase initializes the database
// If the environment is set to test, the test database will be used and will be dropped, recreated and migrated
func (c *Container) initDatabase() {
var err error
getAddr := func(dbName string) string {
return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s",
c.Config.Database.User,
c.Config.Database.Password,
c.Config.Database.Hostname,
c.Config.Database.Port,
dbName,
)
}
c.Database, err = sql.Open("pgx", getAddr(c.Config.Database.Database))
if err != nil {
panic(fmt.Sprintf("failed to connect to database: %v", err))
}
// Check if this is a test environment
if c.Config.App.Environment == config.EnvTest {
// Drop the test database, ignoring errors in case it doesn't yet exist
_, _ = c.Database.Exec("DROP DATABASE " + c.Config.Database.TestDatabase)
// Create the test database
if _, err = c.Database.Exec("CREATE DATABASE " + c.Config.Database.TestDatabase); err != nil {
panic(fmt.Sprintf("failed to create test database: %v", err))
}
// Connect to the test database
if err = c.Database.Close(); err != nil {
panic(fmt.Sprintf("failed to close database connection: %v", err))
}
c.Database, err = sql.Open("pgx", getAddr(c.Config.Database.TestDatabase))
if err != nil {
panic(fmt.Sprintf("failed to connect to database: %v", err))
}
}
}
// initORM initializes the ORM
func (c *Container) initORM() {
drv := entsql.OpenDB(dialect.Postgres, c.Database)
c.ORM = ent.NewClient(ent.Driver(drv))
if err := c.ORM.Schema.Create(context.Background(), schema.WithAtlas(true)); err != nil {
panic(fmt.Sprintf("failed to create database schema: %v", err))
}
}
// initAuth initializes the authentication client
func (c *Container) initAuth() {
c.Auth = NewAuthClient(c.Config, c.ORM)
}
// initTemplateRenderer initializes the template renderer
func (c *Container) initTemplateRenderer() {
c.TemplateRenderer = NewTemplateRenderer(c.Config)
}
// initMail initialize the mail client
func (c *Container) initMail() {
var err error
c.Mail, err = NewMailClient(c.Config, c.TemplateRenderer)
if err != nil {
panic(fmt.Sprintf("failed to create mail client: %v", err))
}
}
// initTasks initializes the task client
func (c *Container) initTasks() {
c.Tasks = NewTaskClient(c.Config)
}

View file

@ -0,0 +1,20 @@
package services
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewContainer(t *testing.T) {
assert.NotNil(t, c.Web)
assert.NotNil(t, c.Config)
assert.NotNil(t, c.Validator)
assert.NotNil(t, c.Cache)
assert.NotNil(t, c.Database)
assert.NotNil(t, c.ORM)
assert.NotNil(t, c.Mail)
assert.NotNil(t, c.Auth)
assert.NotNil(t, c.TemplateRenderer)
assert.NotNil(t, c.Tasks)
}

139
pkg/services/mail.go Normal file
View file

@ -0,0 +1,139 @@
package services
import (
"errors"
"fmt"
"github.com/mikestefanello/pagoda/config"
"github.com/labstack/echo/v4"
)
type (
// MailClient provides a client for sending email
// This is purposely not completed because there are many different methods and services
// for sending email, many of which are very different. Choose what works best for you
// and populate the methods below
MailClient struct {
// config stores application configuration
config *config.Config
// templates stores the template renderer
templates *TemplateRenderer
}
// mail represents an email to be sent
mail struct {
client *MailClient
from string
to string
subject string
body string
template string
templateData interface{}
}
)
// NewMailClient creates a new MailClient
func NewMailClient(cfg *config.Config, templates *TemplateRenderer) (*MailClient, error) {
return &MailClient{
config: cfg,
templates: templates,
}, nil
}
// Compose creates a new email
func (m *MailClient) Compose() *mail {
return &mail{
client: m,
from: m.config.Mail.FromAddress,
}
}
// skipSend determines if mail sending should be skipped
func (m *MailClient) skipSend() bool {
return m.config.App.Environment != config.EnvProduction
}
// send attempts to send the email
func (m *MailClient) send(email *mail, ctx echo.Context) error {
switch {
case email.to == "":
return errors.New("email cannot be sent without a to address")
case email.body == "" && email.template == "":
return errors.New("email cannot be sent without a body or template")
}
// Check if a template was supplied
if email.template != "" {
// Parse and execute template
buf, err := m.templates.
Parse().
Group("mail").
Key(email.template).
Base(email.template).
Files(fmt.Sprintf("emails/%s", email.template)).
Execute(email.templateData)
if err != nil {
return err
}
email.body = buf.String()
}
// Check if mail sending should be skipped
if m.skipSend() {
ctx.Logger().Debugf("skipping email sent to: %s", email.to)
return nil
}
// TODO: Finish based on your mail sender of choice!
return nil
}
// From sets the email from address
func (m *mail) From(from string) *mail {
m.from = from
return m
}
// To sets the email address this email will be sent to
func (m *mail) To(to string) *mail {
m.to = to
return m
}
// Subject sets the subject line of the email
func (m *mail) Subject(subject string) *mail {
m.subject = subject
return m
}
// Body sets the body of the email
// This is not required and will be ignored if a template via Template()
func (m *mail) Body(body string) *mail {
m.body = body
return m
}
// Template sets the template to be used to produce the body of the email
// The template name should only include the filename without the extension or directory.
// The template must reside within the emails sub-directory.
// The funcmap will be automatically added to the template.
// Use TemplateData() to supply the data that will be passed in to the template.
func (m *mail) Template(template string) *mail {
m.template = template
return m
}
// TemplateData sets the data that will be passed to the template specified when calling Template()
func (m *mail) TemplateData(data interface{}) *mail {
m.templateData = data
return m
}
// Send attempts to send the email
func (m *mail) Send(ctx echo.Context) error {
return m.client.send(m, ctx)
}

View file

@ -0,0 +1,3 @@
package services
// Fill this in once you implement your mail client

View file

@ -0,0 +1,46 @@
package services
import (
"os"
"testing"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/labstack/echo/v4"
)
var (
c *Container
ctx echo.Context
usr *ent.User
)
func TestMain(m *testing.M) {
// Set the environment to test
config.SwitchEnvironment(config.EnvTest)
// Create a new container
c = NewContainer()
// Create a web context
ctx, _ = tests.NewContext(c.Web, "/")
tests.InitSession(ctx)
// Create a test user
var err error
if usr, err = tests.CreateUser(c.ORM); err != nil {
panic(err)
}
// Run tests
exitVal := m.Run()
// Shutdown the container
if err = c.Shutdown(); err != nil {
panic(err)
}
os.Exit(exitVal)
}

179
pkg/services/tasks.go Normal file
View file

@ -0,0 +1,179 @@
package services
import (
"encoding/json"
"fmt"
"time"
"github.com/hibiken/asynq"
"github.com/mikestefanello/pagoda/config"
)
type (
// TaskClient is that client that allows you to queue or schedule task execution
TaskClient struct {
// client stores the asynq client
client *asynq.Client
// scheduler stores the asynq scheduler
scheduler *asynq.Scheduler
}
// task handles task creation operations
task struct {
client *TaskClient
typ string
payload interface{}
periodic *string
queue *string
maxRetries *int
timeout *time.Duration
deadline *time.Time
at *time.Time
wait *time.Duration
retain *time.Duration
}
)
// NewTaskClient creates a new task client
func NewTaskClient(cfg *config.Config) *TaskClient {
// Determine the database based on the environment
db := cfg.Cache.Database
if cfg.App.Environment == config.EnvTest {
db = cfg.Cache.TestDatabase
}
conn := asynq.RedisClientOpt{
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Hostname, cfg.Cache.Port),
Password: cfg.Cache.Password,
DB: db,
}
return &TaskClient{
client: asynq.NewClient(conn),
scheduler: asynq.NewScheduler(conn, nil),
}
}
// Close closes the connection to the task service
func (t *TaskClient) Close() error {
return t.client.Close()
}
// StartScheduler starts the scheduler service which adds scheduled tasks to the queue
// This must be running in order to queue tasks set for periodic execution
func (t *TaskClient) StartScheduler() error {
return t.scheduler.Run()
}
// New starts a task creation operation
func (t *TaskClient) New(typ string) *task {
return &task{
client: t,
typ: typ,
}
}
// Payload sets the task payload data which will be sent to the task handler
func (t *task) Payload(payload interface{}) *task {
t.payload = payload
return t
}
// Periodic sets the task to execute periodically according to a given interval
// The interval can be either in cron form ("*/5 * * * *") or "@every 30s"
func (t *task) Periodic(interval string) *task {
t.periodic = &interval
return t
}
// Queue specifies the name of the queue to add the task to
// The default queue will be used if this is not set
func (t *task) Queue(queue string) *task {
t.queue = &queue
return t
}
// Timeout sets the task timeout, meaning the task must execute within a given duration
func (t *task) Timeout(timeout time.Duration) *task {
t.timeout = &timeout
return t
}
// Deadline sets the task execution deadline to a specific date and time
func (t *task) Deadline(deadline time.Time) *task {
t.deadline = &deadline
return t
}
// At sets the exact date and time the task should be executed
func (t *task) At(processAt time.Time) *task {
t.at = &processAt
return t
}
// Wait instructs the task to wait a given duration before it is executed
func (t *task) Wait(duration time.Duration) *task {
t.wait = &duration
return t
}
// Retain instructs the task service to retain the task data for a given duration after execution is complete
func (t *task) Retain(duration time.Duration) *task {
t.retain = &duration
return t
}
// MaxRetries sets the maximum amount of times to retry executing the task in the event of a failure
func (t *task) MaxRetries(retries int) *task {
t.maxRetries = &retries
return t
}
// Save saves the task so it can be executed
func (t *task) Save() error {
var err error
// Build the payload
var payload []byte
if t.payload != nil {
if payload, err = json.Marshal(t.payload); err != nil {
return err
}
}
// Build the task options
opts := make([]asynq.Option, 0)
if t.queue != nil {
opts = append(opts, asynq.Queue(*t.queue))
}
if t.maxRetries != nil {
opts = append(opts, asynq.MaxRetry(*t.maxRetries))
}
if t.timeout != nil {
opts = append(opts, asynq.Timeout(*t.timeout))
}
if t.deadline != nil {
opts = append(opts, asynq.Deadline(*t.deadline))
}
if t.wait != nil {
opts = append(opts, asynq.ProcessIn(*t.wait))
}
if t.retain != nil {
opts = append(opts, asynq.Retention(*t.retain))
}
if t.at != nil {
opts = append(opts, asynq.ProcessAt(*t.at))
}
// Build the task
task := asynq.NewTask(t.typ, payload, opts...)
// Schedule, if needed
if t.periodic != nil {
_, err = t.client.scheduler.Register(*t.periodic, task)
} else {
_, err = t.client.client.Enqueue(task)
}
return err
}

View file

@ -0,0 +1,35 @@
package services
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTaskClient_New(t *testing.T) {
now := time.Now()
tk := c.Tasks.
New("task1").
Payload("payload").
Queue("queue").
Periodic("@every 5s").
MaxRetries(5).
Timeout(5 * time.Second).
Deadline(now).
At(now).
Wait(6 * time.Second).
Retain(7 * time.Second)
assert.Equal(t, "task1", tk.typ)
assert.Equal(t, "payload", tk.payload.(string))
assert.Equal(t, "queue", *tk.queue)
assert.Equal(t, "@every 5s", *tk.periodic)
assert.Equal(t, 5, *tk.maxRetries)
assert.Equal(t, 5*time.Second, *tk.timeout)
assert.Equal(t, now, *tk.deadline)
assert.Equal(t, now, *tk.at)
assert.Equal(t, 6*time.Second, *tk.wait)
assert.Equal(t, 7*time.Second, *tk.retain)
assert.NoError(t, tk.Save())
}

View file

@ -0,0 +1,233 @@
package services
import (
"bytes"
"errors"
"fmt"
"html/template"
"path"
"path/filepath"
"runtime"
"sync"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/pkg/funcmap"
)
type (
// TemplateRenderer provides a flexible and easy to use method of rendering simple templates or complex sets of
// templates while also providing caching and/or hot-reloading depending on your current environment
TemplateRenderer struct {
// templateCache stores a cache of parsed page templates
templateCache sync.Map
// funcMap stores the template function map
funcMap template.FuncMap
// templatePath stores the complete path to the templates directory
templatesPath string
// config stores application configuration
config *config.Config
}
// TemplateParsed is a wrapper around parsed templates which are stored in the TemplateRenderer cache
TemplateParsed struct {
// Template is the parsed template
Template *template.Template
// build stores the build data used to parse the template
build *templateBuild
}
// templateBuild stores the build data used to parse a template
templateBuild struct {
group string
key string
base string
files []string
directories []string
}
// templateBuilder handles chaining a template parse operation
templateBuilder struct {
build *templateBuild
renderer *TemplateRenderer
}
)
// NewTemplateRenderer creates a new TemplateRenderer
func NewTemplateRenderer(cfg *config.Config) *TemplateRenderer {
t := &TemplateRenderer{
templateCache: sync.Map{},
funcMap: funcmap.GetFuncMap(),
config: cfg,
}
// Gets the complete templates directory path
// This is needed in case this is called from a package outside of main, such as within tests
_, b, _, _ := runtime.Caller(0)
d := path.Join(path.Dir(b))
t.templatesPath = filepath.Join(filepath.Dir(d), config.TemplateDir)
return t
}
// Parse creates a template build operation
func (t *TemplateRenderer) Parse() *templateBuilder {
return &templateBuilder{
renderer: t,
build: &templateBuild{},
}
}
// GetTemplatesPath gets the complete path to the templates directory
func (t *TemplateRenderer) GetTemplatesPath() string {
return t.templatesPath
}
// getCacheKey gets a cache key for a given group and ID
func (t *TemplateRenderer) getCacheKey(group, key string) string {
if group != "" {
return fmt.Sprintf("%s:%s", group, key)
}
return key
}
// parse parses a set of templates and caches them for quick execution
// If the application environment is set to local, the cache will be bypassed and templates will be
// parsed upon each request so hot-reloading is possible without restarts.
// Also included will be the function map provided by the funcmap package.
func (t *TemplateRenderer) parse(build *templateBuild) (*TemplateParsed, error) {
var tp *TemplateParsed
var err error
switch {
case build.key == "":
return nil, errors.New("cannot parse template without key")
case len(build.files) == 0 && len(build.directories) == 0:
return nil, errors.New("cannot parse template without files or directories")
case build.base == "":
return nil, errors.New("cannot parse template without base")
}
// Generate the cache key
cacheKey := t.getCacheKey(build.group, build.key)
// Check if the template has not yet been parsed or if the app environment is local, so that
// templates reflect changes without having the restart the server
if tp, err = t.Load(build.group, build.key); err != nil || t.config.App.Environment == config.EnvLocal {
// Initialize the parsed template with the function map
parsed := template.New(build.base + config.TemplateExt).
Funcs(t.funcMap)
// Parse all files provided
if len(build.files) > 0 {
for k, v := range build.files {
build.files[k] = fmt.Sprintf("%s/%s%s", t.templatesPath, v, config.TemplateExt)
}
parsed, err = parsed.ParseFiles(build.files...)
if err != nil {
return nil, err
}
}
// Parse all templates within the provided directories
for _, dir := range build.directories {
dir = fmt.Sprintf("%s/%s/*%s", t.templatesPath, dir, config.TemplateExt)
parsed, err = parsed.ParseGlob(dir)
if err != nil {
return nil, err
}
}
// Store the template so this process only happens once
tp = &TemplateParsed{
Template: parsed,
build: build,
}
t.templateCache.Store(cacheKey, tp)
}
return tp, nil
}
// Load loads a template from the cache
func (t *TemplateRenderer) Load(group, key string) (*TemplateParsed, error) {
load, ok := t.templateCache.Load(t.getCacheKey(group, key))
if !ok {
return nil, errors.New("uncached page template requested")
}
tmpl, ok := load.(*TemplateParsed)
if !ok {
return nil, errors.New("unable to cast cached template")
}
return tmpl, nil
}
// Execute executes a template with the given data and provides the output
func (t *TemplateParsed) Execute(data interface{}) (*bytes.Buffer, error) {
if t.Template == nil {
return nil, errors.New("cannot execute template: template not initialized")
}
buf := new(bytes.Buffer)
err := t.Template.ExecuteTemplate(buf, t.build.base+config.TemplateExt, data)
if err != nil {
return nil, err
}
return buf, nil
}
// Group sets the cache group for the template being built
func (t *templateBuilder) Group(group string) *templateBuilder {
t.build.group = group
return t
}
// Key sets the cache key for the template being built
func (t *templateBuilder) Key(key string) *templateBuilder {
t.build.key = key
return t
}
// Base sets the name of the base template to be used during template parsing and execution.
// This should be only the file name without a directory or extension.
func (t *templateBuilder) Base(base string) *templateBuilder {
t.build.base = base
return t
}
// Files sets a list of template files to include in the parse.
// This should not include the file extension and the paths should be relative to the templates directory.
func (t *templateBuilder) Files(files ...string) *templateBuilder {
t.build.files = files
return t
}
// Directories sets a list of directories that all template files within will be parsed.
// The paths should be relative to the templates directory.
func (t *templateBuilder) Directories(directories ...string) *templateBuilder {
t.build.directories = directories
return t
}
// Store parsed the templates and stores them in the cache
func (t *templateBuilder) Store() (*TemplateParsed, error) {
return t.renderer.parse(t.build)
}
// Execute executes the template with the given data.
// If the template has not already been cached, this will parse and cache the template
func (t *templateBuilder) Execute(data interface{}) (*bytes.Buffer, error) {
tp, err := t.Store()
if err != nil {
return nil, err
}
return tp.Execute(data)
}

View file

@ -0,0 +1,72 @@
package services
import (
"io/ioutil"
"testing"
"github.com/mikestefanello/pagoda/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTemplateRenderer(t *testing.T) {
group := "test"
id := "parse"
// Should not exist yet
_, err := c.TemplateRenderer.Load(group, id)
assert.Error(t, err)
// Parse in to the cache
tpl, err := c.TemplateRenderer.
Parse().
Group(group).
Key(id).
Base("htmx").
Files("htmx", "pages/error").
Directories("components").
Store()
require.NoError(t, err)
// Should exist now
parsed, err := c.TemplateRenderer.Load(group, id)
require.NoError(t, err)
// Check that all expected templates are included
expectedTemplates := make(map[string]bool)
expectedTemplates["htmx"+config.TemplateExt] = true
expectedTemplates["error"+config.TemplateExt] = true
components, err := ioutil.ReadDir(c.TemplateRenderer.GetTemplatesPath() + "/components")
require.NoError(t, err)
for _, f := range components {
expectedTemplates[f.Name()] = true
}
for _, v := range parsed.Template.Templates() {
delete(expectedTemplates, v.Name())
}
assert.Empty(t, expectedTemplates)
data := struct {
StatusCode int
}{
StatusCode: 500,
}
buf, err := tpl.Execute(data)
require.NoError(t, err)
require.NotNil(t, buf)
assert.Contains(t, buf.String(), "Please try again")
buf, err = c.TemplateRenderer.
Parse().
Group(group).
Key(id).
Base("htmx").
Files("htmx", "pages/error").
Directories("components").
Execute(data)
require.NoError(t, err)
require.NotNil(t, buf)
assert.Contains(t, buf.String(), "Please try again")
}

26
pkg/services/validator.go Normal file
View file

@ -0,0 +1,26 @@
package services
import (
"github.com/go-playground/validator/v10"
)
// Validator provides validation mainly validating structs within the web context
type Validator struct {
// validator stores the underlying validator
validator *validator.Validate
}
// NewValidator creats a new Validator
func NewValidator() *Validator {
return &Validator{
validator: validator.New(),
}
}
// Validate validates a struct
func (v *Validator) Validate(i interface{}) error {
if err := v.validator.Struct(i); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,19 @@
package services
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidator(t *testing.T) {
type example struct {
Value string `validate:"required"`
}
e := example{}
err := c.Validator.Validate(e)
assert.Error(t, err)
e.Value = "a"
err = c.Validator.Validate(e)
assert.NoError(t, err)
}

22
pkg/tasks/example.go Normal file
View file

@ -0,0 +1,22 @@
package tasks
import (
"context"
"log"
"github.com/hibiken/asynq"
)
// TypeExample is the type for the example task.
// This is what is passed in to TaskClient.New() when creating a new task
const TypeExample = "example_task"
// ExampleProcessor processes example tasks
type ExampleProcessor struct {
}
// ProcessTask handles the processing of the task
func (p *ExampleProcessor) ProcessTask(ctx context.Context, t *asynq.Task) error {
log.Printf("executing task: %s", t.Type())
return nil
}

60
pkg/tests/tests.go Normal file
View file

@ -0,0 +1,60 @@
package tests
import (
"context"
"fmt"
"math/rand"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/mikestefanello/pagoda/ent"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
)
// NewContext creates a new Echo context for tests using an HTTP test request and response recorder
func NewContext(e *echo.Echo, url string) (echo.Context, *httptest.ResponseRecorder) {
req := httptest.NewRequest(http.MethodGet, url, strings.NewReader(""))
rec := httptest.NewRecorder()
return e.NewContext(req, rec), rec
}
// InitSession initializes a session for a given Echo context
func InitSession(ctx echo.Context) {
mw := session.Middleware(sessions.NewCookieStore([]byte("secret")))
_ = ExecuteMiddleware(ctx, mw)
}
// ExecuteMiddleware executes a middleware function on a given Echo context
func ExecuteMiddleware(ctx echo.Context, mw echo.MiddlewareFunc) error {
handler := mw(func(c echo.Context) error {
return nil
})
return handler(ctx)
}
// AssertHTTPErrorCode asserts an HTTP status code on a given Echo HTTP error
func AssertHTTPErrorCode(t *testing.T, err error, code int) {
httpError, ok := err.(*echo.HTTPError)
require.True(t, ok)
assert.Equal(t, code, httpError.Code)
}
// CreateUser creates a random user entity
func CreateUser(orm *ent.Client) (*ent.User, error) {
seed := fmt.Sprintf("%d-%d", time.Now().UnixMilli(), rand.Intn(1000000))
return orm.User.
Create().
SetEmail(fmt.Sprintf("testuser-%s@localhost.localhost", seed)).
SetPassword("password").
SetName(fmt.Sprintf("Test User %s", seed)).
Save(context.Background())
}