Reorganized directories and packages.

This commit is contained in:
mikestefanello 2022-11-02 19:23:26 -04:00
parent 965fb540c7
commit dceb232cb2
61 changed files with 83 additions and 83 deletions

233
pkg/services/auth.go Normal file
View file

@ -0,0 +1,233 @@
package services
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/ent/passwordtoken"
"github.com/mikestefanello/pagoda/ent/user"
"github.com/mikestefanello/pagoda/pkg/context"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"golang.org/x/crypto/bcrypt"
)
const (
// authSessionName stores the name of the session which contains authentication data
authSessionName = "ua"
// authSessionKeyUserID stores the key used to store the user ID in the session
authSessionKeyUserID = "user_id"
// authSessionKeyAuthenticated stores the key used to store the authentication status in the session
authSessionKeyAuthenticated = "authenticated"
)
// NotAuthenticatedError is an error returned when a user is not authenticated
type NotAuthenticatedError struct{}
// Error implements the error interface.
func (e NotAuthenticatedError) Error() string {
return "user not authenticated"
}
// InvalidPasswordTokenError is an error returned when an invalid token is provided
type InvalidPasswordTokenError struct{}
// Error implements the error interface.
func (e InvalidPasswordTokenError) Error() string {
return "invalid password token"
}
// AuthClient is the client that handles authentication requests
type AuthClient struct {
config *config.Config
orm *ent.Client
}
// NewAuthClient creates a new authentication client
func NewAuthClient(cfg *config.Config, orm *ent.Client) *AuthClient {
return &AuthClient{
config: cfg,
orm: orm,
}
}
// Login logs in a user of a given ID
func (c *AuthClient) Login(ctx echo.Context, userID int) error {
sess, err := session.Get(authSessionName, ctx)
if err != nil {
return err
}
sess.Values[authSessionKeyUserID] = userID
sess.Values[authSessionKeyAuthenticated] = true
return sess.Save(ctx.Request(), ctx.Response())
}
// Logout logs the requesting user out
func (c *AuthClient) Logout(ctx echo.Context) error {
sess, err := session.Get(authSessionName, ctx)
if err != nil {
return err
}
sess.Values[authSessionKeyAuthenticated] = false
return sess.Save(ctx.Request(), ctx.Response())
}
// GetAuthenticatedUserID returns the authenticated user's ID, if the user is logged in
func (c *AuthClient) GetAuthenticatedUserID(ctx echo.Context) (int, error) {
sess, err := session.Get(authSessionName, ctx)
if err != nil {
return 0, err
}
if sess.Values[authSessionKeyAuthenticated] == true {
return sess.Values[authSessionKeyUserID].(int), nil
}
return 0, NotAuthenticatedError{}
}
// GetAuthenticatedUser returns the authenticated user if the user is logged in
func (c *AuthClient) GetAuthenticatedUser(ctx echo.Context) (*ent.User, error) {
if userID, err := c.GetAuthenticatedUserID(ctx); err == nil {
return c.orm.User.Query().
Where(user.ID(userID)).
Only(ctx.Request().Context())
}
return nil, NotAuthenticatedError{}
}
// HashPassword returns a hash of a given password
func (c *AuthClient) HashPassword(password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hash), nil
}
// CheckPassword check if a given password matches a given hash
func (c *AuthClient) CheckPassword(password, hash string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
}
// GeneratePasswordResetToken generates a password reset token for a given user.
// For security purposes, the token itself is not stored in the database but rather
// a hash of the token, exactly how passwords are handled. This method returns both
// the generated token as well as the token entity which only contains the hash.
func (c *AuthClient) GeneratePasswordResetToken(ctx echo.Context, userID int) (string, *ent.PasswordToken, error) {
// Generate the token, which is what will go in the URL, but not the database
token, err := c.RandomToken(c.config.App.PasswordToken.Length)
if err != nil {
return "", nil, err
}
// Hash the token, which is what will be stored in the database
hash, err := c.HashPassword(token)
if err != nil {
return "", nil, err
}
// Create and save the password reset token
pt, err := c.orm.PasswordToken.
Create().
SetHash(hash).
SetUserID(userID).
Save(ctx.Request().Context())
return token, pt, err
}
// GetValidPasswordToken returns a valid, non-expired password token entity for a given user, token ID and token.
// Since the actual token is not stored in the database for security purposes, if a matching password token entity is
// found a hash of the provided token is compared with the hash stored in the database in order to validate.
func (c *AuthClient) GetValidPasswordToken(ctx echo.Context, userID, tokenID int, token string) (*ent.PasswordToken, error) {
// Ensure expired tokens are never returned
expiration := time.Now().Add(-c.config.App.PasswordToken.Expiration)
// Query to find a password token entity that matches the given user and token ID
pt, err := c.orm.PasswordToken.
Query().
Where(passwordtoken.ID(tokenID)).
Where(passwordtoken.HasUserWith(user.ID(userID))).
Where(passwordtoken.CreatedAtGTE(expiration)).
Only(ctx.Request().Context())
switch err.(type) {
case *ent.NotFoundError:
case nil:
// Check the token for a hash match
if err := c.CheckPassword(token, pt.Hash); err == nil {
return pt, nil
}
default:
if !context.IsCanceledError(err) {
return nil, err
}
}
return nil, InvalidPasswordTokenError{}
}
// DeletePasswordTokens deletes all password tokens in the database for a belonging to a given user.
// This should be called after a successful password reset.
func (c *AuthClient) DeletePasswordTokens(ctx echo.Context, userID int) error {
_, err := c.orm.PasswordToken.
Delete().
Where(passwordtoken.HasUserWith(user.ID(userID))).
Exec(ctx.Request().Context())
return err
}
// RandomToken generates a random token string of a given length
func (c *AuthClient) RandomToken(length int) (string, error) {
b := make([]byte, (length/2)+1)
if _, err := rand.Read(b); err != nil {
return "", err
}
token := hex.EncodeToString(b)
return token[:length], nil
}
// GenerateEmailVerificationToken generates an email verification token for a given email address using JWT which
// is set to expire based on the duration stored in configuration
func (c *AuthClient) GenerateEmailVerificationToken(email string) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"email": email,
"exp": time.Now().Add(c.config.App.EmailVerificationTokenExpiration).Unix(),
})
return token.SignedString([]byte(c.config.App.EncryptionKey))
}
// ValidateEmailVerificationToken validates an email verification token and returns the associated email address if
// the token is valid and has not expired
func (c *AuthClient) ValidateEmailVerificationToken(token string) (string, error) {
t, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(c.config.App.EncryptionKey), nil
})
if err != nil {
return "", err
}
if claims, ok := t.Claims.(jwt.MapClaims); ok && t.Valid {
return claims["email"].(string), nil
}
return "", errors.New("invalid or expired token")
}

146
pkg/services/auth_test.go Normal file
View file

@ -0,0 +1,146 @@
package services
import (
"context"
"errors"
"testing"
"time"
"github.com/mikestefanello/pagoda/ent/passwordtoken"
"github.com/mikestefanello/pagoda/ent/user"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)
func TestAuthClient_Auth(t *testing.T) {
assertNoAuth := func() {
_, err := c.Auth.GetAuthenticatedUserID(ctx)
assert.True(t, errors.Is(err, NotAuthenticatedError{}))
_, err = c.Auth.GetAuthenticatedUser(ctx)
assert.True(t, errors.Is(err, NotAuthenticatedError{}))
}
assertNoAuth()
err := c.Auth.Login(ctx, usr.ID)
require.NoError(t, err)
uid, err := c.Auth.GetAuthenticatedUserID(ctx)
require.NoError(t, err)
assert.Equal(t, usr.ID, uid)
u, err := c.Auth.GetAuthenticatedUser(ctx)
require.NoError(t, err)
assert.Equal(t, u.ID, usr.ID)
err = c.Auth.Logout(ctx)
require.NoError(t, err)
assertNoAuth()
}
func TestAuthClient_PasswordHashing(t *testing.T) {
pw := "testcheckpassword"
hash, err := c.Auth.HashPassword(pw)
assert.NoError(t, err)
assert.NotEqual(t, hash, pw)
err = c.Auth.CheckPassword(pw, hash)
assert.NoError(t, err)
}
func TestAuthClient_GeneratePasswordResetToken(t *testing.T) {
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
require.NoError(t, err)
assert.Len(t, token, c.Config.App.PasswordToken.Length)
assert.NoError(t, c.Auth.CheckPassword(token, pt.Hash))
}
func TestAuthClient_GetValidPasswordToken(t *testing.T) {
// Check that a fake token is not valid
_, err := c.Auth.GetValidPasswordToken(ctx, usr.ID, 1, "faketoken")
assert.Error(t, err)
// Generate a valid token and check that it is returned
token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
require.NoError(t, err)
pt2, err := c.Auth.GetValidPasswordToken(ctx, usr.ID, pt.ID, token)
require.NoError(t, err)
assert.Equal(t, pt.ID, pt2.ID)
// Expire the token by pushing the date far enough back
count, err := c.ORM.PasswordToken.
Update().
SetCreatedAt(time.Now().Add(-(c.Config.App.PasswordToken.Expiration + time.Hour))).
Where(passwordtoken.ID(pt.ID)).
Save(context.Background())
require.NoError(t, err)
require.Equal(t, 1, count)
// Expired tokens should not be valid
_, err = c.Auth.GetValidPasswordToken(ctx, usr.ID, pt.ID, token)
assert.Error(t, err)
}
func TestAuthClient_DeletePasswordTokens(t *testing.T) {
// Create three tokens for the user
for i := 0; i < 3; i++ {
_, _, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID)
require.NoError(t, err)
}
// Delete all tokens for the user
err := c.Auth.DeletePasswordTokens(ctx, usr.ID)
require.NoError(t, err)
// Check that no tokens remain
count, err := c.ORM.PasswordToken.
Query().
Where(passwordtoken.HasUserWith(user.ID(usr.ID))).
Count(context.Background())
require.NoError(t, err)
assert.Equal(t, 0, count)
}
func TestAuthClient_RandomToken(t *testing.T) {
length := c.Config.App.PasswordToken.Length
a, err := c.Auth.RandomToken(length)
require.NoError(t, err)
b, err := c.Auth.RandomToken(length)
require.NoError(t, err)
assert.Len(t, a, length)
assert.Len(t, b, length)
assert.NotEqual(t, a, b)
}
func TestAuthClient_EmailVerificationToken(t *testing.T) {
t.Run("valid token", func(t *testing.T) {
email := "test@localhost.com"
token, err := c.Auth.GenerateEmailVerificationToken(email)
require.NoError(t, err)
tokenEmail, err := c.Auth.ValidateEmailVerificationToken(token)
require.NoError(t, err)
assert.Equal(t, email, tokenEmail)
})
t.Run("invalid token", func(t *testing.T) {
badToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAbG9jYWxob3N0LmNvbSIsImV4cCI6MTkxNzg2NDAwMH0.ScJCpfEEzlilKfRs_aVouzwPNKI28M3AIm-hyImQHUQ"
_, err := c.Auth.ValidateEmailVerificationToken(badToken)
assert.Error(t, err)
})
t.Run("expired token", func(t *testing.T) {
c.Config.App.EmailVerificationTokenExpiration = -time.Hour
email := "test@localhost.com"
token, err := c.Auth.GenerateEmailVerificationToken(email)
require.NoError(t, err)
_, err = c.Auth.ValidateEmailVerificationToken(token)
assert.Error(t, err)
c.Config.App.EmailVerificationTokenExpiration = time.Hour * 12
})
}

228
pkg/services/cache.go Normal file
View file

@ -0,0 +1,228 @@
package services
import (
"context"
"errors"
"fmt"
"time"
"github.com/eko/gocache/v2/cache"
"github.com/eko/gocache/v2/marshaler"
"github.com/eko/gocache/v2/store"
"github.com/go-redis/redis/v8"
"github.com/mikestefanello/pagoda/config"
)
type (
// CacheClient is the client that allows you to interact with the cache
CacheClient struct {
// Client stores the client to the underlying cache service
Client *redis.Client
// cache stores the cache interface
cache *cache.Cache
}
// cacheSet handles chaining a set operation
cacheSet struct {
client *CacheClient
key string
group string
data interface{}
expiration time.Duration
tags []string
}
// cacheGet handles chaining a get operation
cacheGet struct {
client *CacheClient
key string
group string
dataType interface{}
}
// cacheFlush handles chaining a flush operation
cacheFlush struct {
client *CacheClient
key string
group string
tags []string
}
)
// NewCacheClient creates a new cache client
func NewCacheClient(cfg *config.Config) (*CacheClient, error) {
// Determine the database based on the environment
db := cfg.Cache.Database
if cfg.App.Environment == config.EnvTest {
db = cfg.Cache.TestDatabase
}
// Connect to the cache
c := &CacheClient{}
c.Client = redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Hostname, cfg.Cache.Port),
Password: cfg.Cache.Password,
DB: db,
})
if _, err := c.Client.Ping(context.Background()).Result(); err != nil {
return c, err
}
// Flush the database if this is the test environment
if cfg.App.Environment == config.EnvTest {
if err := c.Client.FlushDB(context.Background()).Err(); err != nil {
return c, err
}
}
cacheStore := store.NewRedis(c.Client, nil)
c.cache = cache.New(cacheStore)
return c, nil
}
// Close closes the connection to the cache
func (c *CacheClient) Close() error {
return c.Client.Close()
}
// Set creates a cache set operation
func (c *CacheClient) Set() *cacheSet {
return &cacheSet{
client: c,
}
}
// Get creates a cache get operation
func (c *CacheClient) Get() *cacheGet {
return &cacheGet{
client: c,
}
}
// Flush creates a cache flush operation
func (c *CacheClient) Flush() *cacheFlush {
return &cacheFlush{
client: c,
}
}
// cacheKey formats a cache key with an optional group
func (c *CacheClient) cacheKey(group, key string) string {
if group != "" {
return fmt.Sprintf("%s::%s", group, key)
}
return key
}
// Key sets the cache key
func (c *cacheSet) Key(key string) *cacheSet {
c.key = key
return c
}
// Group sets the cache group
func (c *cacheSet) Group(group string) *cacheSet {
c.group = group
return c
}
// Data sets the data to cache
func (c *cacheSet) Data(data interface{}) *cacheSet {
c.data = data
return c
}
// Expiration sets the expiration duration of the cached data
func (c *cacheSet) Expiration(expiration time.Duration) *cacheSet {
c.expiration = expiration
return c
}
// Tags sets the cache tags
func (c *cacheSet) Tags(tags ...string) *cacheSet {
c.tags = tags
return c
}
// Save saves the data in the cache
func (c *cacheSet) Save(ctx context.Context) error {
if c.key == "" {
return errors.New("no cache key specified")
}
opts := &store.Options{
Expiration: c.expiration,
Tags: c.tags,
}
return marshaler.
New(c.client.cache).
Set(ctx, c.client.cacheKey(c.group, c.key), c.data, opts)
}
// Key sets the cache key
func (c *cacheGet) Key(key string) *cacheGet {
c.key = key
return c
}
// Group sets the cache group
func (c *cacheGet) Group(group string) *cacheGet {
c.group = group
return c
}
// Type sets the expected Go type of the data being retrieved from the cache
func (c *cacheGet) Type(expectedType interface{}) *cacheGet {
c.dataType = expectedType
return c
}
// Fetch fetches the data from the cache
func (c *cacheGet) Fetch(ctx context.Context) (interface{}, error) {
if c.key == "" {
return nil, errors.New("no cache key specified")
}
return marshaler.New(c.client.cache).Get(
ctx,
c.client.cacheKey(c.group, c.key),
c.dataType,
)
}
// Key sets the cache key
func (c *cacheFlush) Key(key string) *cacheFlush {
c.key = key
return c
}
// Group sets the cache group
func (c *cacheFlush) Group(group string) *cacheFlush {
c.group = group
return c
}
// Tags sets the cache tags
func (c *cacheFlush) Tags(tags ...string) *cacheFlush {
c.tags = tags
return c
}
// Execute flushes the data from the cache
func (c *cacheFlush) Execute(ctx context.Context) error {
if len(c.tags) > 0 {
if err := c.client.cache.Invalidate(ctx, store.InvalidateOptions{
Tags: c.tags,
}); err != nil {
return err
}
}
if c.key != "" {
return c.client.cache.Delete(ctx, c.client.cacheKey(c.group, c.key))
}
return nil
}

105
pkg/services/cache_test.go Normal file
View file

@ -0,0 +1,105 @@
package services
import (
"context"
"testing"
"time"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCacheClient(t *testing.T) {
type cacheTest struct {
Value string
}
// Cache some data
data := cacheTest{Value: "abcdef"}
group := "testgroup"
key := "testkey"
err := c.Cache.
Set().
Group(group).
Key(key).
Data(data).
Save(context.Background())
require.NoError(t, err)
// Get the data
fromCache, err := c.Cache.
Get().
Group(group).
Key(key).
Type(new(cacheTest)).
Fetch(context.Background())
require.NoError(t, err)
cast, ok := fromCache.(*cacheTest)
require.True(t, ok)
assert.Equal(t, data, *cast)
// The same key with the wrong group should fail
_, err = c.Cache.
Get().
Key(key).
Type(new(cacheTest)).
Fetch(context.Background())
assert.Error(t, err)
// Flush the data
err = c.Cache.
Flush().
Group(group).
Key(key).
Execute(context.Background())
require.NoError(t, err)
// The data should be gone
assertFlushed := func() {
// The data should be gone
_, err = c.Cache.
Get().
Group(group).
Key(key).
Type(new(cacheTest)).
Fetch(context.Background())
assert.Equal(t, redis.Nil, err)
}
assertFlushed()
// Set with tags
err = c.Cache.
Set().
Group(group).
Key(key).
Data(data).
Tags("tag1").
Save(context.Background())
require.NoError(t, err)
// Flush the tag
err = c.Cache.
Flush().
Tags("tag1").
Execute(context.Background())
require.NoError(t, err)
// The data should be gone
assertFlushed()
// Set with expiration
err = c.Cache.
Set().
Group(group).
Key(key).
Data(data).
Expiration(time.Millisecond).
Save(context.Background())
require.NoError(t, err)
// Wait for expiration
time.Sleep(time.Millisecond * 2)
// The data should be gone
assertFlushed()
}

201
pkg/services/container.go Normal file
View file

@ -0,0 +1,201 @@
package services
import (
"context"
"database/sql"
"fmt"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/schema"
// Required by ent
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/ent"
// Require by ent
_ "github.com/mikestefanello/pagoda/ent/runtime"
)
// Container contains all services used by the application and provides an easy way to handle dependency
// injection including within tests
type Container struct {
// Validator stores a validator
Validator *Validator
// Web stores the web framework
Web *echo.Echo
// Config stores the application configuration
Config *config.Config
// Cache contains the cache client
Cache *CacheClient
// Database stores the connection to the database
Database *sql.DB
// ORM stores a client to the ORM
ORM *ent.Client
// Mail stores an email sending client
Mail *MailClient
// Auth stores an authentication client
Auth *AuthClient
// TemplateRenderer stores a service to easily render and cache templates
TemplateRenderer *TemplateRenderer
// Tasks stores the task client
Tasks *TaskClient
}
// NewContainer creates and initializes a new Container
func NewContainer() *Container {
c := new(Container)
c.initConfig()
c.initValidator()
c.initWeb()
c.initCache()
c.initDatabase()
c.initORM()
c.initAuth()
c.initTemplateRenderer()
c.initMail()
c.initTasks()
return c
}
// Shutdown shuts the Container down and disconnects all connections
func (c *Container) Shutdown() error {
if err := c.Tasks.Close(); err != nil {
return err
}
if err := c.Cache.Close(); err != nil {
return err
}
if err := c.ORM.Close(); err != nil {
return err
}
if err := c.Database.Close(); err != nil {
return err
}
return nil
}
// initConfig initializes configuration
func (c *Container) initConfig() {
cfg, err := config.GetConfig()
if err != nil {
panic(fmt.Sprintf("failed to load config: %v", err))
}
c.Config = &cfg
}
// initValidator initializes the validator
func (c *Container) initValidator() {
c.Validator = NewValidator()
}
// initWeb initializes the web framework
func (c *Container) initWeb() {
c.Web = echo.New()
// Configure logging
switch c.Config.App.Environment {
case config.EnvProduction:
c.Web.Logger.SetLevel(log.WARN)
default:
c.Web.Logger.SetLevel(log.DEBUG)
}
c.Web.Validator = c.Validator
}
// initCache initializes the cache
func (c *Container) initCache() {
var err error
if c.Cache, err = NewCacheClient(c.Config); err != nil {
panic(err)
}
}
// initDatabase initializes the database
// If the environment is set to test, the test database will be used and will be dropped, recreated and migrated
func (c *Container) initDatabase() {
var err error
getAddr := func(dbName string) string {
return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s",
c.Config.Database.User,
c.Config.Database.Password,
c.Config.Database.Hostname,
c.Config.Database.Port,
dbName,
)
}
c.Database, err = sql.Open("pgx", getAddr(c.Config.Database.Database))
if err != nil {
panic(fmt.Sprintf("failed to connect to database: %v", err))
}
// Check if this is a test environment
if c.Config.App.Environment == config.EnvTest {
// Drop the test database, ignoring errors in case it doesn't yet exist
_, _ = c.Database.Exec("DROP DATABASE " + c.Config.Database.TestDatabase)
// Create the test database
if _, err = c.Database.Exec("CREATE DATABASE " + c.Config.Database.TestDatabase); err != nil {
panic(fmt.Sprintf("failed to create test database: %v", err))
}
// Connect to the test database
if err = c.Database.Close(); err != nil {
panic(fmt.Sprintf("failed to close database connection: %v", err))
}
c.Database, err = sql.Open("pgx", getAddr(c.Config.Database.TestDatabase))
if err != nil {
panic(fmt.Sprintf("failed to connect to database: %v", err))
}
}
}
// initORM initializes the ORM
func (c *Container) initORM() {
drv := entsql.OpenDB(dialect.Postgres, c.Database)
c.ORM = ent.NewClient(ent.Driver(drv))
if err := c.ORM.Schema.Create(context.Background(), schema.WithAtlas(true)); err != nil {
panic(fmt.Sprintf("failed to create database schema: %v", err))
}
}
// initAuth initializes the authentication client
func (c *Container) initAuth() {
c.Auth = NewAuthClient(c.Config, c.ORM)
}
// initTemplateRenderer initializes the template renderer
func (c *Container) initTemplateRenderer() {
c.TemplateRenderer = NewTemplateRenderer(c.Config)
}
// initMail initialize the mail client
func (c *Container) initMail() {
var err error
c.Mail, err = NewMailClient(c.Config, c.TemplateRenderer)
if err != nil {
panic(fmt.Sprintf("failed to create mail client: %v", err))
}
}
// initTasks initializes the task client
func (c *Container) initTasks() {
c.Tasks = NewTaskClient(c.Config)
}

View file

@ -0,0 +1,20 @@
package services
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewContainer(t *testing.T) {
assert.NotNil(t, c.Web)
assert.NotNil(t, c.Config)
assert.NotNil(t, c.Validator)
assert.NotNil(t, c.Cache)
assert.NotNil(t, c.Database)
assert.NotNil(t, c.ORM)
assert.NotNil(t, c.Mail)
assert.NotNil(t, c.Auth)
assert.NotNil(t, c.TemplateRenderer)
assert.NotNil(t, c.Tasks)
}

139
pkg/services/mail.go Normal file
View file

@ -0,0 +1,139 @@
package services
import (
"errors"
"fmt"
"github.com/mikestefanello/pagoda/config"
"github.com/labstack/echo/v4"
)
type (
// MailClient provides a client for sending email
// This is purposely not completed because there are many different methods and services
// for sending email, many of which are very different. Choose what works best for you
// and populate the methods below
MailClient struct {
// config stores application configuration
config *config.Config
// templates stores the template renderer
templates *TemplateRenderer
}
// mail represents an email to be sent
mail struct {
client *MailClient
from string
to string
subject string
body string
template string
templateData interface{}
}
)
// NewMailClient creates a new MailClient
func NewMailClient(cfg *config.Config, templates *TemplateRenderer) (*MailClient, error) {
return &MailClient{
config: cfg,
templates: templates,
}, nil
}
// Compose creates a new email
func (m *MailClient) Compose() *mail {
return &mail{
client: m,
from: m.config.Mail.FromAddress,
}
}
// skipSend determines if mail sending should be skipped
func (m *MailClient) skipSend() bool {
return m.config.App.Environment != config.EnvProduction
}
// send attempts to send the email
func (m *MailClient) send(email *mail, ctx echo.Context) error {
switch {
case email.to == "":
return errors.New("email cannot be sent without a to address")
case email.body == "" && email.template == "":
return errors.New("email cannot be sent without a body or template")
}
// Check if a template was supplied
if email.template != "" {
// Parse and execute template
buf, err := m.templates.
Parse().
Group("mail").
Key(email.template).
Base(email.template).
Files(fmt.Sprintf("emails/%s", email.template)).
Execute(email.templateData)
if err != nil {
return err
}
email.body = buf.String()
}
// Check if mail sending should be skipped
if m.skipSend() {
ctx.Logger().Debugf("skipping email sent to: %s", email.to)
return nil
}
// TODO: Finish based on your mail sender of choice!
return nil
}
// From sets the email from address
func (m *mail) From(from string) *mail {
m.from = from
return m
}
// To sets the email address this email will be sent to
func (m *mail) To(to string) *mail {
m.to = to
return m
}
// Subject sets the subject line of the email
func (m *mail) Subject(subject string) *mail {
m.subject = subject
return m
}
// Body sets the body of the email
// This is not required and will be ignored if a template via Template()
func (m *mail) Body(body string) *mail {
m.body = body
return m
}
// Template sets the template to be used to produce the body of the email
// The template name should only include the filename without the extension or directory.
// The template must reside within the emails sub-directory.
// The funcmap will be automatically added to the template.
// Use TemplateData() to supply the data that will be passed in to the template.
func (m *mail) Template(template string) *mail {
m.template = template
return m
}
// TemplateData sets the data that will be passed to the template specified when calling Template()
func (m *mail) TemplateData(data interface{}) *mail {
m.templateData = data
return m
}
// Send attempts to send the email
func (m *mail) Send(ctx echo.Context) error {
return m.client.send(m, ctx)
}

View file

@ -0,0 +1,3 @@
package services
// Fill this in once you implement your mail client

View file

@ -0,0 +1,46 @@
package services
import (
"os"
"testing"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/ent"
"github.com/mikestefanello/pagoda/pkg/tests"
"github.com/labstack/echo/v4"
)
var (
c *Container
ctx echo.Context
usr *ent.User
)
func TestMain(m *testing.M) {
// Set the environment to test
config.SwitchEnvironment(config.EnvTest)
// Create a new container
c = NewContainer()
// Create a web context
ctx, _ = tests.NewContext(c.Web, "/")
tests.InitSession(ctx)
// Create a test user
var err error
if usr, err = tests.CreateUser(c.ORM); err != nil {
panic(err)
}
// Run tests
exitVal := m.Run()
// Shutdown the container
if err = c.Shutdown(); err != nil {
panic(err)
}
os.Exit(exitVal)
}

179
pkg/services/tasks.go Normal file
View file

@ -0,0 +1,179 @@
package services
import (
"encoding/json"
"fmt"
"time"
"github.com/hibiken/asynq"
"github.com/mikestefanello/pagoda/config"
)
type (
// TaskClient is that client that allows you to queue or schedule task execution
TaskClient struct {
// client stores the asynq client
client *asynq.Client
// scheduler stores the asynq scheduler
scheduler *asynq.Scheduler
}
// task handles task creation operations
task struct {
client *TaskClient
typ string
payload interface{}
periodic *string
queue *string
maxRetries *int
timeout *time.Duration
deadline *time.Time
at *time.Time
wait *time.Duration
retain *time.Duration
}
)
// NewTaskClient creates a new task client
func NewTaskClient(cfg *config.Config) *TaskClient {
// Determine the database based on the environment
db := cfg.Cache.Database
if cfg.App.Environment == config.EnvTest {
db = cfg.Cache.TestDatabase
}
conn := asynq.RedisClientOpt{
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Hostname, cfg.Cache.Port),
Password: cfg.Cache.Password,
DB: db,
}
return &TaskClient{
client: asynq.NewClient(conn),
scheduler: asynq.NewScheduler(conn, nil),
}
}
// Close closes the connection to the task service
func (t *TaskClient) Close() error {
return t.client.Close()
}
// StartScheduler starts the scheduler service which adds scheduled tasks to the queue
// This must be running in order to queue tasks set for periodic execution
func (t *TaskClient) StartScheduler() error {
return t.scheduler.Run()
}
// New starts a task creation operation
func (t *TaskClient) New(typ string) *task {
return &task{
client: t,
typ: typ,
}
}
// Payload sets the task payload data which will be sent to the task handler
func (t *task) Payload(payload interface{}) *task {
t.payload = payload
return t
}
// Periodic sets the task to execute periodically according to a given interval
// The interval can be either in cron form ("*/5 * * * *") or "@every 30s"
func (t *task) Periodic(interval string) *task {
t.periodic = &interval
return t
}
// Queue specifies the name of the queue to add the task to
// The default queue will be used if this is not set
func (t *task) Queue(queue string) *task {
t.queue = &queue
return t
}
// Timeout sets the task timeout, meaning the task must execute within a given duration
func (t *task) Timeout(timeout time.Duration) *task {
t.timeout = &timeout
return t
}
// Deadline sets the task execution deadline to a specific date and time
func (t *task) Deadline(deadline time.Time) *task {
t.deadline = &deadline
return t
}
// At sets the exact date and time the task should be executed
func (t *task) At(processAt time.Time) *task {
t.at = &processAt
return t
}
// Wait instructs the task to wait a given duration before it is executed
func (t *task) Wait(duration time.Duration) *task {
t.wait = &duration
return t
}
// Retain instructs the task service to retain the task data for a given duration after execution is complete
func (t *task) Retain(duration time.Duration) *task {
t.retain = &duration
return t
}
// MaxRetries sets the maximum amount of times to retry executing the task in the event of a failure
func (t *task) MaxRetries(retries int) *task {
t.maxRetries = &retries
return t
}
// Save saves the task so it can be executed
func (t *task) Save() error {
var err error
// Build the payload
var payload []byte
if t.payload != nil {
if payload, err = json.Marshal(t.payload); err != nil {
return err
}
}
// Build the task options
opts := make([]asynq.Option, 0)
if t.queue != nil {
opts = append(opts, asynq.Queue(*t.queue))
}
if t.maxRetries != nil {
opts = append(opts, asynq.MaxRetry(*t.maxRetries))
}
if t.timeout != nil {
opts = append(opts, asynq.Timeout(*t.timeout))
}
if t.deadline != nil {
opts = append(opts, asynq.Deadline(*t.deadline))
}
if t.wait != nil {
opts = append(opts, asynq.ProcessIn(*t.wait))
}
if t.retain != nil {
opts = append(opts, asynq.Retention(*t.retain))
}
if t.at != nil {
opts = append(opts, asynq.ProcessAt(*t.at))
}
// Build the task
task := asynq.NewTask(t.typ, payload, opts...)
// Schedule, if needed
if t.periodic != nil {
_, err = t.client.scheduler.Register(*t.periodic, task)
} else {
_, err = t.client.client.Enqueue(task)
}
return err
}

View file

@ -0,0 +1,35 @@
package services
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTaskClient_New(t *testing.T) {
now := time.Now()
tk := c.Tasks.
New("task1").
Payload("payload").
Queue("queue").
Periodic("@every 5s").
MaxRetries(5).
Timeout(5 * time.Second).
Deadline(now).
At(now).
Wait(6 * time.Second).
Retain(7 * time.Second)
assert.Equal(t, "task1", tk.typ)
assert.Equal(t, "payload", tk.payload.(string))
assert.Equal(t, "queue", *tk.queue)
assert.Equal(t, "@every 5s", *tk.periodic)
assert.Equal(t, 5, *tk.maxRetries)
assert.Equal(t, 5*time.Second, *tk.timeout)
assert.Equal(t, now, *tk.deadline)
assert.Equal(t, now, *tk.at)
assert.Equal(t, 6*time.Second, *tk.wait)
assert.Equal(t, 7*time.Second, *tk.retain)
assert.NoError(t, tk.Save())
}

View file

@ -0,0 +1,233 @@
package services
import (
"bytes"
"errors"
"fmt"
"html/template"
"path"
"path/filepath"
"runtime"
"sync"
"github.com/mikestefanello/pagoda/config"
"github.com/mikestefanello/pagoda/pkg/funcmap"
)
type (
// TemplateRenderer provides a flexible and easy to use method of rendering simple templates or complex sets of
// templates while also providing caching and/or hot-reloading depending on your current environment
TemplateRenderer struct {
// templateCache stores a cache of parsed page templates
templateCache sync.Map
// funcMap stores the template function map
funcMap template.FuncMap
// templatePath stores the complete path to the templates directory
templatesPath string
// config stores application configuration
config *config.Config
}
// TemplateParsed is a wrapper around parsed templates which are stored in the TemplateRenderer cache
TemplateParsed struct {
// Template is the parsed template
Template *template.Template
// build stores the build data used to parse the template
build *templateBuild
}
// templateBuild stores the build data used to parse a template
templateBuild struct {
group string
key string
base string
files []string
directories []string
}
// templateBuilder handles chaining a template parse operation
templateBuilder struct {
build *templateBuild
renderer *TemplateRenderer
}
)
// NewTemplateRenderer creates a new TemplateRenderer
func NewTemplateRenderer(cfg *config.Config) *TemplateRenderer {
t := &TemplateRenderer{
templateCache: sync.Map{},
funcMap: funcmap.GetFuncMap(),
config: cfg,
}
// Gets the complete templates directory path
// This is needed in case this is called from a package outside of main, such as within tests
_, b, _, _ := runtime.Caller(0)
d := path.Join(path.Dir(b))
t.templatesPath = filepath.Join(filepath.Dir(d), config.TemplateDir)
return t
}
// Parse creates a template build operation
func (t *TemplateRenderer) Parse() *templateBuilder {
return &templateBuilder{
renderer: t,
build: &templateBuild{},
}
}
// GetTemplatesPath gets the complete path to the templates directory
func (t *TemplateRenderer) GetTemplatesPath() string {
return t.templatesPath
}
// getCacheKey gets a cache key for a given group and ID
func (t *TemplateRenderer) getCacheKey(group, key string) string {
if group != "" {
return fmt.Sprintf("%s:%s", group, key)
}
return key
}
// parse parses a set of templates and caches them for quick execution
// If the application environment is set to local, the cache will be bypassed and templates will be
// parsed upon each request so hot-reloading is possible without restarts.
// Also included will be the function map provided by the funcmap package.
func (t *TemplateRenderer) parse(build *templateBuild) (*TemplateParsed, error) {
var tp *TemplateParsed
var err error
switch {
case build.key == "":
return nil, errors.New("cannot parse template without key")
case len(build.files) == 0 && len(build.directories) == 0:
return nil, errors.New("cannot parse template without files or directories")
case build.base == "":
return nil, errors.New("cannot parse template without base")
}
// Generate the cache key
cacheKey := t.getCacheKey(build.group, build.key)
// Check if the template has not yet been parsed or if the app environment is local, so that
// templates reflect changes without having the restart the server
if tp, err = t.Load(build.group, build.key); err != nil || t.config.App.Environment == config.EnvLocal {
// Initialize the parsed template with the function map
parsed := template.New(build.base + config.TemplateExt).
Funcs(t.funcMap)
// Parse all files provided
if len(build.files) > 0 {
for k, v := range build.files {
build.files[k] = fmt.Sprintf("%s/%s%s", t.templatesPath, v, config.TemplateExt)
}
parsed, err = parsed.ParseFiles(build.files...)
if err != nil {
return nil, err
}
}
// Parse all templates within the provided directories
for _, dir := range build.directories {
dir = fmt.Sprintf("%s/%s/*%s", t.templatesPath, dir, config.TemplateExt)
parsed, err = parsed.ParseGlob(dir)
if err != nil {
return nil, err
}
}
// Store the template so this process only happens once
tp = &TemplateParsed{
Template: parsed,
build: build,
}
t.templateCache.Store(cacheKey, tp)
}
return tp, nil
}
// Load loads a template from the cache
func (t *TemplateRenderer) Load(group, key string) (*TemplateParsed, error) {
load, ok := t.templateCache.Load(t.getCacheKey(group, key))
if !ok {
return nil, errors.New("uncached page template requested")
}
tmpl, ok := load.(*TemplateParsed)
if !ok {
return nil, errors.New("unable to cast cached template")
}
return tmpl, nil
}
// Execute executes a template with the given data and provides the output
func (t *TemplateParsed) Execute(data interface{}) (*bytes.Buffer, error) {
if t.Template == nil {
return nil, errors.New("cannot execute template: template not initialized")
}
buf := new(bytes.Buffer)
err := t.Template.ExecuteTemplate(buf, t.build.base+config.TemplateExt, data)
if err != nil {
return nil, err
}
return buf, nil
}
// Group sets the cache group for the template being built
func (t *templateBuilder) Group(group string) *templateBuilder {
t.build.group = group
return t
}
// Key sets the cache key for the template being built
func (t *templateBuilder) Key(key string) *templateBuilder {
t.build.key = key
return t
}
// Base sets the name of the base template to be used during template parsing and execution.
// This should be only the file name without a directory or extension.
func (t *templateBuilder) Base(base string) *templateBuilder {
t.build.base = base
return t
}
// Files sets a list of template files to include in the parse.
// This should not include the file extension and the paths should be relative to the templates directory.
func (t *templateBuilder) Files(files ...string) *templateBuilder {
t.build.files = files
return t
}
// Directories sets a list of directories that all template files within will be parsed.
// The paths should be relative to the templates directory.
func (t *templateBuilder) Directories(directories ...string) *templateBuilder {
t.build.directories = directories
return t
}
// Store parsed the templates and stores them in the cache
func (t *templateBuilder) Store() (*TemplateParsed, error) {
return t.renderer.parse(t.build)
}
// Execute executes the template with the given data.
// If the template has not already been cached, this will parse and cache the template
func (t *templateBuilder) Execute(data interface{}) (*bytes.Buffer, error) {
tp, err := t.Store()
if err != nil {
return nil, err
}
return tp.Execute(data)
}

View file

@ -0,0 +1,72 @@
package services
import (
"io/ioutil"
"testing"
"github.com/mikestefanello/pagoda/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTemplateRenderer(t *testing.T) {
group := "test"
id := "parse"
// Should not exist yet
_, err := c.TemplateRenderer.Load(group, id)
assert.Error(t, err)
// Parse in to the cache
tpl, err := c.TemplateRenderer.
Parse().
Group(group).
Key(id).
Base("htmx").
Files("htmx", "pages/error").
Directories("components").
Store()
require.NoError(t, err)
// Should exist now
parsed, err := c.TemplateRenderer.Load(group, id)
require.NoError(t, err)
// Check that all expected templates are included
expectedTemplates := make(map[string]bool)
expectedTemplates["htmx"+config.TemplateExt] = true
expectedTemplates["error"+config.TemplateExt] = true
components, err := ioutil.ReadDir(c.TemplateRenderer.GetTemplatesPath() + "/components")
require.NoError(t, err)
for _, f := range components {
expectedTemplates[f.Name()] = true
}
for _, v := range parsed.Template.Templates() {
delete(expectedTemplates, v.Name())
}
assert.Empty(t, expectedTemplates)
data := struct {
StatusCode int
}{
StatusCode: 500,
}
buf, err := tpl.Execute(data)
require.NoError(t, err)
require.NotNil(t, buf)
assert.Contains(t, buf.String(), "Please try again")
buf, err = c.TemplateRenderer.
Parse().
Group(group).
Key(id).
Base("htmx").
Files("htmx", "pages/error").
Directories("components").
Execute(data)
require.NoError(t, err)
require.NotNil(t, buf)
assert.Contains(t, buf.String(), "Please try again")
}

26
pkg/services/validator.go Normal file
View file

@ -0,0 +1,26 @@
package services
import (
"github.com/go-playground/validator/v10"
)
// Validator provides validation mainly validating structs within the web context
type Validator struct {
// validator stores the underlying validator
validator *validator.Validate
}
// NewValidator creats a new Validator
func NewValidator() *Validator {
return &Validator{
validator: validator.New(),
}
}
// Validate validates a struct
func (v *Validator) Validate(i interface{}) error {
if err := v.validator.Struct(i); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,19 @@
package services
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidator(t *testing.T) {
type example struct {
Value string `validate:"required"`
}
e := example{}
err := c.Validator.Validate(e)
assert.Error(t, err)
e.Value = "a"
err = c.Validator.Validate(e)
assert.NoError(t, err)
}