Reorganized directories and packages.
This commit is contained in:
parent
1018d82d13
commit
72ce41c828
61 changed files with 83 additions and 83 deletions
25
pkg/context/context.go
Normal file
25
pkg/context/context.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuthenticatedUserKey is the key value used to store the authenticated user in context
|
||||
AuthenticatedUserKey = "auth_user"
|
||||
|
||||
// UserKey is the key value used to store a user in context
|
||||
UserKey = "user"
|
||||
|
||||
// FormKey is the key value used to store a form in context
|
||||
FormKey = "form"
|
||||
|
||||
// PasswordTokenKey is the key value used to store a password token in context
|
||||
PasswordTokenKey = "password_token"
|
||||
)
|
||||
|
||||
// IsCanceledError determines if an error is due to a context cancelation
|
||||
func IsCanceledError(err error) bool {
|
||||
return errors.Is(err, context.Canceled)
|
||||
}
|
||||
24
pkg/context/context_test.go
Normal file
24
pkg/context/context_test.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
assert.False(t, IsCanceledError(ctx.Err()))
|
||||
cancel()
|
||||
assert.True(t, IsCanceledError(ctx.Err()))
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Microsecond)
|
||||
time.Sleep(time.Microsecond * 5)
|
||||
cancel()
|
||||
assert.False(t, IsCanceledError(ctx.Err()))
|
||||
|
||||
assert.False(t, IsCanceledError(errors.New("test error")))
|
||||
}
|
||||
171
pkg/controller/controller.go
Normal file
171
pkg/controller/controller.go
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/htmx"
|
||||
"github.com/mikestefanello/pagoda/pkg/middleware"
|
||||
"github.com/mikestefanello/pagoda/pkg/services"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// Controller provides base functionality and dependencies to routes.
|
||||
// The proposed pattern is to embed a Controller in each individual route struct and to use
|
||||
// the router to inject the container so your routes have access to the services within the container
|
||||
type Controller struct {
|
||||
// Container stores a services container which contains dependencies
|
||||
Container *services.Container
|
||||
}
|
||||
|
||||
// NewController creates a new Controller
|
||||
func NewController(c *services.Container) Controller {
|
||||
return Controller{
|
||||
Container: c,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderPage renders a Page as an HTTP response
|
||||
func (c *Controller) RenderPage(ctx echo.Context, page Page) error {
|
||||
var buf *bytes.Buffer
|
||||
var err error
|
||||
|
||||
// Page name is required
|
||||
if page.Name == "" {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "page render failed due to missing name")
|
||||
}
|
||||
|
||||
// Use the app name in configuration if a value was not set
|
||||
if page.AppName == "" {
|
||||
page.AppName = c.Container.Config.App.Name
|
||||
}
|
||||
|
||||
// Check if this is an HTMX non-boosted request which indicates that only partial
|
||||
// content should be rendered
|
||||
if page.HTMX.Request.Enabled && !page.HTMX.Request.Boosted {
|
||||
// Parse and execute the templates only for the content portion of the page
|
||||
// The templates used for this partial request will be:
|
||||
// 1. The base htmx template which omits the layout and only includes the content template
|
||||
// 2. The content template specified in Page.Name
|
||||
// 3. All templates within the components directory
|
||||
// Also included is the function map provided by the funcmap package
|
||||
buf, err = c.Container.TemplateRenderer.
|
||||
Parse().
|
||||
Group("page:htmx").
|
||||
Key(page.Name).
|
||||
Base("htmx").
|
||||
Files(
|
||||
"htmx",
|
||||
fmt.Sprintf("pages/%s", page.Name),
|
||||
).
|
||||
Directories("components").
|
||||
Execute(page)
|
||||
} else {
|
||||
// Parse and execute the templates for the Page
|
||||
// As mentioned in the documentation for the Page struct, the templates used for the page will be:
|
||||
// 1. The layout/base template specified in Page.Layout
|
||||
// 2. The content template specified in Page.Name
|
||||
// 3. All templates within the components directory
|
||||
// Also included is the function map provided by the funcmap package
|
||||
buf, err = c.Container.TemplateRenderer.
|
||||
Parse().
|
||||
Group("page").
|
||||
Key(page.Name).
|
||||
Base(page.Layout).
|
||||
Files(
|
||||
fmt.Sprintf("layouts/%s", page.Layout),
|
||||
fmt.Sprintf("pages/%s", page.Name),
|
||||
).
|
||||
Directories("components").
|
||||
Execute(page)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return c.Fail(err, "failed to parse and execute templates")
|
||||
}
|
||||
|
||||
// Set the status code
|
||||
ctx.Response().Status = page.StatusCode
|
||||
|
||||
// Set any headers
|
||||
for k, v := range page.Headers {
|
||||
ctx.Response().Header().Set(k, v)
|
||||
}
|
||||
|
||||
// Apply the HTMX response, if one
|
||||
if page.HTMX.Response != nil {
|
||||
page.HTMX.Response.Apply(ctx)
|
||||
}
|
||||
|
||||
// Cache this page, if caching was enabled
|
||||
c.cachePage(ctx, page, buf)
|
||||
|
||||
return ctx.HTMLBlob(ctx.Response().Status, buf.Bytes())
|
||||
}
|
||||
|
||||
// cachePage caches the HTML for a given Page if the Page has caching enabled
|
||||
func (c *Controller) cachePage(ctx echo.Context, page Page, html *bytes.Buffer) {
|
||||
if !page.Cache.Enabled || page.IsAuth {
|
||||
return
|
||||
}
|
||||
|
||||
// If no expiration time was provided, default to the configuration value
|
||||
if page.Cache.Expiration == 0 {
|
||||
page.Cache.Expiration = c.Container.Config.Cache.Expiration.Page
|
||||
}
|
||||
|
||||
// Extract the headers
|
||||
headers := make(map[string]string)
|
||||
for k, v := range ctx.Response().Header() {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
|
||||
// The request URL is used as the cache key so the middleware can serve the
|
||||
// cached page on matching requests
|
||||
key := ctx.Request().URL.String()
|
||||
cp := middleware.CachedPage{
|
||||
URL: key,
|
||||
HTML: html.Bytes(),
|
||||
Headers: headers,
|
||||
StatusCode: ctx.Response().Status,
|
||||
}
|
||||
|
||||
err := c.Container.Cache.
|
||||
Set().
|
||||
Group(middleware.CachedPageGroup).
|
||||
Key(key).
|
||||
Tags(page.Cache.Tags...).
|
||||
Expiration(page.Cache.Expiration).
|
||||
Data(cp).
|
||||
Save(ctx.Request().Context())
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
ctx.Logger().Info("cached page")
|
||||
case !context.IsCanceledError(err):
|
||||
ctx.Logger().Errorf("failed to cache page: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect redirects to a given route name with optional route parameters
|
||||
func (c *Controller) Redirect(ctx echo.Context, route string, routeParams ...interface{}) error {
|
||||
url := ctx.Echo().Reverse(route, routeParams...)
|
||||
|
||||
if htmx.GetRequest(ctx).Boosted {
|
||||
htmx.Response{
|
||||
Redirect: url,
|
||||
}.Apply(ctx)
|
||||
|
||||
return nil
|
||||
} else {
|
||||
return ctx.Redirect(http.StatusFound, url)
|
||||
}
|
||||
}
|
||||
|
||||
// Fail is a helper to fail a request by returning a 500 error and logging the error
|
||||
func (c *Controller) Fail(err error, log string) error {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("%s: %v", log, err))
|
||||
}
|
||||
187
pkg/controller/controller_test.go
Normal file
187
pkg/controller/controller_test.go
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
"github.com/mikestefanello/pagoda/pkg/htmx"
|
||||
"github.com/mikestefanello/pagoda/pkg/middleware"
|
||||
"github.com/mikestefanello/pagoda/pkg/services"
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
var (
|
||||
c *services.Container
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Set the environment to test
|
||||
config.SwitchEnvironment(config.EnvTest)
|
||||
|
||||
// Create a new container
|
||||
c = services.NewContainer()
|
||||
|
||||
// Run tests
|
||||
exitVal := m.Run()
|
||||
|
||||
// Shutdown the container
|
||||
if err := c.Shutdown(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
os.Exit(exitVal)
|
||||
}
|
||||
|
||||
func TestController_Redirect(t *testing.T) {
|
||||
c.Web.GET("/path/:first/and/:second", func(c echo.Context) error {
|
||||
return nil
|
||||
}).Name = "redirect-test"
|
||||
|
||||
ctx, _ := tests.NewContext(c.Web, "/abc")
|
||||
ctr := NewController(c)
|
||||
err := ctr.Redirect(ctx, "redirect-test", "one", "two")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/path/one/and/two", ctx.Response().Header().Get(echo.HeaderLocation))
|
||||
assert.Equal(t, http.StatusFound, ctx.Response().Status)
|
||||
}
|
||||
|
||||
func TestController_RenderPage(t *testing.T) {
|
||||
setup := func() (echo.Context, *httptest.ResponseRecorder, Controller, Page) {
|
||||
ctx, rec := tests.NewContext(c.Web, "/test/TestController_RenderPage")
|
||||
tests.InitSession(ctx)
|
||||
ctr := NewController(c)
|
||||
|
||||
p := NewPage(ctx)
|
||||
p.Name = "home"
|
||||
p.Layout = "main"
|
||||
p.Cache.Enabled = false
|
||||
p.Headers["A"] = "b"
|
||||
p.Headers["C"] = "d"
|
||||
p.StatusCode = http.StatusCreated
|
||||
return ctx, rec, ctr, p
|
||||
}
|
||||
|
||||
t.Run("missing name", func(t *testing.T) {
|
||||
// Rendering should fail if the Page has no name
|
||||
ctx, _, ctr, p := setup()
|
||||
p.Name = ""
|
||||
err := ctr.RenderPage(ctx, p)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("no page cache", func(t *testing.T) {
|
||||
ctx, _, ctr, p := setup()
|
||||
err := ctr.RenderPage(ctx, p)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check status code and headers
|
||||
assert.Equal(t, http.StatusCreated, ctx.Response().Status)
|
||||
for k, v := range p.Headers {
|
||||
assert.Equal(t, v, ctx.Response().Header().Get(k))
|
||||
}
|
||||
|
||||
// Check the template cache
|
||||
parsed, err := c.TemplateRenderer.Load("page", p.Name)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check that all expected templates were parsed.
|
||||
// This includes the name, layout and all components
|
||||
expectedTemplates := make(map[string]bool)
|
||||
expectedTemplates[p.Name+config.TemplateExt] = true
|
||||
expectedTemplates[p.Layout+config.TemplateExt] = true
|
||||
components, err := ioutil.ReadDir(c.TemplateRenderer.GetTemplatesPath() + "/components")
|
||||
require.NoError(t, err)
|
||||
for _, f := range components {
|
||||
expectedTemplates[f.Name()] = true
|
||||
}
|
||||
|
||||
for _, v := range parsed.Template.Templates() {
|
||||
delete(expectedTemplates, v.Name())
|
||||
}
|
||||
assert.Empty(t, expectedTemplates)
|
||||
})
|
||||
|
||||
t.Run("htmx rendering", func(t *testing.T) {
|
||||
ctx, _, ctr, p := setup()
|
||||
p.HTMX.Request.Enabled = true
|
||||
p.HTMX.Response = &htmx.Response{
|
||||
Trigger: "trigger",
|
||||
}
|
||||
err := ctr.RenderPage(ctx, p)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check HTMX header
|
||||
assert.Equal(t, "trigger", ctx.Response().Header().Get(htmx.HeaderTrigger))
|
||||
|
||||
// Check the template cache
|
||||
parsed, err := c.TemplateRenderer.Load("page:htmx", p.Name)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check that all expected templates were parsed.
|
||||
// This includes the name, htmx and all components
|
||||
expectedTemplates := make(map[string]bool)
|
||||
expectedTemplates[p.Name+config.TemplateExt] = true
|
||||
expectedTemplates["htmx"+config.TemplateExt] = true
|
||||
components, err := ioutil.ReadDir(c.TemplateRenderer.GetTemplatesPath() + "/components")
|
||||
require.NoError(t, err)
|
||||
for _, f := range components {
|
||||
expectedTemplates[f.Name()] = true
|
||||
}
|
||||
|
||||
for _, v := range parsed.Template.Templates() {
|
||||
delete(expectedTemplates, v.Name())
|
||||
}
|
||||
assert.Empty(t, expectedTemplates)
|
||||
})
|
||||
|
||||
t.Run("page cache", func(t *testing.T) {
|
||||
ctx, rec, ctr, p := setup()
|
||||
p.Cache.Enabled = true
|
||||
p.Cache.Tags = []string{"tag1"}
|
||||
err := ctr.RenderPage(ctx, p)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch from the cache
|
||||
res, err := c.Cache.
|
||||
Get().
|
||||
Group(middleware.CachedPageGroup).
|
||||
Key(p.URL).
|
||||
Type(new(middleware.CachedPage)).
|
||||
Fetch(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the cached page
|
||||
cp, ok := res.(*middleware.CachedPage)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, p.URL, cp.URL)
|
||||
assert.Equal(t, p.Headers, cp.Headers)
|
||||
assert.Equal(t, p.StatusCode, cp.StatusCode)
|
||||
assert.Equal(t, rec.Body.Bytes(), cp.HTML)
|
||||
|
||||
// Clear the tag
|
||||
err = c.Cache.
|
||||
Flush().
|
||||
Tags(p.Cache.Tags[0]).
|
||||
Execute(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refetch from the cache and expect no results
|
||||
_, err = c.Cache.
|
||||
Get().
|
||||
Group(middleware.CachedPageGroup).
|
||||
Key(p.URL).
|
||||
Type(new(middleware.CachedPage)).
|
||||
Fetch(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
104
pkg/controller/form.go
Normal file
104
pkg/controller/form.go
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// FormSubmission represents the state of the submission of a form, not including the form itself
|
||||
type FormSubmission struct {
|
||||
// IsSubmitted indicates if the form has been submitted
|
||||
IsSubmitted bool
|
||||
|
||||
// Errors stores a slice of error message strings keyed by form struct field name
|
||||
Errors map[string][]string
|
||||
}
|
||||
|
||||
// Process processes a submission for a form
|
||||
func (f *FormSubmission) Process(ctx echo.Context, form interface{}) error {
|
||||
f.Errors = make(map[string][]string)
|
||||
f.IsSubmitted = true
|
||||
|
||||
// Validate the form
|
||||
if err := ctx.Validate(form); err != nil {
|
||||
f.setErrorMessages(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasErrors indicates if the submission has any validation errors
|
||||
func (f FormSubmission) HasErrors() bool {
|
||||
if f.Errors == nil {
|
||||
return false
|
||||
}
|
||||
return len(f.Errors) > 0
|
||||
}
|
||||
|
||||
// FieldHasErrors indicates if a given field on the form has any validation errors
|
||||
func (f FormSubmission) FieldHasErrors(fieldName string) bool {
|
||||
return len(f.GetFieldErrors(fieldName)) > 0
|
||||
}
|
||||
|
||||
// SetFieldError sets an error message for a given field name
|
||||
func (f *FormSubmission) SetFieldError(fieldName string, message string) {
|
||||
if f.Errors == nil {
|
||||
f.Errors = make(map[string][]string)
|
||||
}
|
||||
f.Errors[fieldName] = append(f.Errors[fieldName], message)
|
||||
}
|
||||
|
||||
// GetFieldErrors gets the errors for a given field name
|
||||
func (f FormSubmission) GetFieldErrors(fieldName string) []string {
|
||||
if f.Errors == nil {
|
||||
return []string{}
|
||||
}
|
||||
return f.Errors[fieldName]
|
||||
}
|
||||
|
||||
// GetFieldStatusClass returns an HTML class based on the status of the field
|
||||
func (f FormSubmission) GetFieldStatusClass(fieldName string) string {
|
||||
if f.IsSubmitted {
|
||||
if f.FieldHasErrors(fieldName) {
|
||||
return "is-danger"
|
||||
}
|
||||
return "is-success"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsDone indicates if the submission is considered done which is when it has been submitted
|
||||
// and there are no errors.
|
||||
func (f FormSubmission) IsDone() bool {
|
||||
return f.IsSubmitted && !f.HasErrors()
|
||||
}
|
||||
|
||||
// setErrorMessages sets errors messages on the submission for all fields that failed validation
|
||||
func (f *FormSubmission) setErrorMessages(err error) {
|
||||
// Only this is supported right now
|
||||
ves, ok := err.(validator.ValidationErrors)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ve := range ves {
|
||||
var message string
|
||||
|
||||
// Provide better error messages depending on the failed validation tag
|
||||
// This should be expanded as you use additional tags in your validation
|
||||
switch ve.Tag() {
|
||||
case "required":
|
||||
message = "This field is required."
|
||||
case "email":
|
||||
message = "Enter a valid email address."
|
||||
case "eqfield":
|
||||
message = "Does not match."
|
||||
default:
|
||||
message = "Invalid value."
|
||||
}
|
||||
|
||||
// Add the error
|
||||
f.SetFieldError(ve.Field(), message)
|
||||
}
|
||||
}
|
||||
36
pkg/controller/form_test.go
Normal file
36
pkg/controller/form_test.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFormSubmission(t *testing.T) {
|
||||
type formTest struct {
|
||||
Name string `validate:"required"`
|
||||
Email string `validate:"required,email"`
|
||||
Submission FormSubmission
|
||||
}
|
||||
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
form := formTest{
|
||||
Name: "",
|
||||
Email: "a@a.com",
|
||||
}
|
||||
err := form.Submission.Process(ctx, form)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, form.Submission.HasErrors())
|
||||
assert.True(t, form.Submission.FieldHasErrors("Name"))
|
||||
assert.False(t, form.Submission.FieldHasErrors("Email"))
|
||||
require.Len(t, form.Submission.GetFieldErrors("Name"), 1)
|
||||
assert.Len(t, form.Submission.GetFieldErrors("Email"), 0)
|
||||
assert.Equal(t, "This field is required.", form.Submission.GetFieldErrors("Name")[0])
|
||||
assert.Equal(t, "is-danger", form.Submission.GetFieldStatusClass("Name"))
|
||||
assert.Equal(t, "is-success", form.Submission.GetFieldStatusClass("Email"))
|
||||
assert.False(t, form.Submission.IsDone())
|
||||
}
|
||||
161
pkg/controller/page.go
Normal file
161
pkg/controller/page.go
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/htmx"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
|
||||
echomw "github.com/labstack/echo/v4/middleware"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// Page consists of all data that will be used to render a page response for a given controller.
|
||||
// While it's not required for a controller to render a Page on a route, this is the common data
|
||||
// object that will be passed to the templates, making it easy for all controllers to share
|
||||
// functionality both on the back and frontend. The Page can be expanded to include anything else
|
||||
// your app wants to support.
|
||||
// Methods on this page also then become available in the templates, which can be more useful than
|
||||
// the funcmap if your methods require data stored in the page, such as the context.
|
||||
type Page struct {
|
||||
// AppName stores the name of the application.
|
||||
// If omitted, the configuration value will be used.
|
||||
AppName string
|
||||
|
||||
// Title stores the title of the page
|
||||
Title string
|
||||
|
||||
// Context stores the request context
|
||||
Context echo.Context
|
||||
|
||||
// ToURL is a function to convert a route name and optional route parameters to a URL
|
||||
ToURL func(name string, params ...interface{}) string
|
||||
|
||||
// Path stores the path of the current request
|
||||
Path string
|
||||
|
||||
// URL stores the URL of the current request
|
||||
URL string
|
||||
|
||||
// Data stores whatever additional data that needs to be passed to the templates.
|
||||
// This is what the controller uses to pass the content of the page.
|
||||
Data interface{}
|
||||
|
||||
// Form stores a struct that represents a form on the page.
|
||||
// This should be a struct with fields for each form field, using both "form" and "validate" tags
|
||||
// It should also contain a Submission field of type FormSubmission if you wish to have validation
|
||||
// messagesa and markup presented to the user
|
||||
Form interface{}
|
||||
|
||||
// Layout stores the name of the layout base template file which will be used when the page is rendered.
|
||||
// This should match a template file located within the layouts directory inside the templates directory.
|
||||
// The template extension should not be included in this value.
|
||||
Layout string
|
||||
|
||||
// Name stores the name of the page as well as the name of the template file which will be used to render
|
||||
// the content portion of the layout template.
|
||||
// This should match a template file located within the pages directory inside the templates directory.
|
||||
// The template extension should not be included in this value.
|
||||
Name string
|
||||
|
||||
// IsHome stores whether the requested page is the home page or not
|
||||
IsHome bool
|
||||
|
||||
// IsAuth stores whether or not the user is authenticated
|
||||
IsAuth bool
|
||||
|
||||
// AuthUser stores the authenticated user
|
||||
AuthUser *ent.User
|
||||
|
||||
// StatusCode stores the HTTP status code that will be returned
|
||||
StatusCode int
|
||||
|
||||
// Metatags stores metatag values
|
||||
Metatags struct {
|
||||
// Description stores the description metatag value
|
||||
Description string
|
||||
|
||||
// Keywords stores the keywords metatag values
|
||||
Keywords []string
|
||||
}
|
||||
|
||||
// Pager stores a pager which can be used to page lists of results
|
||||
Pager Pager
|
||||
|
||||
// CSRF stores the CSRF token for the given request.
|
||||
// This will only be populated if the CSRF middleware is in effect for the given request.
|
||||
// If this is populated, all forms must include this value otherwise the requests will be rejected.
|
||||
CSRF string
|
||||
|
||||
// Headers stores a list of HTTP headers and values to be set on the response
|
||||
Headers map[string]string
|
||||
|
||||
// RequestID stores the ID of the given request.
|
||||
// This will only be populated if the request ID middleware is in effect for the given request.
|
||||
RequestID string
|
||||
|
||||
HTMX struct {
|
||||
Request htmx.Request
|
||||
Response *htmx.Response
|
||||
}
|
||||
|
||||
// Cache stores values for caching the response of this page
|
||||
Cache struct {
|
||||
// Enabled dictates if the response of this page should be cached.
|
||||
// Cached responses are served via middleware.
|
||||
Enabled bool
|
||||
|
||||
// Expiration stores the amount of time that the cache entry should live for before expiring.
|
||||
// If omitted, the configuration value will be used.
|
||||
Expiration time.Duration
|
||||
|
||||
// Tags stores a list of tags to apply to the cache entry.
|
||||
// These are useful when invalidating cache for dynamic events such as entity operations.
|
||||
Tags []string
|
||||
}
|
||||
}
|
||||
|
||||
// NewPage creates and initiatizes a new Page for a given request context
|
||||
func NewPage(ctx echo.Context) Page {
|
||||
p := Page{
|
||||
Context: ctx,
|
||||
ToURL: ctx.Echo().Reverse,
|
||||
Path: ctx.Request().URL.Path,
|
||||
URL: ctx.Request().URL.String(),
|
||||
StatusCode: http.StatusOK,
|
||||
Pager: NewPager(ctx, DefaultItemsPerPage),
|
||||
Headers: make(map[string]string),
|
||||
RequestID: ctx.Response().Header().Get(echo.HeaderXRequestID),
|
||||
}
|
||||
|
||||
p.IsHome = p.Path == "/"
|
||||
|
||||
if csrf := ctx.Get(echomw.DefaultCSRFConfig.ContextKey); csrf != nil {
|
||||
p.CSRF = csrf.(string)
|
||||
}
|
||||
|
||||
if u := ctx.Get(context.AuthenticatedUserKey); u != nil {
|
||||
p.IsAuth = true
|
||||
p.AuthUser = u.(*ent.User)
|
||||
}
|
||||
|
||||
p.HTMX.Request = htmx.GetRequest(ctx)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// GetMessages gets all flash messages for a given type.
|
||||
// This allows for easy access to flash messages from the templates.
|
||||
func (p Page) GetMessages(typ msg.Type) []template.HTML {
|
||||
strs := msg.Get(p.Context, typ)
|
||||
ret := make([]template.HTML, len(strs))
|
||||
for k, v := range strs {
|
||||
ret[k] = template.HTML(v)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
75
pkg/controller/page_test.go
Normal file
75
pkg/controller/page_test.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
echomw "github.com/labstack/echo/v4/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewPage(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
p := NewPage(ctx)
|
||||
assert.Same(t, ctx, p.Context)
|
||||
assert.NotNil(t, p.ToURL)
|
||||
assert.Equal(t, "/", p.Path)
|
||||
assert.Equal(t, "/", p.URL)
|
||||
assert.Equal(t, http.StatusOK, p.StatusCode)
|
||||
assert.Equal(t, NewPager(ctx, DefaultItemsPerPage), p.Pager)
|
||||
assert.Empty(t, p.Headers)
|
||||
assert.True(t, p.IsHome)
|
||||
assert.False(t, p.IsAuth)
|
||||
assert.Empty(t, p.CSRF)
|
||||
assert.Empty(t, p.RequestID)
|
||||
assert.False(t, p.Cache.Enabled)
|
||||
|
||||
ctx, _ = tests.NewContext(c.Web, "/abc?def=123")
|
||||
usr, err := tests.CreateUser(c.ORM)
|
||||
require.NoError(t, err)
|
||||
ctx.Set(context.AuthenticatedUserKey, usr)
|
||||
ctx.Set(echomw.DefaultCSRFConfig.ContextKey, "csrf")
|
||||
p = NewPage(ctx)
|
||||
assert.Equal(t, "/abc", p.Path)
|
||||
assert.Equal(t, "/abc?def=123", p.URL)
|
||||
assert.False(t, p.IsHome)
|
||||
assert.True(t, p.IsAuth)
|
||||
assert.Equal(t, usr, p.AuthUser)
|
||||
assert.Equal(t, "csrf", p.CSRF)
|
||||
}
|
||||
|
||||
func TestPage_GetMessages(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
tests.InitSession(ctx)
|
||||
p := NewPage(ctx)
|
||||
|
||||
// Set messages
|
||||
msgTests := make(map[msg.Type][]string)
|
||||
msgTests[msg.TypeWarning] = []string{
|
||||
"abc",
|
||||
"def",
|
||||
}
|
||||
msgTests[msg.TypeInfo] = []string{
|
||||
"123",
|
||||
"456",
|
||||
}
|
||||
for typ, values := range msgTests {
|
||||
for _, value := range values {
|
||||
msg.Set(ctx, typ, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Get the messages
|
||||
for typ, values := range msgTests {
|
||||
msgs := p.GetMessages(typ)
|
||||
|
||||
for i, message := range msgs {
|
||||
assert.Equal(t, values[i], string(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
80
pkg/controller/pager.go
Normal file
80
pkg/controller/pager.go
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultItemsPerPage stores the default amount of items per page
|
||||
DefaultItemsPerPage = 20
|
||||
|
||||
// PageQueryKey stores the query key used to indicate the current page
|
||||
PageQueryKey = "page"
|
||||
)
|
||||
|
||||
// Pager provides a mechanism to allow a user to page results via a query parameter
|
||||
type Pager struct {
|
||||
// Items stores the total amount of items in the result set
|
||||
Items int
|
||||
|
||||
// Page stores the current page number
|
||||
Page int
|
||||
|
||||
// ItemsPerPage stores the amount of items to display per page
|
||||
ItemsPerPage int
|
||||
|
||||
// Pages stores the total amount of pages in the result set
|
||||
Pages int
|
||||
}
|
||||
|
||||
// NewPager creates a new Pager
|
||||
func NewPager(ctx echo.Context, itemsPerPage int) Pager {
|
||||
p := Pager{
|
||||
ItemsPerPage: itemsPerPage,
|
||||
Page: 1,
|
||||
}
|
||||
|
||||
if page := ctx.QueryParam(PageQueryKey); page != "" {
|
||||
if pageInt, err := strconv.Atoi(page); err == nil {
|
||||
if pageInt > 0 {
|
||||
p.Page = pageInt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// SetItems sets the amount of items in total for the pager and calculate the amount
|
||||
// of total pages based off on the item per page.
|
||||
// This should be used rather than setting either items or pages directly.
|
||||
func (p *Pager) SetItems(items int) {
|
||||
p.Items = items
|
||||
p.Pages = int(math.Ceil(float64(items) / float64(p.ItemsPerPage)))
|
||||
|
||||
if p.Page > p.Pages {
|
||||
p.Page = p.Pages
|
||||
}
|
||||
}
|
||||
|
||||
// IsBeginning determines if the pager is at the beginning of the pages
|
||||
func (p Pager) IsBeginning() bool {
|
||||
return p.Page == 1
|
||||
}
|
||||
|
||||
// IsEnd determines if the pager is at the end of the pages
|
||||
func (p Pager) IsEnd() bool {
|
||||
return p.Page >= p.Pages
|
||||
}
|
||||
|
||||
// GetOffset determines the offset of the results in order to get the items for
|
||||
// the current page
|
||||
func (p Pager) GetOffset() int {
|
||||
if p.Page == 0 {
|
||||
p.Page = 1
|
||||
}
|
||||
return (p.Page - 1) * p.ItemsPerPage
|
||||
}
|
||||
67
pkg/controller/pager_test.go
Normal file
67
pkg/controller/pager_test.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewPager(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
pgr := NewPager(ctx, 10)
|
||||
assert.Equal(t, 10, pgr.ItemsPerPage)
|
||||
assert.Equal(t, 1, pgr.Page)
|
||||
assert.Equal(t, 0, pgr.Items)
|
||||
assert.Equal(t, 0, pgr.Pages)
|
||||
|
||||
ctx, _ = tests.NewContext(c.Web, fmt.Sprintf("/abc?%s=%d", PageQueryKey, 2))
|
||||
pgr = NewPager(ctx, 10)
|
||||
assert.Equal(t, 2, pgr.Page)
|
||||
|
||||
ctx, _ = tests.NewContext(c.Web, fmt.Sprintf("/abc?%s=%d", PageQueryKey, -2))
|
||||
pgr = NewPager(ctx, 10)
|
||||
assert.Equal(t, 1, pgr.Page)
|
||||
}
|
||||
|
||||
func TestPager_SetItems(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
pgr := NewPager(ctx, 20)
|
||||
pgr.SetItems(100)
|
||||
assert.Equal(t, 100, pgr.Items)
|
||||
assert.Equal(t, 5, pgr.Pages)
|
||||
}
|
||||
|
||||
func TestPager_IsBeginning(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
pgr := NewPager(ctx, 20)
|
||||
pgr.Pages = 10
|
||||
assert.True(t, pgr.IsBeginning())
|
||||
pgr.Page = 2
|
||||
assert.False(t, pgr.IsBeginning())
|
||||
pgr.Page = 1
|
||||
assert.True(t, pgr.IsBeginning())
|
||||
}
|
||||
|
||||
func TestPager_IsEnd(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
pgr := NewPager(ctx, 20)
|
||||
pgr.Pages = 10
|
||||
assert.False(t, pgr.IsEnd())
|
||||
pgr.Page = 10
|
||||
assert.True(t, pgr.IsEnd())
|
||||
pgr.Page = 1
|
||||
assert.False(t, pgr.IsEnd())
|
||||
}
|
||||
|
||||
func TestPager_GetOffset(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
pgr := NewPager(ctx, 20)
|
||||
assert.Equal(t, 0, pgr.GetOffset())
|
||||
pgr.Page = 2
|
||||
assert.Equal(t, 20, pgr.GetOffset())
|
||||
pgr.Page = 3
|
||||
assert.Equal(t, 40, pgr.GetOffset())
|
||||
}
|
||||
66
pkg/funcmap/funcmap.go
Normal file
66
pkg/funcmap/funcmap.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package funcmap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
|
||||
"github.com/Masterminds/sprig"
|
||||
"github.com/labstack/gommon/random"
|
||||
)
|
||||
|
||||
var (
|
||||
// CacheBuster stores a random string used as a cache buster for static files.
|
||||
CacheBuster = random.String(10)
|
||||
)
|
||||
|
||||
// GetFuncMap provides a template function map
|
||||
func GetFuncMap() template.FuncMap {
|
||||
// See http://masterminds.github.io/sprig/ for available funcs
|
||||
funcMap := sprig.FuncMap()
|
||||
|
||||
// Provide a list of custom functions
|
||||
// Expand this as you add more functions to this package
|
||||
// Avoid using a name already in use by sprig
|
||||
f := template.FuncMap{
|
||||
"hasField": HasField,
|
||||
"file": File,
|
||||
"link": Link,
|
||||
}
|
||||
|
||||
for k, v := range f {
|
||||
funcMap[k] = v
|
||||
}
|
||||
|
||||
return funcMap
|
||||
}
|
||||
|
||||
// HasField checks if an interface contains a given field
|
||||
func HasField(v interface{}, name string) bool {
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
if rv.Kind() != reflect.Struct {
|
||||
return false
|
||||
}
|
||||
return rv.FieldByName(name).IsValid()
|
||||
}
|
||||
|
||||
// File appends a cache buster to a given filepath so it can remain cached until the app is restarted
|
||||
func File(filepath string) string {
|
||||
return fmt.Sprintf("/%s/%s?v=%s", config.StaticPrefix, filepath, CacheBuster)
|
||||
}
|
||||
|
||||
// Link outputs HTML for a link element, providing the ability to dynamically set the active class
|
||||
func Link(url, text, currentPath string, classes ...string) template.HTML {
|
||||
if currentPath == url {
|
||||
classes = append(classes, "is-active")
|
||||
}
|
||||
|
||||
html := fmt.Sprintf(`<a class="%s" href="%s">%s</a>`, strings.Join(classes, " "), url, text)
|
||||
return template.HTML(html)
|
||||
}
|
||||
39
pkg/funcmap/funcmap_test.go
Normal file
39
pkg/funcmap/funcmap_test.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
package funcmap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHasField(t *testing.T) {
|
||||
type example struct {
|
||||
name string
|
||||
}
|
||||
var e example
|
||||
assert.True(t, HasField(e, "name"))
|
||||
assert.False(t, HasField(e, "abcd"))
|
||||
}
|
||||
|
||||
func TestLink(t *testing.T) {
|
||||
link := string(Link("/abc", "Text", "/abc"))
|
||||
expected := `<a class="is-active" href="/abc">Text</a>`
|
||||
assert.Equal(t, expected, link)
|
||||
|
||||
link = string(Link("/abc", "Text", "/abc", "first", "second"))
|
||||
expected = `<a class="first second is-active" href="/abc">Text</a>`
|
||||
assert.Equal(t, expected, link)
|
||||
|
||||
link = string(Link("/abc", "Text", "/def"))
|
||||
expected = `<a class="" href="/abc">Text</a>`
|
||||
assert.Equal(t, expected, link)
|
||||
}
|
||||
|
||||
func TestGetFuncMap(t *testing.T) {
|
||||
file := File("test.png")
|
||||
expected := fmt.Sprintf("/%s/test.png?v=%s", config.StaticPrefix, CacheBuster)
|
||||
assert.Equal(t, expected, file)
|
||||
}
|
||||
82
pkg/htmx/htmx.go
Normal file
82
pkg/htmx/htmx.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package htmx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// Headers (https://htmx.org/docs/#requests)
|
||||
const (
|
||||
HeaderRequest = "HX-Request"
|
||||
HeaderBoosted = "HX-Boosted"
|
||||
HeaderTrigger = "HX-Trigger"
|
||||
HeaderTriggerName = "HX-Trigger-Name"
|
||||
HeaderTriggerAfterSwap = "HX-Trigger-After-Swap"
|
||||
HeaderTriggerAfterSettle = "HX-Trigger-After-Settle"
|
||||
HeaderTarget = "HX-Target"
|
||||
HeaderPrompt = "HX-Prompt"
|
||||
HeaderPush = "HX-Push"
|
||||
HeaderRedirect = "HX-Redirect"
|
||||
HeaderRefresh = "HX-Refresh"
|
||||
)
|
||||
|
||||
type (
|
||||
// Request contains data that HTMX provides during requests
|
||||
Request struct {
|
||||
Enabled bool
|
||||
Boosted bool
|
||||
Trigger string
|
||||
TriggerName string
|
||||
Target string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// Response contain data that the server can communicate back to HTMX
|
||||
Response struct {
|
||||
Push string
|
||||
Redirect string
|
||||
Refresh bool
|
||||
Trigger string
|
||||
TriggerAfterSwap string
|
||||
TriggerAfterSettle string
|
||||
NoContent bool
|
||||
}
|
||||
)
|
||||
|
||||
// GetRequest extracts HTMX data from the request
|
||||
func GetRequest(ctx echo.Context) Request {
|
||||
return Request{
|
||||
Enabled: ctx.Request().Header.Get(HeaderRequest) == "true",
|
||||
Boosted: ctx.Request().Header.Get(HeaderBoosted) == "true",
|
||||
Trigger: ctx.Request().Header.Get(HeaderTrigger),
|
||||
TriggerName: ctx.Request().Header.Get(HeaderTriggerName),
|
||||
Target: ctx.Request().Header.Get(HeaderTarget),
|
||||
Prompt: ctx.Request().Header.Get(HeaderPrompt),
|
||||
}
|
||||
}
|
||||
|
||||
// Apply applies data from a Response to a server response
|
||||
func (r Response) Apply(ctx echo.Context) {
|
||||
if r.Push != "" {
|
||||
ctx.Response().Header().Set(HeaderPush, r.Push)
|
||||
}
|
||||
if r.Redirect != "" {
|
||||
ctx.Response().Header().Set(HeaderRedirect, r.Redirect)
|
||||
}
|
||||
if r.Refresh {
|
||||
ctx.Response().Header().Set(HeaderRefresh, "true")
|
||||
}
|
||||
if r.Trigger != "" {
|
||||
ctx.Response().Header().Set(HeaderTrigger, r.Trigger)
|
||||
}
|
||||
if r.TriggerAfterSwap != "" {
|
||||
ctx.Response().Header().Set(HeaderTriggerAfterSwap, r.TriggerAfterSwap)
|
||||
}
|
||||
if r.TriggerAfterSettle != "" {
|
||||
ctx.Response().Header().Set(HeaderTriggerAfterSettle, r.TriggerAfterSettle)
|
||||
}
|
||||
if r.NoContent {
|
||||
ctx.Response().Status = http.StatusNoContent
|
||||
}
|
||||
}
|
||||
52
pkg/htmx/htmx_test.go
Normal file
52
pkg/htmx/htmx_test.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package htmx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func TestSetRequest(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(echo.New(), "/")
|
||||
ctx.Request().Header.Set(HeaderRequest, "true")
|
||||
ctx.Request().Header.Set(HeaderBoosted, "true")
|
||||
ctx.Request().Header.Set(HeaderTrigger, "a")
|
||||
ctx.Request().Header.Set(HeaderTriggerName, "b")
|
||||
ctx.Request().Header.Set(HeaderTarget, "c")
|
||||
ctx.Request().Header.Set(HeaderPrompt, "d")
|
||||
|
||||
r := GetRequest(ctx)
|
||||
assert.Equal(t, true, r.Enabled)
|
||||
assert.Equal(t, true, r.Boosted)
|
||||
assert.Equal(t, "a", r.Trigger)
|
||||
assert.Equal(t, "b", r.TriggerName)
|
||||
assert.Equal(t, "c", r.Target)
|
||||
assert.Equal(t, "d", r.Prompt)
|
||||
}
|
||||
|
||||
func TestResponse_Apply(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(echo.New(), "/")
|
||||
r := Response{
|
||||
Push: "a",
|
||||
Redirect: "b",
|
||||
Refresh: true,
|
||||
Trigger: "c",
|
||||
TriggerAfterSwap: "d",
|
||||
TriggerAfterSettle: "e",
|
||||
NoContent: true,
|
||||
}
|
||||
r.Apply(ctx)
|
||||
|
||||
assert.Equal(t, "a", ctx.Response().Header().Get(HeaderPush))
|
||||
assert.Equal(t, "b", ctx.Response().Header().Get(HeaderRedirect))
|
||||
assert.Equal(t, "true", ctx.Response().Header().Get(HeaderRefresh))
|
||||
assert.Equal(t, "c", ctx.Response().Header().Get(HeaderTrigger))
|
||||
assert.Equal(t, "d", ctx.Response().Header().Get(HeaderTriggerAfterSwap))
|
||||
assert.Equal(t, "e", ctx.Response().Header().Get(HeaderTriggerAfterSettle))
|
||||
assert.Equal(t, http.StatusNoContent, ctx.Response().Status)
|
||||
}
|
||||
108
pkg/middleware/auth.go
Normal file
108
pkg/middleware/auth.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
"github.com/mikestefanello/pagoda/pkg/services"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// LoadAuthenticatedUser loads the authenticated user, if one, and stores in context
|
||||
func LoadAuthenticatedUser(authClient *services.AuthClient) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
u, err := authClient.GetAuthenticatedUser(c)
|
||||
switch err.(type) {
|
||||
case *ent.NotFoundError:
|
||||
c.Logger().Warn("auth user not found")
|
||||
case services.NotAuthenticatedError:
|
||||
case nil:
|
||||
c.Set(context.AuthenticatedUserKey, u)
|
||||
c.Logger().Infof("auth user loaded in to context: %d", u.ID)
|
||||
default:
|
||||
return echo.NewHTTPError(
|
||||
http.StatusInternalServerError,
|
||||
fmt.Sprintf("error querying for authenticated user: %v", err),
|
||||
)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadValidPasswordToken loads a valid password token entity that matches the user and token
|
||||
// provided in path parameters
|
||||
// If the token is invalid, the user will be redirected to the forgot password route
|
||||
// This requires that the user owning the token is loaded in to context
|
||||
func LoadValidPasswordToken(authClient *services.AuthClient) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Extract the user parameter
|
||||
if c.Get(context.UserKey) == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError)
|
||||
}
|
||||
usr := c.Get(context.UserKey).(*ent.User)
|
||||
|
||||
// Extract the token ID
|
||||
tokenID, err := strconv.Atoi(c.Param("password_token"))
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound)
|
||||
}
|
||||
|
||||
// Attempt to load a valid password token
|
||||
token, err := authClient.GetValidPasswordToken(
|
||||
c,
|
||||
usr.ID,
|
||||
tokenID,
|
||||
c.Param("token"),
|
||||
)
|
||||
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
c.Set(context.PasswordTokenKey, token)
|
||||
return next(c)
|
||||
case services.InvalidPasswordTokenError:
|
||||
msg.Warning(c, "The link is either invalid or has expired. Please request a new one.")
|
||||
return c.Redirect(http.StatusFound, c.Echo().Reverse("forgot_password"))
|
||||
default:
|
||||
return echo.NewHTTPError(
|
||||
http.StatusInternalServerError,
|
||||
fmt.Sprintf("error loading password token: %v", err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuthentication requires that the user be authenticated in order to proceed
|
||||
func RequireAuthentication() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if u := c.Get(context.AuthenticatedUserKey); u == nil {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireNoAuthentication requires that the user not be authenticated in order to proceed
|
||||
func RequireNoAuthentication() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if u := c.Get(context.AuthenticatedUserKey); u != nil {
|
||||
return echo.NewHTTPError(http.StatusForbidden)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
111
pkg/middleware/auth_test.go
Normal file
111
pkg/middleware/auth_test.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLoadAuthenticatedUser(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
tests.InitSession(ctx)
|
||||
mw := LoadAuthenticatedUser(c.Auth)
|
||||
|
||||
// Not authenticated
|
||||
_ = tests.ExecuteMiddleware(ctx, mw)
|
||||
assert.Nil(t, ctx.Get(context.AuthenticatedUserKey))
|
||||
|
||||
// Login
|
||||
err := c.Auth.Login(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the midldeware returns the authenticated user
|
||||
_ = tests.ExecuteMiddleware(ctx, mw)
|
||||
require.NotNil(t, ctx.Get(context.AuthenticatedUserKey))
|
||||
ctxUsr, ok := ctx.Get(context.AuthenticatedUserKey).(*ent.User)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, usr.ID, ctxUsr.ID)
|
||||
}
|
||||
|
||||
func TestRequireAuthentication(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
tests.InitSession(ctx)
|
||||
|
||||
// Not logged in
|
||||
err := tests.ExecuteMiddleware(ctx, RequireAuthentication())
|
||||
tests.AssertHTTPErrorCode(t, err, http.StatusUnauthorized)
|
||||
|
||||
// Login
|
||||
err = c.Auth.Login(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
_ = tests.ExecuteMiddleware(ctx, LoadAuthenticatedUser(c.Auth))
|
||||
|
||||
// Logged in
|
||||
err = tests.ExecuteMiddleware(ctx, RequireAuthentication())
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestRequireNoAuthentication(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
tests.InitSession(ctx)
|
||||
|
||||
// Not logged in
|
||||
err := tests.ExecuteMiddleware(ctx, RequireNoAuthentication())
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Login
|
||||
err = c.Auth.Login(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
_ = tests.ExecuteMiddleware(ctx, LoadAuthenticatedUser(c.Auth))
|
||||
|
||||
// Logged in
|
||||
err = tests.ExecuteMiddleware(ctx, RequireNoAuthentication())
|
||||
tests.AssertHTTPErrorCode(t, err, http.StatusForbidden)
|
||||
}
|
||||
|
||||
func TestLoadValidPasswordToken(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
tests.InitSession(ctx)
|
||||
|
||||
// Missing user context
|
||||
err := tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
|
||||
tests.AssertHTTPErrorCode(t, err, http.StatusInternalServerError)
|
||||
|
||||
// Add user and password token context but no token and expect a redirect
|
||||
ctx.SetParamNames("user", "password_token")
|
||||
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), "1")
|
||||
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
|
||||
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusFound, ctx.Response().Status)
|
||||
|
||||
// Add user context and invalid password token and expect a redirect
|
||||
ctx.SetParamNames("user", "password_token", "token")
|
||||
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), "1", "faketoken")
|
||||
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
|
||||
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusFound, ctx.Response().Status)
|
||||
|
||||
// Create a valid token
|
||||
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add user and valid password token
|
||||
ctx.SetParamNames("user", "password_token", "token")
|
||||
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID), fmt.Sprintf("%d", pt.ID), token)
|
||||
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
|
||||
err = tests.ExecuteMiddleware(ctx, LoadValidPasswordToken(c.Auth))
|
||||
assert.Nil(t, err)
|
||||
ctxPt, ok := ctx.Get(context.PasswordTokenKey).(*ent.PasswordToken)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, pt.ID, ctxPt.ID)
|
||||
}
|
||||
102
pkg/middleware/cache.go
Normal file
102
pkg/middleware/cache.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/services"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// CachedPageGroup stores the cache group for cached pages
|
||||
const CachedPageGroup = "page"
|
||||
|
||||
// CachedPage is what is used to store a rendered Page in the cache
|
||||
type CachedPage struct {
|
||||
// URL stores the URL of the requested page
|
||||
URL string
|
||||
|
||||
// HTML stores the complete HTML of the rendered Page
|
||||
HTML []byte
|
||||
|
||||
// StatusCode stores the HTTP status code
|
||||
StatusCode int
|
||||
|
||||
// Headers stores the HTTP headers
|
||||
Headers map[string]string
|
||||
}
|
||||
|
||||
// ServeCachedPage attempts to load a page from the cache by matching on the complete request URL
|
||||
// If a page is cached for the requested URL, it will be served here and the request terminated.
|
||||
// Any request made by an authenticated user or that is not a GET will be skipped.
|
||||
func ServeCachedPage(ch *services.CacheClient) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Skip non GET requests
|
||||
if c.Request().Method != http.MethodGet {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Skip if the user is authenticated
|
||||
if c.Get(context.AuthenticatedUserKey) != nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Attempt to load from cache
|
||||
res, err := ch.
|
||||
Get().
|
||||
Group(CachedPageGroup).
|
||||
Key(c.Request().URL.String()).
|
||||
Type(new(CachedPage)).
|
||||
Fetch(c.Request().Context())
|
||||
|
||||
if err != nil {
|
||||
switch {
|
||||
case err == redis.Nil:
|
||||
c.Logger().Info("no cached page found")
|
||||
case context.IsCanceledError(err):
|
||||
return nil
|
||||
default:
|
||||
c.Logger().Errorf("failed getting cached page: %v", err)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
|
||||
page, ok := res.(*CachedPage)
|
||||
if !ok {
|
||||
c.Logger().Errorf("failed casting cached page")
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Set any headers
|
||||
if page.Headers != nil {
|
||||
for k, v := range page.Headers {
|
||||
c.Response().Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
c.Logger().Info("serving cached page")
|
||||
|
||||
return c.HTMLBlob(page.StatusCode, page.HTML)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CacheControl sets a Cache-Control header with a given max age
|
||||
func CacheControl(maxAge time.Duration) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
v := "no-cache, no-store"
|
||||
if maxAge > 0 {
|
||||
v = fmt.Sprintf("public, max-age=%.0f", maxAge.Seconds())
|
||||
}
|
||||
c.Response().Header().Set("Cache-Control", v)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
59
pkg/middleware/cache_test.go
Normal file
59
pkg/middleware/cache_test.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestServeCachedPage(t *testing.T) {
|
||||
// Cache a page
|
||||
cp := CachedPage{
|
||||
URL: "/cache",
|
||||
HTML: []byte("html"),
|
||||
Headers: make(map[string]string),
|
||||
StatusCode: http.StatusCreated,
|
||||
}
|
||||
cp.Headers["a"] = "b"
|
||||
cp.Headers["c"] = "d"
|
||||
|
||||
err := c.Cache.
|
||||
Set().
|
||||
Group(CachedPageGroup).
|
||||
Key(cp.URL).
|
||||
Data(cp).
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Request the URL of the cached page
|
||||
ctx, rec := tests.NewContext(c.Web, cp.URL)
|
||||
err = tests.ExecuteMiddleware(ctx, ServeCachedPage(c.Cache))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cp.StatusCode, ctx.Response().Status)
|
||||
assert.Equal(t, cp.Headers["a"], ctx.Response().Header().Get("a"))
|
||||
assert.Equal(t, cp.Headers["c"], ctx.Response().Header().Get("c"))
|
||||
assert.Equal(t, cp.HTML, rec.Body.Bytes())
|
||||
|
||||
// Login and try again
|
||||
tests.InitSession(ctx)
|
||||
err = c.Auth.Login(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
_ = tests.ExecuteMiddleware(ctx, LoadAuthenticatedUser(c.Auth))
|
||||
err = tests.ExecuteMiddleware(ctx, ServeCachedPage(c.Cache))
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestCacheControl(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
_ = tests.ExecuteMiddleware(ctx, CacheControl(time.Second*5))
|
||||
assert.Equal(t, "public, max-age=5", ctx.Response().Header().Get("Cache-Control"))
|
||||
_ = tests.ExecuteMiddleware(ctx, CacheControl(0))
|
||||
assert.Equal(t, "no-cache, no-store", ctx.Response().Header().Get("Cache-Control"))
|
||||
}
|
||||
43
pkg/middleware/entity.go
Normal file
43
pkg/middleware/entity.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/ent/user"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// LoadUser loads the user based on the ID provided as a path parameter
|
||||
func LoadUser(orm *ent.Client) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID, err := strconv.Atoi(c.Param("user"))
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound)
|
||||
}
|
||||
|
||||
u, err := orm.User.
|
||||
Query().
|
||||
Where(user.ID(userID)).
|
||||
Only(c.Request().Context())
|
||||
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
c.Set(context.UserKey, u)
|
||||
return next(c)
|
||||
case *ent.NotFoundError:
|
||||
return echo.NewHTTPError(http.StatusNotFound)
|
||||
default:
|
||||
return echo.NewHTTPError(
|
||||
http.StatusInternalServerError,
|
||||
fmt.Sprintf("error querying user: %v", err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
23
pkg/middleware/entity_test.go
Normal file
23
pkg/middleware/entity_test.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoadUser(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
ctx.SetParamNames("user")
|
||||
ctx.SetParamValues(fmt.Sprintf("%d", usr.ID))
|
||||
_ = tests.ExecuteMiddleware(ctx, LoadUser(c.ORM))
|
||||
ctxUsr, ok := ctx.Get(context.UserKey).(*ent.User)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, usr.ID, ctxUsr.ID)
|
||||
}
|
||||
20
pkg/middleware/log.go
Normal file
20
pkg/middleware/log.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// LogRequestID includes the request ID in all logs for the given request
|
||||
// This requires that middleware that includes the request ID first execute
|
||||
func LogRequestID() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
rID := c.Response().Header().Get(echo.HeaderXRequestID)
|
||||
format := `{"time":"${time_rfc3339_nano}","id":"%s","level":"${level}","prefix":"${prefix}","file":"${short_file}","line":"${line}"}`
|
||||
c.Logger().SetHeader(fmt.Sprintf(format, rID))
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
27
pkg/middleware/log_test.go
Normal file
27
pkg/middleware/log_test.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
echomw "github.com/labstack/echo/v4/middleware"
|
||||
)
|
||||
|
||||
func TestLogRequestID(t *testing.T) {
|
||||
ctx, _ := tests.NewContext(c.Web, "/")
|
||||
_ = tests.ExecuteMiddleware(ctx, echomw.RequestID())
|
||||
_ = tests.ExecuteMiddleware(ctx, LogRequestID())
|
||||
|
||||
var buf bytes.Buffer
|
||||
ctx.Logger().SetOutput(&buf)
|
||||
ctx.Logger().Info("test")
|
||||
rID := ctx.Response().Header().Get(echo.HeaderXRequestID)
|
||||
assert.Contains(t, buf.String(), fmt.Sprintf(`id":"%s"`, rID))
|
||||
}
|
||||
40
pkg/middleware/middleware_test.go
Normal file
40
pkg/middleware/middleware_test.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/pkg/services"
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
c *services.Container
|
||||
usr *ent.User
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Set the environment to test
|
||||
config.SwitchEnvironment(config.EnvTest)
|
||||
|
||||
// Create a new container
|
||||
c = services.NewContainer()
|
||||
|
||||
// Create a user
|
||||
var err error
|
||||
if usr, err = tests.CreateUser(c.ORM); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
exitVal := m.Run()
|
||||
|
||||
// Shutdown the container
|
||||
if err = c.Shutdown(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
os.Exit(exitVal)
|
||||
}
|
||||
92
pkg/msg/msg.go
Normal file
92
pkg/msg/msg.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package msg
|
||||
|
||||
import (
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// Type is a message type
|
||||
type Type string
|
||||
|
||||
const (
|
||||
// TypeSuccess represents a success message type
|
||||
TypeSuccess Type = "success"
|
||||
|
||||
// TypeInfo represents a info message type
|
||||
TypeInfo Type = "info"
|
||||
|
||||
// TypeWarning represents a warning message type
|
||||
TypeWarning Type = "warning"
|
||||
|
||||
// TypeDanger represents a danger message type
|
||||
TypeDanger Type = "danger"
|
||||
)
|
||||
|
||||
const (
|
||||
// sessionName stores the name of the session which contains flash messages
|
||||
sessionName = "msg"
|
||||
)
|
||||
|
||||
// Success sets a success flash message
|
||||
func Success(ctx echo.Context, message string) {
|
||||
Set(ctx, TypeSuccess, message)
|
||||
}
|
||||
|
||||
// Info sets an info flash message
|
||||
func Info(ctx echo.Context, message string) {
|
||||
Set(ctx, TypeInfo, message)
|
||||
}
|
||||
|
||||
// Warning sets a warning flash message
|
||||
func Warning(ctx echo.Context, message string) {
|
||||
Set(ctx, TypeWarning, message)
|
||||
}
|
||||
|
||||
// Danger sets a danger flash message
|
||||
func Danger(ctx echo.Context, message string) {
|
||||
Set(ctx, TypeDanger, message)
|
||||
}
|
||||
|
||||
// Set adds a new flash message of a given type into the session storage
|
||||
// Errors will logged and not returned
|
||||
func Set(ctx echo.Context, typ Type, message string) {
|
||||
if sess, err := getSession(ctx); err == nil {
|
||||
sess.AddFlash(message, string(typ))
|
||||
save(ctx, sess)
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets flash messages of a given type from the session storage
|
||||
// Errors will logged and not returned
|
||||
func Get(ctx echo.Context, typ Type) []string {
|
||||
var msgs []string
|
||||
|
||||
if sess, err := getSession(ctx); err == nil {
|
||||
if flash := sess.Flashes(string(typ)); len(flash) > 0 {
|
||||
save(ctx, sess)
|
||||
|
||||
for _, m := range flash {
|
||||
msgs = append(msgs, m.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return msgs
|
||||
}
|
||||
|
||||
// getSession gets the flash message session
|
||||
func getSession(ctx echo.Context) (*sessions.Session, error) {
|
||||
sess, err := session.Get(sessionName, ctx)
|
||||
if err != nil {
|
||||
ctx.Logger().Errorf("cannot load flash message session: %v", err)
|
||||
}
|
||||
return sess, err
|
||||
}
|
||||
|
||||
// save saves the flash message session
|
||||
func save(ctx echo.Context, sess *sessions.Session) {
|
||||
if err := sess.Save(ctx.Request(), ctx.Response()); err != nil {
|
||||
ctx.Logger().Errorf("failed to set flash message: %v", err)
|
||||
}
|
||||
}
|
||||
46
pkg/msg/msg_test.go
Normal file
46
pkg/msg/msg_test.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package msg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func TestMsg(t *testing.T) {
|
||||
e := echo.New()
|
||||
ctx, _ := tests.NewContext(e, "/")
|
||||
tests.InitSession(ctx)
|
||||
|
||||
assertMsg := func(typ Type, message string) {
|
||||
ret := Get(ctx, typ)
|
||||
require.Len(t, ret, 1)
|
||||
assert.Equal(t, message, ret[0])
|
||||
ret = Get(ctx, typ)
|
||||
require.Len(t, ret, 0)
|
||||
}
|
||||
|
||||
text := "aaa"
|
||||
Success(ctx, text)
|
||||
assertMsg(TypeSuccess, text)
|
||||
|
||||
text = "bbb"
|
||||
Info(ctx, text)
|
||||
assertMsg(TypeInfo, text)
|
||||
|
||||
text = "ccc"
|
||||
Danger(ctx, text)
|
||||
assertMsg(TypeDanger, text)
|
||||
|
||||
text = "ddd"
|
||||
Warning(ctx, text)
|
||||
assertMsg(TypeWarning, text)
|
||||
|
||||
text = "eee"
|
||||
Set(ctx, TypeSuccess, text)
|
||||
assertMsg(TypeSuccess, text)
|
||||
}
|
||||
69
pkg/routes/about.go
Normal file
69
pkg/routes/about.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
about struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
aboutData struct {
|
||||
ShowCacheWarning bool
|
||||
FrontendTabs []aboutTab
|
||||
BackendTabs []aboutTab
|
||||
}
|
||||
|
||||
aboutTab struct {
|
||||
Title string
|
||||
Body template.HTML
|
||||
}
|
||||
)
|
||||
|
||||
func (c *about) Get(ctx echo.Context) error {
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "main"
|
||||
page.Name = "about"
|
||||
page.Title = "About"
|
||||
|
||||
// This page will be cached!
|
||||
page.Cache.Enabled = true
|
||||
page.Cache.Tags = []string{"page_about", "page:list"}
|
||||
|
||||
// A simple example of how the Data field can contain anything you want to send to the templates
|
||||
// even though you wouldn't normally send markup like this
|
||||
page.Data = aboutData{
|
||||
ShowCacheWarning: true,
|
||||
FrontendTabs: []aboutTab{
|
||||
{
|
||||
Title: "HTMX",
|
||||
Body: template.HTML(`Completes HTML as a hypertext by providing attributes to AJAXify anything and much more. Visit <a href="https://htmx.org/">htmx.org</a> to learn more.`),
|
||||
},
|
||||
{
|
||||
Title: "Alpine.js",
|
||||
Body: template.HTML(`Drop-in, Vue-like functionality written directly in your markup. Visit <a href="https://alpinejs.dev/">alpinejs.dev</a> to learn more.`),
|
||||
},
|
||||
{
|
||||
Title: "Bulma",
|
||||
Body: template.HTML(`Ready-to-use frontend components that you can easily combine to build responsive web interfaces with no JavaScript requirements. Visit <a href="https://bulma.io/">bulma.io</a> to learn more.`),
|
||||
},
|
||||
},
|
||||
BackendTabs: []aboutTab{
|
||||
{
|
||||
Title: "Echo",
|
||||
Body: template.HTML(`High performance, extensible, minimalist Go web framework. Visit <a href="https://echo.labstack.com/">echo.labstack.com</a> to learn more.`),
|
||||
},
|
||||
{
|
||||
Title: "Ent",
|
||||
Body: template.HTML(`Simple, yet powerful ORM for modeling and querying data. Visit <a href="https://entgo.io/">entgo.io</a> to learn more.`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return c.RenderPage(ctx, page)
|
||||
}
|
||||
23
pkg/routes/about_test.go
Normal file
23
pkg/routes/about_test.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Simple example of how to test routes and their markup using the test HTTP server spun up within
|
||||
// this test package
|
||||
func TestAbout_Get(t *testing.T) {
|
||||
doc := request(t).
|
||||
setRoute("about").
|
||||
get().
|
||||
assertStatusCode(http.StatusOK).
|
||||
toDoc()
|
||||
|
||||
// Goquery is an excellent package to use for testing HTML markup
|
||||
h1 := doc.Find("h1.title")
|
||||
assert.Len(t, h1.Nodes, 1)
|
||||
assert.Equal(t, "About", h1.Text())
|
||||
}
|
||||
65
pkg/routes/contact.go
Normal file
65
pkg/routes/contact.go
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
contact struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
contactForm struct {
|
||||
Email string `form:"email" validate:"required,email"`
|
||||
Message string `form:"message" validate:"required"`
|
||||
Submission controller.FormSubmission
|
||||
}
|
||||
)
|
||||
|
||||
func (c *contact) Get(ctx echo.Context) error {
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "main"
|
||||
page.Name = "contact"
|
||||
page.Title = "Contact us"
|
||||
page.Form = contactForm{}
|
||||
|
||||
if form := ctx.Get(context.FormKey); form != nil {
|
||||
page.Form = form.(*contactForm)
|
||||
}
|
||||
|
||||
return c.RenderPage(ctx, page)
|
||||
}
|
||||
|
||||
func (c *contact) Post(ctx echo.Context) error {
|
||||
var form contactForm
|
||||
ctx.Set(context.FormKey, &form)
|
||||
|
||||
// Parse the form values
|
||||
if err := ctx.Bind(&form); err != nil {
|
||||
return c.Fail(err, "unable to bind form")
|
||||
}
|
||||
|
||||
if err := form.Submission.Process(ctx, form); err != nil {
|
||||
return c.Fail(err, "unable to process form submission")
|
||||
}
|
||||
|
||||
if !form.Submission.HasErrors() {
|
||||
err := c.Container.Mail.
|
||||
Compose().
|
||||
To(form.Email).
|
||||
Subject("Contact form submitted").
|
||||
Body(fmt.Sprintf("The message is: %s", form.Message)).
|
||||
Send(ctx)
|
||||
|
||||
if err != nil {
|
||||
return c.Fail(err, "unable to send email")
|
||||
}
|
||||
}
|
||||
|
||||
return c.Get(ctx)
|
||||
}
|
||||
42
pkg/routes/error.go
Normal file
42
pkg/routes/error.go
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type errorHandler struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
func (e *errorHandler) Get(err error, ctx echo.Context) {
|
||||
if ctx.Response().Committed || context.IsCanceledError(err) {
|
||||
return
|
||||
}
|
||||
|
||||
code := http.StatusInternalServerError
|
||||
if he, ok := err.(*echo.HTTPError); ok {
|
||||
code = he.Code
|
||||
}
|
||||
|
||||
if code >= 500 {
|
||||
ctx.Logger().Error(err)
|
||||
} else {
|
||||
ctx.Logger().Info(err)
|
||||
}
|
||||
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "main"
|
||||
page.Title = http.StatusText(code)
|
||||
page.Name = "error"
|
||||
page.StatusCode = code
|
||||
page.HTMX.Request.Enabled = false
|
||||
|
||||
if err = e.RenderPage(ctx, page); err != nil {
|
||||
ctx.Logger().Error(err)
|
||||
}
|
||||
}
|
||||
100
pkg/routes/forgot_password.go
Normal file
100
pkg/routes/forgot_password.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/ent/user"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
forgotPassword struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
forgotPasswordForm struct {
|
||||
Email string `form:"email" validate:"required,email"`
|
||||
Submission controller.FormSubmission
|
||||
}
|
||||
)
|
||||
|
||||
func (c *forgotPassword) Get(ctx echo.Context) error {
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "auth"
|
||||
page.Name = "forgot-password"
|
||||
page.Title = "Forgot password"
|
||||
page.Form = forgotPasswordForm{}
|
||||
|
||||
if form := ctx.Get(context.FormKey); form != nil {
|
||||
page.Form = form.(*forgotPasswordForm)
|
||||
}
|
||||
|
||||
return c.RenderPage(ctx, page)
|
||||
}
|
||||
|
||||
func (c *forgotPassword) Post(ctx echo.Context) error {
|
||||
var form forgotPasswordForm
|
||||
ctx.Set(context.FormKey, &form)
|
||||
|
||||
succeed := func() error {
|
||||
ctx.Set(context.FormKey, nil)
|
||||
msg.Success(ctx, "An email containing a link to reset your password will be sent to this address if it exists in our system.")
|
||||
return c.Get(ctx)
|
||||
}
|
||||
|
||||
// Parse the form values
|
||||
if err := ctx.Bind(&form); err != nil {
|
||||
return c.Fail(err, "unable to parse forgot password form")
|
||||
}
|
||||
|
||||
if err := form.Submission.Process(ctx, form); err != nil {
|
||||
return c.Fail(err, "unable to process form submission")
|
||||
}
|
||||
|
||||
if form.Submission.HasErrors() {
|
||||
return c.Get(ctx)
|
||||
}
|
||||
|
||||
// Attempt to load the user
|
||||
u, err := c.Container.ORM.User.
|
||||
Query().
|
||||
Where(user.Email(strings.ToLower(form.Email))).
|
||||
Only(ctx.Request().Context())
|
||||
|
||||
switch err.(type) {
|
||||
case *ent.NotFoundError:
|
||||
return succeed()
|
||||
case nil:
|
||||
default:
|
||||
return c.Fail(err, "error querying user during forgot password")
|
||||
}
|
||||
|
||||
// Generate the token
|
||||
token, pt, err := c.Container.Auth.GeneratePasswordResetToken(ctx, u.ID)
|
||||
if err != nil {
|
||||
return c.Fail(err, "error generating password reset token")
|
||||
}
|
||||
|
||||
ctx.Logger().Infof("generated password reset token for user %d", u.ID)
|
||||
|
||||
// Email the user
|
||||
url := ctx.Echo().Reverse("reset_password", u.ID, pt.ID, token)
|
||||
err = c.Container.Mail.
|
||||
Compose().
|
||||
To(u.Email).
|
||||
Subject("Reset your password").
|
||||
Body(fmt.Sprintf("Go here to reset your password: %s", url)).
|
||||
Send(ctx)
|
||||
|
||||
if err != nil {
|
||||
return c.Fail(err, "error sending password reset email")
|
||||
}
|
||||
|
||||
return succeed()
|
||||
}
|
||||
46
pkg/routes/home.go
Normal file
46
pkg/routes/home.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
home struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
post struct {
|
||||
Title string
|
||||
Body string
|
||||
}
|
||||
)
|
||||
|
||||
func (c *home) Get(ctx echo.Context) error {
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "main"
|
||||
page.Name = "home"
|
||||
page.Metatags.Description = "Welcome to the homepage."
|
||||
page.Metatags.Keywords = []string{"Go", "MVC", "Web", "Software"}
|
||||
page.Pager = controller.NewPager(ctx, 4)
|
||||
page.Data = c.fetchPosts(&page.Pager)
|
||||
|
||||
return c.RenderPage(ctx, page)
|
||||
}
|
||||
|
||||
// fetchPosts is an mock example of fetching posts to illustrate how paging works
|
||||
func (c *home) fetchPosts(pager *controller.Pager) []post {
|
||||
pager.SetItems(20)
|
||||
posts := make([]post, 20)
|
||||
|
||||
for k := range posts {
|
||||
posts[k] = post{
|
||||
Title: fmt.Sprintf("Post example #%d", k+1),
|
||||
Body: fmt.Sprintf("Lorem ipsum example #%d ddolor sit amet, consectetur adipiscing elit. Nam elementum vulputate tristique.", k+1),
|
||||
}
|
||||
}
|
||||
return posts[pager.GetOffset() : pager.GetOffset()+pager.ItemsPerPage]
|
||||
}
|
||||
94
pkg/routes/login.go
Normal file
94
pkg/routes/login.go
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/ent/user"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
login struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
loginForm struct {
|
||||
Email string `form:"email" validate:"required,email"`
|
||||
Password string `form:"password" validate:"required"`
|
||||
Submission controller.FormSubmission
|
||||
}
|
||||
)
|
||||
|
||||
func (c *login) Get(ctx echo.Context) error {
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "auth"
|
||||
page.Name = "login"
|
||||
page.Title = "Log in"
|
||||
page.Form = loginForm{}
|
||||
|
||||
if form := ctx.Get(context.FormKey); form != nil {
|
||||
page.Form = form.(*loginForm)
|
||||
}
|
||||
|
||||
return c.RenderPage(ctx, page)
|
||||
}
|
||||
|
||||
func (c *login) Post(ctx echo.Context) error {
|
||||
var form loginForm
|
||||
ctx.Set(context.FormKey, &form)
|
||||
|
||||
authFailed := func() error {
|
||||
form.Submission.SetFieldError("Email", "")
|
||||
form.Submission.SetFieldError("Password", "")
|
||||
msg.Danger(ctx, "Invalid credentials. Please try again.")
|
||||
return c.Get(ctx)
|
||||
}
|
||||
|
||||
// Parse the form values
|
||||
if err := ctx.Bind(&form); err != nil {
|
||||
return c.Fail(err, "unable to parse login form")
|
||||
}
|
||||
|
||||
if err := form.Submission.Process(ctx, form); err != nil {
|
||||
return c.Fail(err, "unable to process form submission")
|
||||
}
|
||||
|
||||
if form.Submission.HasErrors() {
|
||||
return c.Get(ctx)
|
||||
}
|
||||
|
||||
// Attempt to load the user
|
||||
u, err := c.Container.ORM.User.
|
||||
Query().
|
||||
Where(user.Email(strings.ToLower(form.Email))).
|
||||
Only(ctx.Request().Context())
|
||||
|
||||
switch err.(type) {
|
||||
case *ent.NotFoundError:
|
||||
return authFailed()
|
||||
case nil:
|
||||
default:
|
||||
return c.Fail(err, "error querying user during login")
|
||||
}
|
||||
|
||||
// Check if the password is correct
|
||||
err = c.Container.Auth.CheckPassword(form.Password, u.Password)
|
||||
if err != nil {
|
||||
return authFailed()
|
||||
}
|
||||
|
||||
// Log the user in
|
||||
err = c.Container.Auth.Login(ctx, u.ID)
|
||||
if err != nil {
|
||||
return c.Fail(err, "unable to log in user")
|
||||
}
|
||||
|
||||
msg.Success(ctx, fmt.Sprintf("Welcome back, <strong>%s</strong>. You are now logged in.", u.Name))
|
||||
return c.Redirect(ctx, "home")
|
||||
}
|
||||
21
pkg/routes/logout.go
Normal file
21
pkg/routes/logout.go
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type logout struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
func (l *logout) Get(c echo.Context) error {
|
||||
if err := l.Container.Auth.Logout(c); err == nil {
|
||||
msg.Success(c, "You have been logged out successfully.")
|
||||
} else {
|
||||
msg.Danger(c, "An error occurred. Please try again.")
|
||||
}
|
||||
return l.Redirect(c, "home")
|
||||
}
|
||||
122
pkg/routes/register.go
Normal file
122
pkg/routes/register.go
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
register struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
registerForm struct {
|
||||
Name string `form:"name" validate:"required"`
|
||||
Email string `form:"email" validate:"required,email"`
|
||||
Password string `form:"password" validate:"required"`
|
||||
ConfirmPassword string `form:"password-confirm" validate:"required,eqfield=Password"`
|
||||
Submission controller.FormSubmission
|
||||
}
|
||||
)
|
||||
|
||||
func (c *register) Get(ctx echo.Context) error {
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "auth"
|
||||
page.Name = "register"
|
||||
page.Title = "Register"
|
||||
page.Form = registerForm{}
|
||||
|
||||
if form := ctx.Get(context.FormKey); form != nil {
|
||||
page.Form = form.(*registerForm)
|
||||
}
|
||||
|
||||
return c.RenderPage(ctx, page)
|
||||
}
|
||||
|
||||
func (c *register) Post(ctx echo.Context) error {
|
||||
var form registerForm
|
||||
ctx.Set(context.FormKey, &form)
|
||||
|
||||
// Parse the form values
|
||||
if err := ctx.Bind(&form); err != nil {
|
||||
return c.Fail(err, "unable to parse register form")
|
||||
}
|
||||
|
||||
if err := form.Submission.Process(ctx, form); err != nil {
|
||||
return c.Fail(err, "unable to process form submission")
|
||||
}
|
||||
|
||||
if form.Submission.HasErrors() {
|
||||
return c.Get(ctx)
|
||||
}
|
||||
|
||||
// Hash the password
|
||||
pwHash, err := c.Container.Auth.HashPassword(form.Password)
|
||||
if err != nil {
|
||||
return c.Fail(err, "unable to hash password")
|
||||
}
|
||||
|
||||
// Attempt creating the user
|
||||
u, err := c.Container.ORM.User.
|
||||
Create().
|
||||
SetName(form.Name).
|
||||
SetEmail(form.Email).
|
||||
SetPassword(pwHash).
|
||||
Save(ctx.Request().Context())
|
||||
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
ctx.Logger().Infof("user created: %s", u.Name)
|
||||
case *ent.ConstraintError:
|
||||
msg.Warning(ctx, "A user with this email address already exists. Please log in.")
|
||||
return c.Redirect(ctx, "login")
|
||||
default:
|
||||
return c.Fail(err, "unable to create user")
|
||||
}
|
||||
|
||||
// Log the user in
|
||||
err = c.Container.Auth.Login(ctx, u.ID)
|
||||
if err != nil {
|
||||
ctx.Logger().Errorf("unable to log in: %v", err)
|
||||
msg.Info(ctx, "Your account has been created.")
|
||||
return c.Redirect(ctx, "login")
|
||||
}
|
||||
|
||||
msg.Success(ctx, "Your account has been created. You are now logged in.")
|
||||
|
||||
// Send the verification email
|
||||
c.sendVerificationEmail(ctx, u)
|
||||
|
||||
return c.Redirect(ctx, "home")
|
||||
}
|
||||
|
||||
func (c *register) sendVerificationEmail(ctx echo.Context, usr *ent.User) {
|
||||
// Generate a token
|
||||
token, err := c.Container.Auth.GenerateEmailVerificationToken(usr.Email)
|
||||
if err != nil {
|
||||
ctx.Logger().Errorf("unable to generate email verification token: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Send the email
|
||||
url := ctx.Echo().Reverse("verify_email", token)
|
||||
err = c.Container.Mail.
|
||||
Compose().
|
||||
To(usr.Email).
|
||||
Subject("Confirm your email address").
|
||||
Body(fmt.Sprintf("Click here to confirm your email address: %s", url)).
|
||||
Send(ctx)
|
||||
|
||||
if err != nil {
|
||||
ctx.Logger().Errorf("unable to send email verification link: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
msg.Info(ctx, "An email was sent to you to verify your email address.")
|
||||
}
|
||||
82
pkg/routes/reset_password.go
Normal file
82
pkg/routes/reset_password.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
resetPassword struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
resetPasswordForm struct {
|
||||
Password string `form:"password" validate:"required"`
|
||||
ConfirmPassword string `form:"password-confirm" validate:"required,eqfield=Password"`
|
||||
Submission controller.FormSubmission
|
||||
}
|
||||
)
|
||||
|
||||
func (c *resetPassword) Get(ctx echo.Context) error {
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "auth"
|
||||
page.Name = "reset-password"
|
||||
page.Title = "Reset password"
|
||||
page.Form = resetPasswordForm{}
|
||||
|
||||
if form := ctx.Get(context.FormKey); form != nil {
|
||||
page.Form = form.(*resetPasswordForm)
|
||||
}
|
||||
|
||||
return c.RenderPage(ctx, page)
|
||||
}
|
||||
|
||||
func (c *resetPassword) Post(ctx echo.Context) error {
|
||||
var form resetPasswordForm
|
||||
ctx.Set(context.FormKey, &form)
|
||||
|
||||
// Parse the form values
|
||||
if err := ctx.Bind(&form); err != nil {
|
||||
return c.Fail(err, "unable to parse password reset form")
|
||||
}
|
||||
|
||||
if err := form.Submission.Process(ctx, form); err != nil {
|
||||
return c.Fail(err, "unable to process form submission")
|
||||
}
|
||||
|
||||
if form.Submission.HasErrors() {
|
||||
return c.Get(ctx)
|
||||
}
|
||||
|
||||
// Hash the new password
|
||||
hash, err := c.Container.Auth.HashPassword(form.Password)
|
||||
if err != nil {
|
||||
return c.Fail(err, "unable to hash password")
|
||||
}
|
||||
|
||||
// Get the requesting user
|
||||
usr := ctx.Get(context.UserKey).(*ent.User)
|
||||
|
||||
// Update the user
|
||||
_, err = usr.
|
||||
Update().
|
||||
SetPassword(hash).
|
||||
Save(ctx.Request().Context())
|
||||
|
||||
if err != nil {
|
||||
return c.Fail(err, "unable to update password")
|
||||
}
|
||||
|
||||
// Delete all password tokens for this user
|
||||
err = c.Container.Auth.DeletePasswordTokens(ctx, usr.ID)
|
||||
if err != nil {
|
||||
return c.Fail(err, "unable to delete password tokens")
|
||||
}
|
||||
|
||||
msg.Success(ctx, "Your password has been updated.")
|
||||
return c.Redirect(ctx, "login")
|
||||
}
|
||||
109
pkg/routes/router.go
Normal file
109
pkg/routes/router.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
"github.com/mikestefanello/pagoda/pkg/middleware"
|
||||
"github.com/mikestefanello/pagoda/pkg/services"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
echomw "github.com/labstack/echo/v4/middleware"
|
||||
)
|
||||
|
||||
// BuildRouter builds the router
|
||||
func BuildRouter(c *services.Container) {
|
||||
// Static files with proper cache control
|
||||
// funcmap.File() should be used in templates to append a cache key to the URL in order to break cache
|
||||
// after each server restart
|
||||
c.Web.Group("", middleware.CacheControl(c.Config.Cache.Expiration.StaticFile)).
|
||||
Static(config.StaticPrefix, config.StaticDir)
|
||||
|
||||
// Non static file route group
|
||||
g := c.Web.Group("")
|
||||
|
||||
// Force HTTPS, if enabled
|
||||
if c.Config.HTTP.TLS.Enabled {
|
||||
g.Use(echomw.HTTPSRedirect())
|
||||
}
|
||||
|
||||
g.Use(
|
||||
echomw.RemoveTrailingSlashWithConfig(echomw.TrailingSlashConfig{
|
||||
RedirectCode: http.StatusMovedPermanently,
|
||||
}),
|
||||
echomw.Recover(),
|
||||
echomw.Secure(),
|
||||
echomw.RequestID(),
|
||||
echomw.Gzip(),
|
||||
echomw.Logger(),
|
||||
middleware.LogRequestID(),
|
||||
echomw.TimeoutWithConfig(echomw.TimeoutConfig{
|
||||
Timeout: c.Config.App.Timeout,
|
||||
}),
|
||||
session.Middleware(sessions.NewCookieStore([]byte(c.Config.App.EncryptionKey))),
|
||||
middleware.LoadAuthenticatedUser(c.Auth),
|
||||
middleware.ServeCachedPage(c.Cache),
|
||||
echomw.CSRFWithConfig(echomw.CSRFConfig{
|
||||
TokenLookup: "form:csrf",
|
||||
}),
|
||||
)
|
||||
|
||||
// Base controller
|
||||
ctr := controller.NewController(c)
|
||||
|
||||
// Error handler
|
||||
err := errorHandler{Controller: ctr}
|
||||
c.Web.HTTPErrorHandler = err.Get
|
||||
|
||||
// Example routes
|
||||
navRoutes(c, g, ctr)
|
||||
userRoutes(c, g, ctr)
|
||||
}
|
||||
|
||||
func navRoutes(c *services.Container, g *echo.Group, ctr controller.Controller) {
|
||||
home := home{Controller: ctr}
|
||||
g.GET("/", home.Get).Name = "home"
|
||||
|
||||
search := search{Controller: ctr}
|
||||
g.GET("/search", search.Get).Name = "search"
|
||||
|
||||
about := about{Controller: ctr}
|
||||
g.GET("/about", about.Get).Name = "about"
|
||||
|
||||
contact := contact{Controller: ctr}
|
||||
g.GET("/contact", contact.Get).Name = "contact"
|
||||
g.POST("/contact", contact.Post).Name = "contact.post"
|
||||
}
|
||||
|
||||
func userRoutes(c *services.Container, g *echo.Group, ctr controller.Controller) {
|
||||
logout := logout{Controller: ctr}
|
||||
g.GET("/logout", logout.Get, middleware.RequireAuthentication()).Name = "logout"
|
||||
|
||||
verifyEmail := verifyEmail{Controller: ctr}
|
||||
g.GET("/email/verify/:token", verifyEmail.Get).Name = "verify_email"
|
||||
|
||||
noAuth := g.Group("/user", middleware.RequireNoAuthentication())
|
||||
login := login{Controller: ctr}
|
||||
noAuth.GET("/login", login.Get).Name = "login"
|
||||
noAuth.POST("/login", login.Post).Name = "login.post"
|
||||
|
||||
register := register{Controller: ctr}
|
||||
noAuth.GET("/register", register.Get).Name = "register"
|
||||
noAuth.POST("/register", register.Post).Name = "register.post"
|
||||
|
||||
forgot := forgotPassword{Controller: ctr}
|
||||
noAuth.GET("/password", forgot.Get).Name = "forgot_password"
|
||||
noAuth.POST("/password", forgot.Post).Name = "forgot_password.post"
|
||||
|
||||
resetGroup := noAuth.Group("/password/reset",
|
||||
middleware.LoadUser(c.ORM),
|
||||
middleware.LoadValidPasswordToken(c.Auth),
|
||||
)
|
||||
reset := resetPassword{Controller: ctr}
|
||||
resetGroup.GET("/token/:user/:password_token/:token", reset.Get).Name = "reset_password"
|
||||
resetGroup.POST("/token/:user/:password_token/:token", reset.Post).Name = "reset_password.post"
|
||||
}
|
||||
136
pkg/routes/routes_test.go
Normal file
136
pkg/routes/routes_test.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
"github.com/mikestefanello/pagoda/pkg/services"
|
||||
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
srv *httptest.Server
|
||||
c *services.Container
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Set the environment to test
|
||||
config.SwitchEnvironment(config.EnvTest)
|
||||
|
||||
// Start a new container
|
||||
c = services.NewContainer()
|
||||
|
||||
// Start a test HTTP server
|
||||
BuildRouter(c)
|
||||
srv = httptest.NewServer(c.Web)
|
||||
|
||||
// Run tests
|
||||
exitVal := m.Run()
|
||||
|
||||
// Shutdown the container and test server
|
||||
if err := c.Shutdown(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
srv.Close()
|
||||
|
||||
os.Exit(exitVal)
|
||||
}
|
||||
|
||||
type httpRequest struct {
|
||||
route string
|
||||
client http.Client
|
||||
body url.Values
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func request(t *testing.T) *httpRequest {
|
||||
jar, err := cookiejar.New(nil)
|
||||
require.NoError(t, err)
|
||||
r := httpRequest{
|
||||
t: t,
|
||||
body: url.Values{},
|
||||
client: http.Client{
|
||||
Jar: jar,
|
||||
},
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func (h *httpRequest) setClient(client http.Client) *httpRequest {
|
||||
h.client = client
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *httpRequest) setRoute(route string, params ...interface{}) *httpRequest {
|
||||
h.route = srv.URL + c.Web.Reverse(route, params)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *httpRequest) setBody(body url.Values) *httpRequest {
|
||||
h.body = body
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *httpRequest) get() *httpResponse {
|
||||
resp, err := h.client.Get(h.route)
|
||||
require.NoError(h.t, err)
|
||||
r := httpResponse{
|
||||
t: h.t,
|
||||
Response: resp,
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func (h *httpRequest) post() *httpResponse {
|
||||
// Make a get request to get the CSRF token
|
||||
doc := h.get().
|
||||
assertStatusCode(http.StatusOK).
|
||||
toDoc()
|
||||
|
||||
// Extract the CSRF and include it in the POST request body
|
||||
csrf := doc.Find(`input[name="csrf"]`).First()
|
||||
token, exists := csrf.Attr("value")
|
||||
assert.True(h.t, exists)
|
||||
h.body["csrf"] = []string{token}
|
||||
|
||||
// Make the POST requests
|
||||
resp, err := h.client.PostForm(h.route, h.body)
|
||||
require.NoError(h.t, err)
|
||||
r := httpResponse{
|
||||
t: h.t,
|
||||
Response: resp,
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
type httpResponse struct {
|
||||
*http.Response
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (h *httpResponse) assertStatusCode(code int) *httpResponse {
|
||||
assert.Equal(h.t, code, h.Response.StatusCode)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *httpResponse) assertRedirect(t *testing.T, route string, params ...interface{}) *httpResponse {
|
||||
assert.Equal(t, c.Web.Reverse(route, params), h.Header.Get("Location"))
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *httpResponse) toDoc() *goquery.Document {
|
||||
doc, err := goquery.NewDocumentFromReader(h.Body)
|
||||
require.NoError(h.t, err)
|
||||
err = h.Body.Close()
|
||||
assert.NoError(h.t, err)
|
||||
return doc
|
||||
}
|
||||
44
pkg/routes/search.go
Normal file
44
pkg/routes/search.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
search struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
searchResult struct {
|
||||
Title string
|
||||
URL string
|
||||
}
|
||||
)
|
||||
|
||||
func (c *search) Get(ctx echo.Context) error {
|
||||
page := controller.NewPage(ctx)
|
||||
page.Layout = "main"
|
||||
page.Name = "search"
|
||||
|
||||
// Fake search results
|
||||
var results []searchResult
|
||||
if search := ctx.QueryParam("query"); search != "" {
|
||||
for i := 0; i < 5; i++ {
|
||||
title := "Lorem ipsum example ddolor sit amet"
|
||||
index := rand.Intn(len(title))
|
||||
title = title[:index] + search + title[index:]
|
||||
results = append(results, searchResult{
|
||||
Title: title,
|
||||
URL: fmt.Sprintf("https://www.%s.com", search),
|
||||
})
|
||||
}
|
||||
}
|
||||
page.Data = results
|
||||
|
||||
return c.RenderPage(ctx, page)
|
||||
}
|
||||
62
pkg/routes/verify_email.go
Normal file
62
pkg/routes/verify_email.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/ent/user"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
"github.com/mikestefanello/pagoda/pkg/controller"
|
||||
"github.com/mikestefanello/pagoda/pkg/msg"
|
||||
)
|
||||
|
||||
type verifyEmail struct {
|
||||
controller.Controller
|
||||
}
|
||||
|
||||
func (c *verifyEmail) Get(ctx echo.Context) error {
|
||||
var usr *ent.User
|
||||
|
||||
// Validate the token
|
||||
token := ctx.Param("token")
|
||||
email, err := c.Container.Auth.ValidateEmailVerificationToken(token)
|
||||
if err != nil {
|
||||
msg.Warning(ctx, "The link is either invalid or has expired.")
|
||||
return c.Redirect(ctx, "home")
|
||||
}
|
||||
|
||||
// Check if it matches the authenticated user
|
||||
if u := ctx.Get(context.AuthenticatedUserKey); u != nil {
|
||||
authUser := u.(*ent.User)
|
||||
|
||||
if authUser.Email == email {
|
||||
usr = authUser
|
||||
}
|
||||
}
|
||||
|
||||
// Query to find a matching user, if needed
|
||||
if usr == nil {
|
||||
usr, err = c.Container.ORM.User.
|
||||
Query().
|
||||
Where(user.Email(email)).
|
||||
Only(ctx.Request().Context())
|
||||
|
||||
if err != nil {
|
||||
return c.Fail(err, "query failed loading email verification token user")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the user, if needed
|
||||
if !usr.Verified {
|
||||
usr, err = usr.
|
||||
Update().
|
||||
SetVerified(true).
|
||||
Save(ctx.Request().Context())
|
||||
|
||||
if err != nil {
|
||||
return c.Fail(err, "failed to set user as verified")
|
||||
}
|
||||
}
|
||||
|
||||
msg.Success(ctx, "Your email has been successfully verified.")
|
||||
return c.Redirect(ctx, "home")
|
||||
}
|
||||
233
pkg/services/auth.go
Normal file
233
pkg/services/auth.go
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/ent/passwordtoken"
|
||||
"github.com/mikestefanello/pagoda/ent/user"
|
||||
"github.com/mikestefanello/pagoda/pkg/context"
|
||||
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
"github.com/labstack/echo/v4"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// authSessionName stores the name of the session which contains authentication data
|
||||
authSessionName = "ua"
|
||||
|
||||
// authSessionKeyUserID stores the key used to store the user ID in the session
|
||||
authSessionKeyUserID = "user_id"
|
||||
|
||||
// authSessionKeyAuthenticated stores the key used to store the authentication status in the session
|
||||
authSessionKeyAuthenticated = "authenticated"
|
||||
)
|
||||
|
||||
// NotAuthenticatedError is an error returned when a user is not authenticated
|
||||
type NotAuthenticatedError struct{}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e NotAuthenticatedError) Error() string {
|
||||
return "user not authenticated"
|
||||
}
|
||||
|
||||
// InvalidPasswordTokenError is an error returned when an invalid token is provided
|
||||
type InvalidPasswordTokenError struct{}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e InvalidPasswordTokenError) Error() string {
|
||||
return "invalid password token"
|
||||
}
|
||||
|
||||
// AuthClient is the client that handles authentication requests
|
||||
type AuthClient struct {
|
||||
config *config.Config
|
||||
orm *ent.Client
|
||||
}
|
||||
|
||||
// NewAuthClient creates a new authentication client
|
||||
func NewAuthClient(cfg *config.Config, orm *ent.Client) *AuthClient {
|
||||
return &AuthClient{
|
||||
config: cfg,
|
||||
orm: orm,
|
||||
}
|
||||
}
|
||||
|
||||
// Login logs in a user of a given ID
|
||||
func (c *AuthClient) Login(ctx echo.Context, userID int) error {
|
||||
sess, err := session.Get(authSessionName, ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sess.Values[authSessionKeyUserID] = userID
|
||||
sess.Values[authSessionKeyAuthenticated] = true
|
||||
return sess.Save(ctx.Request(), ctx.Response())
|
||||
}
|
||||
|
||||
// Logout logs the requesting user out
|
||||
func (c *AuthClient) Logout(ctx echo.Context) error {
|
||||
sess, err := session.Get(authSessionName, ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sess.Values[authSessionKeyAuthenticated] = false
|
||||
return sess.Save(ctx.Request(), ctx.Response())
|
||||
}
|
||||
|
||||
// GetAuthenticatedUserID returns the authenticated user's ID, if the user is logged in
|
||||
func (c *AuthClient) GetAuthenticatedUserID(ctx echo.Context) (int, error) {
|
||||
sess, err := session.Get(authSessionName, ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if sess.Values[authSessionKeyAuthenticated] == true {
|
||||
return sess.Values[authSessionKeyUserID].(int), nil
|
||||
}
|
||||
|
||||
return 0, NotAuthenticatedError{}
|
||||
}
|
||||
|
||||
// GetAuthenticatedUser returns the authenticated user if the user is logged in
|
||||
func (c *AuthClient) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) {
|
||||
if userID, err := c.GetAuthenticatedUserID(ctx); err == nil {
|
||||
return c.orm.User.Query().
|
||||
Where(user.ID(userID)).
|
||||
Only(ctx.Request().Context())
|
||||
}
|
||||
|
||||
return nil, NotAuthenticatedError{}
|
||||
}
|
||||
|
||||
// HashPassword returns a hash of a given password
|
||||
func (c *AuthClient) HashPassword(password string) (string, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// CheckPassword check if a given password matches a given hash
|
||||
func (c *AuthClient) CheckPassword(password, hash string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
}
|
||||
|
||||
// GeneratePasswordResetToken generates a password reset token for a given user.
|
||||
// For security purposes, the token itself is not stored in the database but rather
|
||||
// a hash of the token, exactly how passwords are handled. This method returns both
|
||||
// the generated token as well as the token entity which only contains the hash.
|
||||
func (c *AuthClient) GeneratePasswordResetToken(ctx echo.Context, userID int) (string, *ent.PasswordToken, error) {
|
||||
// Generate the token, which is what will go in the URL, but not the database
|
||||
token, err := c.RandomToken(c.config.App.PasswordToken.Length)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Hash the token, which is what will be stored in the database
|
||||
hash, err := c.HashPassword(token)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Create and save the password reset token
|
||||
pt, err := c.orm.PasswordToken.
|
||||
Create().
|
||||
SetHash(hash).
|
||||
SetUserID(userID).
|
||||
Save(ctx.Request().Context())
|
||||
|
||||
return token, pt, err
|
||||
}
|
||||
|
||||
// GetValidPasswordToken returns a valid, non-expired password token entity for a given user, token ID and token.
|
||||
// Since the actual token is not stored in the database for security purposes, if a matching password token entity is
|
||||
// found a hash of the provided token is compared with the hash stored in the database in order to validate.
|
||||
func (c *AuthClient) GetValidPasswordToken(ctx echo.Context, userID, tokenID int, token string) (*ent.PasswordToken, error) {
|
||||
// Ensure expired tokens are never returned
|
||||
expiration := time.Now().Add(-c.config.App.PasswordToken.Expiration)
|
||||
|
||||
// Query to find a password token entity that matches the given user and token ID
|
||||
pt, err := c.orm.PasswordToken.
|
||||
Query().
|
||||
Where(passwordtoken.ID(tokenID)).
|
||||
Where(passwordtoken.HasUserWith(user.ID(userID))).
|
||||
Where(passwordtoken.CreatedAtGTE(expiration)).
|
||||
Only(ctx.Request().Context())
|
||||
|
||||
switch err.(type) {
|
||||
case *ent.NotFoundError:
|
||||
case nil:
|
||||
// Check the token for a hash match
|
||||
if err := c.CheckPassword(token, pt.Hash); err == nil {
|
||||
return pt, nil
|
||||
}
|
||||
default:
|
||||
if !context.IsCanceledError(err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, InvalidPasswordTokenError{}
|
||||
}
|
||||
|
||||
// DeletePasswordTokens deletes all password tokens in the database for a belonging to a given user.
|
||||
// This should be called after a successful password reset.
|
||||
func (c *AuthClient) DeletePasswordTokens(ctx echo.Context, userID int) error {
|
||||
_, err := c.orm.PasswordToken.
|
||||
Delete().
|
||||
Where(passwordtoken.HasUserWith(user.ID(userID))).
|
||||
Exec(ctx.Request().Context())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// RandomToken generates a random token string of a given length
|
||||
func (c *AuthClient) RandomToken(length int) (string, error) {
|
||||
b := make([]byte, (length/2)+1)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
token := hex.EncodeToString(b)
|
||||
return token[:length], nil
|
||||
}
|
||||
|
||||
// GenerateEmailVerificationToken generates an email verification token for a given email address using JWT which
|
||||
// is set to expire based on the duration stored in configuration
|
||||
func (c *AuthClient) GenerateEmailVerificationToken(email string) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"email": email,
|
||||
"exp": time.Now().Add(c.config.App.EmailVerificationTokenExpiration).Unix(),
|
||||
})
|
||||
|
||||
return token.SignedString([]byte(c.config.App.EncryptionKey))
|
||||
}
|
||||
|
||||
// ValidateEmailVerificationToken validates an email verification token and returns the associated email address if
|
||||
// the token is valid and has not expired
|
||||
func (c *AuthClient) ValidateEmailVerificationToken(token string) (string, error) {
|
||||
t, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
|
||||
return []byte(c.config.App.EncryptionKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if claims, ok := t.Claims.(jwt.MapClaims); ok && t.Valid {
|
||||
return claims["email"].(string), nil
|
||||
}
|
||||
|
||||
return "", errors.New("invalid or expired token")
|
||||
}
|
||||
146
pkg/services/auth_test.go
Normal file
146
pkg/services/auth_test.go
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent/passwordtoken"
|
||||
"github.com/mikestefanello/pagoda/ent/user"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuthClient_Auth(t *testing.T) {
|
||||
assertNoAuth := func() {
|
||||
_, err := c.Auth.GetAuthenticatedUserID(ctx)
|
||||
assert.True(t, errors.Is(err, NotAuthenticatedError{}))
|
||||
_, err = c.Auth.GetAuthenticatedUser(ctx)
|
||||
assert.True(t, errors.Is(err, NotAuthenticatedError{}))
|
||||
}
|
||||
|
||||
assertNoAuth()
|
||||
|
||||
err := c.Auth.Login(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
uid, err := c.Auth.GetAuthenticatedUserID(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, usr.ID, uid)
|
||||
|
||||
u, err := c.Auth.GetAuthenticatedUser(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, u.ID, usr.ID)
|
||||
|
||||
err = c.Auth.Logout(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertNoAuth()
|
||||
}
|
||||
|
||||
func TestAuthClient_PasswordHashing(t *testing.T) {
|
||||
pw := "testcheckpassword"
|
||||
hash, err := c.Auth.HashPassword(pw)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, hash, pw)
|
||||
err = c.Auth.CheckPassword(pw, hash)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthClient_GeneratePasswordResetToken(t *testing.T) {
|
||||
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, token, c.Config.App.PasswordToken.Length)
|
||||
assert.NoError(t, c.Auth.CheckPassword(token, pt.Hash))
|
||||
}
|
||||
|
||||
func TestAuthClient_GetValidPasswordToken(t *testing.T) {
|
||||
// Check that a fake token is not valid
|
||||
_, err := c.Auth.GetValidPasswordToken(ctx, usr.ID, 1, "faketoken")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Generate a valid token and check that it is returned
|
||||
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
pt2, err := c.Auth.GetValidPasswordToken(ctx, usr.ID, pt.ID, token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pt.ID, pt2.ID)
|
||||
|
||||
// Expire the token by pushing the date far enough back
|
||||
count, err := c.ORM.PasswordToken.
|
||||
Update().
|
||||
SetCreatedAt(time.Now().Add(-(c.Config.App.PasswordToken.Expiration + time.Hour))).
|
||||
Where(passwordtoken.ID(pt.ID)).
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
// Expired tokens should not be valid
|
||||
_, err = c.Auth.GetValidPasswordToken(ctx, usr.ID, pt.ID, token)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthClient_DeletePasswordTokens(t *testing.T) {
|
||||
// Create three tokens for the user
|
||||
for i := 0; i < 3; i++ {
|
||||
_, _, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Delete all tokens for the user
|
||||
err := c.Auth.DeletePasswordTokens(ctx, usr.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that no tokens remain
|
||||
count, err := c.ORM.PasswordToken.
|
||||
Query().
|
||||
Where(passwordtoken.HasUserWith(user.ID(usr.ID))).
|
||||
Count(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestAuthClient_RandomToken(t *testing.T) {
|
||||
length := c.Config.App.PasswordToken.Length
|
||||
a, err := c.Auth.RandomToken(length)
|
||||
require.NoError(t, err)
|
||||
b, err := c.Auth.RandomToken(length)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, a, length)
|
||||
assert.Len(t, b, length)
|
||||
assert.NotEqual(t, a, b)
|
||||
}
|
||||
|
||||
func TestAuthClient_EmailVerificationToken(t *testing.T) {
|
||||
t.Run("valid token", func(t *testing.T) {
|
||||
email := "test@localhost.com"
|
||||
token, err := c.Auth.GenerateEmailVerificationToken(email)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenEmail, err := c.Auth.ValidateEmailVerificationToken(token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, email, tokenEmail)
|
||||
})
|
||||
|
||||
t.Run("invalid token", func(t *testing.T) {
|
||||
badToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAbG9jYWxob3N0LmNvbSIsImV4cCI6MTkxNzg2NDAwMH0.ScJCpfEEzlilKfRs_aVouzwPNKI28M3AIm-hyImQHUQ"
|
||||
_, err := c.Auth.ValidateEmailVerificationToken(badToken)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
c.Config.App.EmailVerificationTokenExpiration = -time.Hour
|
||||
email := "test@localhost.com"
|
||||
token, err := c.Auth.GenerateEmailVerificationToken(email)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = c.Auth.ValidateEmailVerificationToken(token)
|
||||
assert.Error(t, err)
|
||||
|
||||
c.Config.App.EmailVerificationTokenExpiration = time.Hour * 12
|
||||
})
|
||||
}
|
||||
228
pkg/services/cache.go
Normal file
228
pkg/services/cache.go
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/v2/cache"
|
||||
"github.com/eko/gocache/v2/marshaler"
|
||||
"github.com/eko/gocache/v2/store"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
)
|
||||
|
||||
type (
|
||||
// CacheClient is the client that allows you to interact with the cache
|
||||
CacheClient struct {
|
||||
// Client stores the client to the underlying cache service
|
||||
Client *redis.Client
|
||||
|
||||
// cache stores the cache interface
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
// cacheSet handles chaining a set operation
|
||||
cacheSet struct {
|
||||
client *CacheClient
|
||||
key string
|
||||
group string
|
||||
data interface{}
|
||||
expiration time.Duration
|
||||
tags []string
|
||||
}
|
||||
|
||||
// cacheGet handles chaining a get operation
|
||||
cacheGet struct {
|
||||
client *CacheClient
|
||||
key string
|
||||
group string
|
||||
dataType interface{}
|
||||
}
|
||||
|
||||
// cacheFlush handles chaining a flush operation
|
||||
cacheFlush struct {
|
||||
client *CacheClient
|
||||
key string
|
||||
group string
|
||||
tags []string
|
||||
}
|
||||
)
|
||||
|
||||
// NewCacheClient creates a new cache client
|
||||
func NewCacheClient(cfg *config.Config) (*CacheClient, error) {
|
||||
// Determine the database based on the environment
|
||||
db := cfg.Cache.Database
|
||||
if cfg.App.Environment == config.EnvTest {
|
||||
db = cfg.Cache.TestDatabase
|
||||
}
|
||||
|
||||
// Connect to the cache
|
||||
c := &CacheClient{}
|
||||
c.Client = redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Hostname, cfg.Cache.Port),
|
||||
Password: cfg.Cache.Password,
|
||||
DB: db,
|
||||
})
|
||||
if _, err := c.Client.Ping(context.Background()).Result(); err != nil {
|
||||
return c, err
|
||||
}
|
||||
|
||||
// Flush the database if this is the test environment
|
||||
if cfg.App.Environment == config.EnvTest {
|
||||
if err := c.Client.FlushDB(context.Background()).Err(); err != nil {
|
||||
return c, err
|
||||
}
|
||||
}
|
||||
|
||||
cacheStore := store.NewRedis(c.Client, nil)
|
||||
c.cache = cache.New(cacheStore)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Close closes the connection to the cache
|
||||
func (c *CacheClient) Close() error {
|
||||
return c.Client.Close()
|
||||
}
|
||||
|
||||
// Set creates a cache set operation
|
||||
func (c *CacheClient) Set() *cacheSet {
|
||||
return &cacheSet{
|
||||
client: c,
|
||||
}
|
||||
}
|
||||
|
||||
// Get creates a cache get operation
|
||||
func (c *CacheClient) Get() *cacheGet {
|
||||
return &cacheGet{
|
||||
client: c,
|
||||
}
|
||||
}
|
||||
|
||||
// Flush creates a cache flush operation
|
||||
func (c *CacheClient) Flush() *cacheFlush {
|
||||
return &cacheFlush{
|
||||
client: c,
|
||||
}
|
||||
}
|
||||
|
||||
// cacheKey formats a cache key with an optional group
|
||||
func (c *CacheClient) cacheKey(group, key string) string {
|
||||
if group != "" {
|
||||
return fmt.Sprintf("%s::%s", group, key)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// Key sets the cache key
|
||||
func (c *cacheSet) Key(key string) *cacheSet {
|
||||
c.key = key
|
||||
return c
|
||||
}
|
||||
|
||||
// Group sets the cache group
|
||||
func (c *cacheSet) Group(group string) *cacheSet {
|
||||
c.group = group
|
||||
return c
|
||||
}
|
||||
|
||||
// Data sets the data to cache
|
||||
func (c *cacheSet) Data(data interface{}) *cacheSet {
|
||||
c.data = data
|
||||
return c
|
||||
}
|
||||
|
||||
// Expiration sets the expiration duration of the cached data
|
||||
func (c *cacheSet) Expiration(expiration time.Duration) *cacheSet {
|
||||
c.expiration = expiration
|
||||
return c
|
||||
}
|
||||
|
||||
// Tags sets the cache tags
|
||||
func (c *cacheSet) Tags(tags ...string) *cacheSet {
|
||||
c.tags = tags
|
||||
return c
|
||||
}
|
||||
|
||||
// Save saves the data in the cache
|
||||
func (c *cacheSet) Save(ctx context.Context) error {
|
||||
if c.key == "" {
|
||||
return errors.New("no cache key specified")
|
||||
}
|
||||
|
||||
opts := &store.Options{
|
||||
Expiration: c.expiration,
|
||||
Tags: c.tags,
|
||||
}
|
||||
|
||||
return marshaler.
|
||||
New(c.client.cache).
|
||||
Set(ctx, c.client.cacheKey(c.group, c.key), c.data, opts)
|
||||
}
|
||||
|
||||
// Key sets the cache key
|
||||
func (c *cacheGet) Key(key string) *cacheGet {
|
||||
c.key = key
|
||||
return c
|
||||
}
|
||||
|
||||
// Group sets the cache group
|
||||
func (c *cacheGet) Group(group string) *cacheGet {
|
||||
c.group = group
|
||||
return c
|
||||
}
|
||||
|
||||
// Type sets the expected Go type of the data being retrieved from the cache
|
||||
func (c *cacheGet) Type(expectedType interface{}) *cacheGet {
|
||||
c.dataType = expectedType
|
||||
return c
|
||||
}
|
||||
|
||||
// Fetch fetches the data from the cache
|
||||
func (c *cacheGet) Fetch(ctx context.Context) (interface{}, error) {
|
||||
if c.key == "" {
|
||||
return nil, errors.New("no cache key specified")
|
||||
}
|
||||
|
||||
return marshaler.New(c.client.cache).Get(
|
||||
ctx,
|
||||
c.client.cacheKey(c.group, c.key),
|
||||
c.dataType,
|
||||
)
|
||||
}
|
||||
|
||||
// Key sets the cache key
|
||||
func (c *cacheFlush) Key(key string) *cacheFlush {
|
||||
c.key = key
|
||||
return c
|
||||
}
|
||||
|
||||
// Group sets the cache group
|
||||
func (c *cacheFlush) Group(group string) *cacheFlush {
|
||||
c.group = group
|
||||
return c
|
||||
}
|
||||
|
||||
// Tags sets the cache tags
|
||||
func (c *cacheFlush) Tags(tags ...string) *cacheFlush {
|
||||
c.tags = tags
|
||||
return c
|
||||
}
|
||||
|
||||
// Execute flushes the data from the cache
|
||||
func (c *cacheFlush) Execute(ctx context.Context) error {
|
||||
if len(c.tags) > 0 {
|
||||
if err := c.client.cache.Invalidate(ctx, store.InvalidateOptions{
|
||||
Tags: c.tags,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if c.key != "" {
|
||||
return c.client.cache.Delete(ctx, c.client.cacheKey(c.group, c.key))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
105
pkg/services/cache_test.go
Normal file
105
pkg/services/cache_test.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCacheClient(t *testing.T) {
|
||||
type cacheTest struct {
|
||||
Value string
|
||||
}
|
||||
// Cache some data
|
||||
data := cacheTest{Value: "abcdef"}
|
||||
group := "testgroup"
|
||||
key := "testkey"
|
||||
err := c.Cache.
|
||||
Set().
|
||||
Group(group).
|
||||
Key(key).
|
||||
Data(data).
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the data
|
||||
fromCache, err := c.Cache.
|
||||
Get().
|
||||
Group(group).
|
||||
Key(key).
|
||||
Type(new(cacheTest)).
|
||||
Fetch(context.Background())
|
||||
require.NoError(t, err)
|
||||
cast, ok := fromCache.(*cacheTest)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, data, *cast)
|
||||
|
||||
// The same key with the wrong group should fail
|
||||
_, err = c.Cache.
|
||||
Get().
|
||||
Key(key).
|
||||
Type(new(cacheTest)).
|
||||
Fetch(context.Background())
|
||||
assert.Error(t, err)
|
||||
|
||||
// Flush the data
|
||||
err = c.Cache.
|
||||
Flush().
|
||||
Group(group).
|
||||
Key(key).
|
||||
Execute(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// The data should be gone
|
||||
assertFlushed := func() {
|
||||
// The data should be gone
|
||||
_, err = c.Cache.
|
||||
Get().
|
||||
Group(group).
|
||||
Key(key).
|
||||
Type(new(cacheTest)).
|
||||
Fetch(context.Background())
|
||||
assert.Equal(t, redis.Nil, err)
|
||||
}
|
||||
assertFlushed()
|
||||
|
||||
// Set with tags
|
||||
err = c.Cache.
|
||||
Set().
|
||||
Group(group).
|
||||
Key(key).
|
||||
Data(data).
|
||||
Tags("tag1").
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Flush the tag
|
||||
err = c.Cache.
|
||||
Flush().
|
||||
Tags("tag1").
|
||||
Execute(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// The data should be gone
|
||||
assertFlushed()
|
||||
|
||||
// Set with expiration
|
||||
err = c.Cache.
|
||||
Set().
|
||||
Group(group).
|
||||
Key(key).
|
||||
Data(data).
|
||||
Expiration(time.Millisecond).
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(time.Millisecond * 2)
|
||||
|
||||
// The data should be gone
|
||||
assertFlushed()
|
||||
}
|
||||
201
pkg/services/container.go
Normal file
201
pkg/services/container.go
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/schema"
|
||||
|
||||
// Required by ent
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/log"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
|
||||
// Require by ent
|
||||
_ "github.com/mikestefanello/pagoda/ent/runtime"
|
||||
)
|
||||
|
||||
// Container contains all services used by the application and provides an easy way to handle dependency
|
||||
// injection including within tests
|
||||
type Container struct {
|
||||
// Validator stores a validator
|
||||
Validator *Validator
|
||||
|
||||
// Web stores the web framework
|
||||
Web *echo.Echo
|
||||
|
||||
// Config stores the application configuration
|
||||
Config *config.Config
|
||||
|
||||
// Cache contains the cache client
|
||||
Cache *CacheClient
|
||||
|
||||
// Database stores the connection to the database
|
||||
Database *sql.DB
|
||||
|
||||
// ORM stores a client to the ORM
|
||||
ORM *ent.Client
|
||||
|
||||
// Mail stores an email sending client
|
||||
Mail *MailClient
|
||||
|
||||
// Auth stores an authentication client
|
||||
Auth *AuthClient
|
||||
|
||||
// TemplateRenderer stores a service to easily render and cache templates
|
||||
TemplateRenderer *TemplateRenderer
|
||||
|
||||
// Tasks stores the task client
|
||||
Tasks *TaskClient
|
||||
}
|
||||
|
||||
// NewContainer creates and initializes a new Container
|
||||
func NewContainer() *Container {
|
||||
c := new(Container)
|
||||
c.initConfig()
|
||||
c.initValidator()
|
||||
c.initWeb()
|
||||
c.initCache()
|
||||
c.initDatabase()
|
||||
c.initORM()
|
||||
c.initAuth()
|
||||
c.initTemplateRenderer()
|
||||
c.initMail()
|
||||
c.initTasks()
|
||||
return c
|
||||
}
|
||||
|
||||
// Shutdown shuts the Container down and disconnects all connections
|
||||
func (c *Container) Shutdown() error {
|
||||
if err := c.Tasks.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Cache.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.ORM.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Database.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initConfig initializes configuration
|
||||
func (c *Container) initConfig() {
|
||||
cfg, err := config.GetConfig()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to load config: %v", err))
|
||||
}
|
||||
c.Config = &cfg
|
||||
}
|
||||
|
||||
// initValidator initializes the validator
|
||||
func (c *Container) initValidator() {
|
||||
c.Validator = NewValidator()
|
||||
}
|
||||
|
||||
// initWeb initializes the web framework
|
||||
func (c *Container) initWeb() {
|
||||
c.Web = echo.New()
|
||||
|
||||
// Configure logging
|
||||
switch c.Config.App.Environment {
|
||||
case config.EnvProduction:
|
||||
c.Web.Logger.SetLevel(log.WARN)
|
||||
default:
|
||||
c.Web.Logger.SetLevel(log.DEBUG)
|
||||
}
|
||||
|
||||
c.Web.Validator = c.Validator
|
||||
}
|
||||
|
||||
// initCache initializes the cache
|
||||
func (c *Container) initCache() {
|
||||
var err error
|
||||
if c.Cache, err = NewCacheClient(c.Config); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// initDatabase initializes the database
|
||||
// If the environment is set to test, the test database will be used and will be dropped, recreated and migrated
|
||||
func (c *Container) initDatabase() {
|
||||
var err error
|
||||
|
||||
getAddr := func(dbName string) string {
|
||||
return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s",
|
||||
c.Config.Database.User,
|
||||
c.Config.Database.Password,
|
||||
c.Config.Database.Hostname,
|
||||
c.Config.Database.Port,
|
||||
dbName,
|
||||
)
|
||||
}
|
||||
|
||||
c.Database, err = sql.Open("pgx", getAddr(c.Config.Database.Database))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to connect to database: %v", err))
|
||||
}
|
||||
|
||||
// Check if this is a test environment
|
||||
if c.Config.App.Environment == config.EnvTest {
|
||||
// Drop the test database, ignoring errors in case it doesn't yet exist
|
||||
_, _ = c.Database.Exec("DROP DATABASE " + c.Config.Database.TestDatabase)
|
||||
|
||||
// Create the test database
|
||||
if _, err = c.Database.Exec("CREATE DATABASE " + c.Config.Database.TestDatabase); err != nil {
|
||||
panic(fmt.Sprintf("failed to create test database: %v", err))
|
||||
}
|
||||
|
||||
// Connect to the test database
|
||||
if err = c.Database.Close(); err != nil {
|
||||
panic(fmt.Sprintf("failed to close database connection: %v", err))
|
||||
}
|
||||
c.Database, err = sql.Open("pgx", getAddr(c.Config.Database.TestDatabase))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to connect to database: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initORM initializes the ORM
|
||||
func (c *Container) initORM() {
|
||||
drv := entsql.OpenDB(dialect.Postgres, c.Database)
|
||||
c.ORM = ent.NewClient(ent.Driver(drv))
|
||||
if err := c.ORM.Schema.Create(context.Background(), schema.WithAtlas(true)); err != nil {
|
||||
panic(fmt.Sprintf("failed to create database schema: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// initAuth initializes the authentication client
|
||||
func (c *Container) initAuth() {
|
||||
c.Auth = NewAuthClient(c.Config, c.ORM)
|
||||
}
|
||||
|
||||
// initTemplateRenderer initializes the template renderer
|
||||
func (c *Container) initTemplateRenderer() {
|
||||
c.TemplateRenderer = NewTemplateRenderer(c.Config)
|
||||
}
|
||||
|
||||
// initMail initialize the mail client
|
||||
func (c *Container) initMail() {
|
||||
var err error
|
||||
c.Mail, err = NewMailClient(c.Config, c.TemplateRenderer)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create mail client: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// initTasks initializes the task client
|
||||
func (c *Container) initTasks() {
|
||||
c.Tasks = NewTaskClient(c.Config)
|
||||
}
|
||||
20
pkg/services/container_test.go
Normal file
20
pkg/services/container_test.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewContainer(t *testing.T) {
|
||||
assert.NotNil(t, c.Web)
|
||||
assert.NotNil(t, c.Config)
|
||||
assert.NotNil(t, c.Validator)
|
||||
assert.NotNil(t, c.Cache)
|
||||
assert.NotNil(t, c.Database)
|
||||
assert.NotNil(t, c.ORM)
|
||||
assert.NotNil(t, c.Mail)
|
||||
assert.NotNil(t, c.Auth)
|
||||
assert.NotNil(t, c.TemplateRenderer)
|
||||
assert.NotNil(t, c.Tasks)
|
||||
}
|
||||
139
pkg/services/mail.go
Normal file
139
pkg/services/mail.go
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// MailClient provides a client for sending email
|
||||
// This is purposely not completed because there are many different methods and services
|
||||
// for sending email, many of which are very different. Choose what works best for you
|
||||
// and populate the methods below
|
||||
MailClient struct {
|
||||
// config stores application configuration
|
||||
config *config.Config
|
||||
|
||||
// templates stores the template renderer
|
||||
templates *TemplateRenderer
|
||||
}
|
||||
|
||||
// mail represents an email to be sent
|
||||
mail struct {
|
||||
client *MailClient
|
||||
from string
|
||||
to string
|
||||
subject string
|
||||
body string
|
||||
template string
|
||||
templateData interface{}
|
||||
}
|
||||
)
|
||||
|
||||
// NewMailClient creates a new MailClient
|
||||
func NewMailClient(cfg *config.Config, templates *TemplateRenderer) (*MailClient, error) {
|
||||
return &MailClient{
|
||||
config: cfg,
|
||||
templates: templates,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compose creates a new email
|
||||
func (m *MailClient) Compose() *mail {
|
||||
return &mail{
|
||||
client: m,
|
||||
from: m.config.Mail.FromAddress,
|
||||
}
|
||||
}
|
||||
|
||||
// skipSend determines if mail sending should be skipped
|
||||
func (m *MailClient) skipSend() bool {
|
||||
return m.config.App.Environment != config.EnvProduction
|
||||
}
|
||||
|
||||
// send attempts to send the email
|
||||
func (m *MailClient) send(email *mail, ctx echo.Context) error {
|
||||
switch {
|
||||
case email.to == "":
|
||||
return errors.New("email cannot be sent without a to address")
|
||||
case email.body == "" && email.template == "":
|
||||
return errors.New("email cannot be sent without a body or template")
|
||||
}
|
||||
|
||||
// Check if a template was supplied
|
||||
if email.template != "" {
|
||||
// Parse and execute template
|
||||
buf, err := m.templates.
|
||||
Parse().
|
||||
Group("mail").
|
||||
Key(email.template).
|
||||
Base(email.template).
|
||||
Files(fmt.Sprintf("emails/%s", email.template)).
|
||||
Execute(email.templateData)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
email.body = buf.String()
|
||||
}
|
||||
|
||||
// Check if mail sending should be skipped
|
||||
if m.skipSend() {
|
||||
ctx.Logger().Debugf("skipping email sent to: %s", email.to)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Finish based on your mail sender of choice!
|
||||
return nil
|
||||
}
|
||||
|
||||
// From sets the email from address
|
||||
func (m *mail) From(from string) *mail {
|
||||
m.from = from
|
||||
return m
|
||||
}
|
||||
|
||||
// To sets the email address this email will be sent to
|
||||
func (m *mail) To(to string) *mail {
|
||||
m.to = to
|
||||
return m
|
||||
}
|
||||
|
||||
// Subject sets the subject line of the email
|
||||
func (m *mail) Subject(subject string) *mail {
|
||||
m.subject = subject
|
||||
return m
|
||||
}
|
||||
|
||||
// Body sets the body of the email
|
||||
// This is not required and will be ignored if a template via Template()
|
||||
func (m *mail) Body(body string) *mail {
|
||||
m.body = body
|
||||
return m
|
||||
}
|
||||
|
||||
// Template sets the template to be used to produce the body of the email
|
||||
// The template name should only include the filename without the extension or directory.
|
||||
// The template must reside within the emails sub-directory.
|
||||
// The funcmap will be automatically added to the template.
|
||||
// Use TemplateData() to supply the data that will be passed in to the template.
|
||||
func (m *mail) Template(template string) *mail {
|
||||
m.template = template
|
||||
return m
|
||||
}
|
||||
|
||||
// TemplateData sets the data that will be passed to the template specified when calling Template()
|
||||
func (m *mail) TemplateData(data interface{}) *mail {
|
||||
m.templateData = data
|
||||
return m
|
||||
}
|
||||
|
||||
// Send attempts to send the email
|
||||
func (m *mail) Send(ctx echo.Context) error {
|
||||
return m.client.send(m, ctx)
|
||||
}
|
||||
3
pkg/services/mail_test.go
Normal file
3
pkg/services/mail_test.go
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
package services
|
||||
|
||||
// Fill this in once you implement your mail client
|
||||
46
pkg/services/services_test.go
Normal file
46
pkg/services/services_test.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
"github.com/mikestefanello/pagoda/pkg/tests"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
var (
|
||||
c *Container
|
||||
ctx echo.Context
|
||||
usr *ent.User
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Set the environment to test
|
||||
config.SwitchEnvironment(config.EnvTest)
|
||||
|
||||
// Create a new container
|
||||
c = NewContainer()
|
||||
|
||||
// Create a web context
|
||||
ctx, _ = tests.NewContext(c.Web, "/")
|
||||
tests.InitSession(ctx)
|
||||
|
||||
// Create a test user
|
||||
var err error
|
||||
if usr, err = tests.CreateUser(c.ORM); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
exitVal := m.Run()
|
||||
|
||||
// Shutdown the container
|
||||
if err = c.Shutdown(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
os.Exit(exitVal)
|
||||
}
|
||||
179
pkg/services/tasks.go
Normal file
179
pkg/services/tasks.go
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
)
|
||||
|
||||
type (
|
||||
// TaskClient is that client that allows you to queue or schedule task execution
|
||||
TaskClient struct {
|
||||
// client stores the asynq client
|
||||
client *asynq.Client
|
||||
|
||||
// scheduler stores the asynq scheduler
|
||||
scheduler *asynq.Scheduler
|
||||
}
|
||||
|
||||
// task handles task creation operations
|
||||
task struct {
|
||||
client *TaskClient
|
||||
typ string
|
||||
payload interface{}
|
||||
periodic *string
|
||||
queue *string
|
||||
maxRetries *int
|
||||
timeout *time.Duration
|
||||
deadline *time.Time
|
||||
at *time.Time
|
||||
wait *time.Duration
|
||||
retain *time.Duration
|
||||
}
|
||||
)
|
||||
|
||||
// NewTaskClient creates a new task client
|
||||
func NewTaskClient(cfg *config.Config) *TaskClient {
|
||||
// Determine the database based on the environment
|
||||
db := cfg.Cache.Database
|
||||
if cfg.App.Environment == config.EnvTest {
|
||||
db = cfg.Cache.TestDatabase
|
||||
}
|
||||
|
||||
conn := asynq.RedisClientOpt{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Hostname, cfg.Cache.Port),
|
||||
Password: cfg.Cache.Password,
|
||||
DB: db,
|
||||
}
|
||||
|
||||
return &TaskClient{
|
||||
client: asynq.NewClient(conn),
|
||||
scheduler: asynq.NewScheduler(conn, nil),
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection to the task service
|
||||
func (t *TaskClient) Close() error {
|
||||
return t.client.Close()
|
||||
}
|
||||
|
||||
// StartScheduler starts the scheduler service which adds scheduled tasks to the queue
|
||||
// This must be running in order to queue tasks set for periodic execution
|
||||
func (t *TaskClient) StartScheduler() error {
|
||||
return t.scheduler.Run()
|
||||
}
|
||||
|
||||
// New starts a task creation operation
|
||||
func (t *TaskClient) New(typ string) *task {
|
||||
return &task{
|
||||
client: t,
|
||||
typ: typ,
|
||||
}
|
||||
}
|
||||
|
||||
// Payload sets the task payload data which will be sent to the task handler
|
||||
func (t *task) Payload(payload interface{}) *task {
|
||||
t.payload = payload
|
||||
return t
|
||||
}
|
||||
|
||||
// Periodic sets the task to execute periodically according to a given interval
|
||||
// The interval can be either in cron form ("*/5 * * * *") or "@every 30s"
|
||||
func (t *task) Periodic(interval string) *task {
|
||||
t.periodic = &interval
|
||||
return t
|
||||
}
|
||||
|
||||
// Queue specifies the name of the queue to add the task to
|
||||
// The default queue will be used if this is not set
|
||||
func (t *task) Queue(queue string) *task {
|
||||
t.queue = &queue
|
||||
return t
|
||||
}
|
||||
|
||||
// Timeout sets the task timeout, meaning the task must execute within a given duration
|
||||
func (t *task) Timeout(timeout time.Duration) *task {
|
||||
t.timeout = &timeout
|
||||
return t
|
||||
}
|
||||
|
||||
// Deadline sets the task execution deadline to a specific date and time
|
||||
func (t *task) Deadline(deadline time.Time) *task {
|
||||
t.deadline = &deadline
|
||||
return t
|
||||
}
|
||||
|
||||
// At sets the exact date and time the task should be executed
|
||||
func (t *task) At(processAt time.Time) *task {
|
||||
t.at = &processAt
|
||||
return t
|
||||
}
|
||||
|
||||
// Wait instructs the task to wait a given duration before it is executed
|
||||
func (t *task) Wait(duration time.Duration) *task {
|
||||
t.wait = &duration
|
||||
return t
|
||||
}
|
||||
|
||||
// Retain instructs the task service to retain the task data for a given duration after execution is complete
|
||||
func (t *task) Retain(duration time.Duration) *task {
|
||||
t.retain = &duration
|
||||
return t
|
||||
}
|
||||
|
||||
// MaxRetries sets the maximum amount of times to retry executing the task in the event of a failure
|
||||
func (t *task) MaxRetries(retries int) *task {
|
||||
t.maxRetries = &retries
|
||||
return t
|
||||
}
|
||||
|
||||
// Save saves the task so it can be executed
|
||||
func (t *task) Save() error {
|
||||
var err error
|
||||
|
||||
// Build the payload
|
||||
var payload []byte
|
||||
if t.payload != nil {
|
||||
if payload, err = json.Marshal(t.payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Build the task options
|
||||
opts := make([]asynq.Option, 0)
|
||||
if t.queue != nil {
|
||||
opts = append(opts, asynq.Queue(*t.queue))
|
||||
}
|
||||
if t.maxRetries != nil {
|
||||
opts = append(opts, asynq.MaxRetry(*t.maxRetries))
|
||||
}
|
||||
if t.timeout != nil {
|
||||
opts = append(opts, asynq.Timeout(*t.timeout))
|
||||
}
|
||||
if t.deadline != nil {
|
||||
opts = append(opts, asynq.Deadline(*t.deadline))
|
||||
}
|
||||
if t.wait != nil {
|
||||
opts = append(opts, asynq.ProcessIn(*t.wait))
|
||||
}
|
||||
if t.retain != nil {
|
||||
opts = append(opts, asynq.Retention(*t.retain))
|
||||
}
|
||||
if t.at != nil {
|
||||
opts = append(opts, asynq.ProcessAt(*t.at))
|
||||
}
|
||||
|
||||
// Build the task
|
||||
task := asynq.NewTask(t.typ, payload, opts...)
|
||||
|
||||
// Schedule, if needed
|
||||
if t.periodic != nil {
|
||||
_, err = t.client.scheduler.Register(*t.periodic, task)
|
||||
} else {
|
||||
_, err = t.client.client.Enqueue(task)
|
||||
}
|
||||
return err
|
||||
}
|
||||
35
pkg/services/tasks_test.go
Normal file
35
pkg/services/tasks_test.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTaskClient_New(t *testing.T) {
|
||||
now := time.Now()
|
||||
tk := c.Tasks.
|
||||
New("task1").
|
||||
Payload("payload").
|
||||
Queue("queue").
|
||||
Periodic("@every 5s").
|
||||
MaxRetries(5).
|
||||
Timeout(5 * time.Second).
|
||||
Deadline(now).
|
||||
At(now).
|
||||
Wait(6 * time.Second).
|
||||
Retain(7 * time.Second)
|
||||
|
||||
assert.Equal(t, "task1", tk.typ)
|
||||
assert.Equal(t, "payload", tk.payload.(string))
|
||||
assert.Equal(t, "queue", *tk.queue)
|
||||
assert.Equal(t, "@every 5s", *tk.periodic)
|
||||
assert.Equal(t, 5, *tk.maxRetries)
|
||||
assert.Equal(t, 5*time.Second, *tk.timeout)
|
||||
assert.Equal(t, now, *tk.deadline)
|
||||
assert.Equal(t, now, *tk.at)
|
||||
assert.Equal(t, 6*time.Second, *tk.wait)
|
||||
assert.Equal(t, 7*time.Second, *tk.retain)
|
||||
assert.NoError(t, tk.Save())
|
||||
}
|
||||
233
pkg/services/template_renderer.go
Normal file
233
pkg/services/template_renderer.go
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
"github.com/mikestefanello/pagoda/pkg/funcmap"
|
||||
)
|
||||
|
||||
type (
|
||||
// TemplateRenderer provides a flexible and easy to use method of rendering simple templates or complex sets of
|
||||
// templates while also providing caching and/or hot-reloading depending on your current environment
|
||||
TemplateRenderer struct {
|
||||
// templateCache stores a cache of parsed page templates
|
||||
templateCache sync.Map
|
||||
|
||||
// funcMap stores the template function map
|
||||
funcMap template.FuncMap
|
||||
|
||||
// templatePath stores the complete path to the templates directory
|
||||
templatesPath string
|
||||
|
||||
// config stores application configuration
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// TemplateParsed is a wrapper around parsed templates which are stored in the TemplateRenderer cache
|
||||
TemplateParsed struct {
|
||||
// Template is the parsed template
|
||||
Template *template.Template
|
||||
|
||||
// build stores the build data used to parse the template
|
||||
build *templateBuild
|
||||
}
|
||||
|
||||
// templateBuild stores the build data used to parse a template
|
||||
templateBuild struct {
|
||||
group string
|
||||
key string
|
||||
base string
|
||||
files []string
|
||||
directories []string
|
||||
}
|
||||
|
||||
// templateBuilder handles chaining a template parse operation
|
||||
templateBuilder struct {
|
||||
build *templateBuild
|
||||
renderer *TemplateRenderer
|
||||
}
|
||||
)
|
||||
|
||||
// NewTemplateRenderer creates a new TemplateRenderer
|
||||
func NewTemplateRenderer(cfg *config.Config) *TemplateRenderer {
|
||||
t := &TemplateRenderer{
|
||||
templateCache: sync.Map{},
|
||||
funcMap: funcmap.GetFuncMap(),
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
// Gets the complete templates directory path
|
||||
// This is needed in case this is called from a package outside of main, such as within tests
|
||||
_, b, _, _ := runtime.Caller(0)
|
||||
d := path.Join(path.Dir(b))
|
||||
t.templatesPath = filepath.Join(filepath.Dir(d), config.TemplateDir)
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Parse creates a template build operation
|
||||
func (t *TemplateRenderer) Parse() *templateBuilder {
|
||||
return &templateBuilder{
|
||||
renderer: t,
|
||||
build: &templateBuild{},
|
||||
}
|
||||
}
|
||||
|
||||
// GetTemplatesPath gets the complete path to the templates directory
|
||||
func (t *TemplateRenderer) GetTemplatesPath() string {
|
||||
return t.templatesPath
|
||||
}
|
||||
|
||||
// getCacheKey gets a cache key for a given group and ID
|
||||
func (t *TemplateRenderer) getCacheKey(group, key string) string {
|
||||
if group != "" {
|
||||
return fmt.Sprintf("%s:%s", group, key)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// parse parses a set of templates and caches them for quick execution
|
||||
// If the application environment is set to local, the cache will be bypassed and templates will be
|
||||
// parsed upon each request so hot-reloading is possible without restarts.
|
||||
// Also included will be the function map provided by the funcmap package.
|
||||
func (t *TemplateRenderer) parse(build *templateBuild) (*TemplateParsed, error) {
|
||||
var tp *TemplateParsed
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case build.key == "":
|
||||
return nil, errors.New("cannot parse template without key")
|
||||
case len(build.files) == 0 && len(build.directories) == 0:
|
||||
return nil, errors.New("cannot parse template without files or directories")
|
||||
case build.base == "":
|
||||
return nil, errors.New("cannot parse template without base")
|
||||
}
|
||||
|
||||
// Generate the cache key
|
||||
cacheKey := t.getCacheKey(build.group, build.key)
|
||||
|
||||
// Check if the template has not yet been parsed or if the app environment is local, so that
|
||||
// templates reflect changes without having the restart the server
|
||||
if tp, err = t.Load(build.group, build.key); err != nil || t.config.App.Environment == config.EnvLocal {
|
||||
// Initialize the parsed template with the function map
|
||||
parsed := template.New(build.base + config.TemplateExt).
|
||||
Funcs(t.funcMap)
|
||||
|
||||
// Parse all files provided
|
||||
if len(build.files) > 0 {
|
||||
for k, v := range build.files {
|
||||
build.files[k] = fmt.Sprintf("%s/%s%s", t.templatesPath, v, config.TemplateExt)
|
||||
}
|
||||
|
||||
parsed, err = parsed.ParseFiles(build.files...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Parse all templates within the provided directories
|
||||
for _, dir := range build.directories {
|
||||
dir = fmt.Sprintf("%s/%s/*%s", t.templatesPath, dir, config.TemplateExt)
|
||||
parsed, err = parsed.ParseGlob(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Store the template so this process only happens once
|
||||
tp = &TemplateParsed{
|
||||
Template: parsed,
|
||||
build: build,
|
||||
}
|
||||
t.templateCache.Store(cacheKey, tp)
|
||||
}
|
||||
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
// Load loads a template from the cache
|
||||
func (t *TemplateRenderer) Load(group, key string) (*TemplateParsed, error) {
|
||||
load, ok := t.templateCache.Load(t.getCacheKey(group, key))
|
||||
if !ok {
|
||||
return nil, errors.New("uncached page template requested")
|
||||
}
|
||||
|
||||
tmpl, ok := load.(*TemplateParsed)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to cast cached template")
|
||||
}
|
||||
|
||||
return tmpl, nil
|
||||
}
|
||||
|
||||
// Execute executes a template with the given data and provides the output
|
||||
func (t *TemplateParsed) Execute(data interface{}) (*bytes.Buffer, error) {
|
||||
if t.Template == nil {
|
||||
return nil, errors.New("cannot execute template: template not initialized")
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err := t.Template.ExecuteTemplate(buf, t.build.base+config.TemplateExt, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Group sets the cache group for the template being built
|
||||
func (t *templateBuilder) Group(group string) *templateBuilder {
|
||||
t.build.group = group
|
||||
return t
|
||||
}
|
||||
|
||||
// Key sets the cache key for the template being built
|
||||
func (t *templateBuilder) Key(key string) *templateBuilder {
|
||||
t.build.key = key
|
||||
return t
|
||||
}
|
||||
|
||||
// Base sets the name of the base template to be used during template parsing and execution.
|
||||
// This should be only the file name without a directory or extension.
|
||||
func (t *templateBuilder) Base(base string) *templateBuilder {
|
||||
t.build.base = base
|
||||
return t
|
||||
}
|
||||
|
||||
// Files sets a list of template files to include in the parse.
|
||||
// This should not include the file extension and the paths should be relative to the templates directory.
|
||||
func (t *templateBuilder) Files(files ...string) *templateBuilder {
|
||||
t.build.files = files
|
||||
return t
|
||||
}
|
||||
|
||||
// Directories sets a list of directories that all template files within will be parsed.
|
||||
// The paths should be relative to the templates directory.
|
||||
func (t *templateBuilder) Directories(directories ...string) *templateBuilder {
|
||||
t.build.directories = directories
|
||||
return t
|
||||
}
|
||||
|
||||
// Store parsed the templates and stores them in the cache
|
||||
func (t *templateBuilder) Store() (*TemplateParsed, error) {
|
||||
return t.renderer.parse(t.build)
|
||||
}
|
||||
|
||||
// Execute executes the template with the given data.
|
||||
// If the template has not already been cached, this will parse and cache the template
|
||||
func (t *templateBuilder) Execute(data interface{}) (*bytes.Buffer, error) {
|
||||
tp, err := t.Store()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tp.Execute(data)
|
||||
}
|
||||
72
pkg/services/template_renderer_test.go
Normal file
72
pkg/services/template_renderer_test.go
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/mikestefanello/pagoda/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTemplateRenderer(t *testing.T) {
|
||||
group := "test"
|
||||
id := "parse"
|
||||
|
||||
// Should not exist yet
|
||||
_, err := c.TemplateRenderer.Load(group, id)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Parse in to the cache
|
||||
tpl, err := c.TemplateRenderer.
|
||||
Parse().
|
||||
Group(group).
|
||||
Key(id).
|
||||
Base("htmx").
|
||||
Files("htmx", "pages/error").
|
||||
Directories("components").
|
||||
Store()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should exist now
|
||||
parsed, err := c.TemplateRenderer.Load(group, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all expected templates are included
|
||||
expectedTemplates := make(map[string]bool)
|
||||
expectedTemplates["htmx"+config.TemplateExt] = true
|
||||
expectedTemplates["error"+config.TemplateExt] = true
|
||||
components, err := ioutil.ReadDir(c.TemplateRenderer.GetTemplatesPath() + "/components")
|
||||
require.NoError(t, err)
|
||||
for _, f := range components {
|
||||
expectedTemplates[f.Name()] = true
|
||||
}
|
||||
for _, v := range parsed.Template.Templates() {
|
||||
delete(expectedTemplates, v.Name())
|
||||
}
|
||||
assert.Empty(t, expectedTemplates)
|
||||
|
||||
data := struct {
|
||||
StatusCode int
|
||||
}{
|
||||
StatusCode: 500,
|
||||
}
|
||||
buf, err := tpl.Execute(data)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, buf)
|
||||
assert.Contains(t, buf.String(), "Please try again")
|
||||
|
||||
buf, err = c.TemplateRenderer.
|
||||
Parse().
|
||||
Group(group).
|
||||
Key(id).
|
||||
Base("htmx").
|
||||
Files("htmx", "pages/error").
|
||||
Directories("components").
|
||||
Execute(data)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, buf)
|
||||
assert.Contains(t, buf.String(), "Please try again")
|
||||
}
|
||||
26
pkg/services/validator.go
Normal file
26
pkg/services/validator.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// Validator provides validation mainly validating structs within the web context
|
||||
type Validator struct {
|
||||
// validator stores the underlying validator
|
||||
validator *validator.Validate
|
||||
}
|
||||
|
||||
// NewValidator creats a new Validator
|
||||
func NewValidator() *Validator {
|
||||
return &Validator{
|
||||
validator: validator.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates a struct
|
||||
func (v *Validator) Validate(i interface{}) error {
|
||||
if err := v.validator.Struct(i); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
19
pkg/services/validator_test.go
Normal file
19
pkg/services/validator_test.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidator(t *testing.T) {
|
||||
type example struct {
|
||||
Value string `validate:"required"`
|
||||
}
|
||||
e := example{}
|
||||
err := c.Validator.Validate(e)
|
||||
assert.Error(t, err)
|
||||
e.Value = "a"
|
||||
err = c.Validator.Validate(e)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
22
pkg/tasks/example.go
Normal file
22
pkg/tasks/example.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
package tasks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
)
|
||||
|
||||
// TypeExample is the type for the example task.
|
||||
// This is what is passed in to TaskClient.New() when creating a new task
|
||||
const TypeExample = "example_task"
|
||||
|
||||
// ExampleProcessor processes example tasks
|
||||
type ExampleProcessor struct {
|
||||
}
|
||||
|
||||
// ProcessTask handles the processing of the task
|
||||
func (p *ExampleProcessor) ProcessTask(ctx context.Context, t *asynq.Task) error {
|
||||
log.Printf("executing task: %s", t.Type())
|
||||
return nil
|
||||
}
|
||||
60
pkg/tests/tests.go
Normal file
60
pkg/tests/tests.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mikestefanello/pagoda/ent"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// NewContext creates a new Echo context for tests using an HTTP test request and response recorder
|
||||
func NewContext(e *echo.Echo, url string) (echo.Context, *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest(http.MethodGet, url, strings.NewReader(""))
|
||||
rec := httptest.NewRecorder()
|
||||
return e.NewContext(req, rec), rec
|
||||
}
|
||||
|
||||
// InitSession initializes a session for a given Echo context
|
||||
func InitSession(ctx echo.Context) {
|
||||
mw := session.Middleware(sessions.NewCookieStore([]byte("secret")))
|
||||
_ = ExecuteMiddleware(ctx, mw)
|
||||
}
|
||||
|
||||
// ExecuteMiddleware executes a middleware function on a given Echo context
|
||||
func ExecuteMiddleware(ctx echo.Context, mw echo.MiddlewareFunc) error {
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return nil
|
||||
})
|
||||
return handler(ctx)
|
||||
}
|
||||
|
||||
// AssertHTTPErrorCode asserts an HTTP status code on a given Echo HTTP error
|
||||
func AssertHTTPErrorCode(t *testing.T, err error, code int) {
|
||||
httpError, ok := err.(*echo.HTTPError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, code, httpError.Code)
|
||||
}
|
||||
|
||||
// CreateUser creates a random user entity
|
||||
func CreateUser(orm *ent.Client) (*ent.User, error) {
|
||||
seed := fmt.Sprintf("%d-%d", time.Now().UnixMilli(), rand.Intn(1000000))
|
||||
return orm.User.
|
||||
Create().
|
||||
SetEmail(fmt.Sprintf("testuser-%s@localhost.localhost", seed)).
|
||||
SetPassword("password").
|
||||
SetName(fmt.Sprintf("Test User %s", seed)).
|
||||
Save(context.Background())
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue