From 60dedc09444785c401a35fc5537fa981efedef73 Mon Sep 17 00:00:00 2001 From: mikestefanello Date: Mon, 13 Dec 2021 12:51:00 -0500 Subject: [PATCH] Support separate database for testing. --- Makefile | 4 +++ config/config.go | 12 ++++--- container/container.go | 64 ++++++++++++++++++++++++--------- controllers/controllers_test.go | 8 +++-- main.go | 11 ------ 5 files changed, 65 insertions(+), 34 deletions(-) diff --git a/Makefile b/Makefile index c98bc93..c40d451 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,10 @@ pg: psql postgresql://admin:admin@localhost:5432/app +.PHONY: pg-test +pg-test: + psql postgresql://admin:admin@localhost:5432/app_test + .PHONY: ent-gen ent-gen: go generate ./ent diff --git a/config/config.go b/config/config.go index 393cbc4..d5d3cfe 100644 --- a/config/config.go +++ b/config/config.go @@ -17,6 +17,7 @@ type Env string const ( EnvLocal Env = "local" + EnvTest Env = "test" EnvDevelop Env = "dev" EnvStaging Env = "staging" EnvQA Env = "qa" @@ -60,11 +61,12 @@ type ( } DatabaseConfig struct { - Hostname string `env:"DB_HOSTNAME,default=localhost"` - Port uint16 `env:"DB_PORT,default=5432"` - User string `env:"DB_USER,default=admin"` - Password string `env:"DB_PASSWORD,default=admin"` - Database string `env:"DB_NAME,default=app"` + Hostname string `env:"DB_HOSTNAME,default=localhost"` + Port uint16 `env:"DB_PORT,default=5432"` + User string `env:"DB_USER,default=admin"` + Password string `env:"DB_PASSWORD,default=admin"` + Database string `env:"DB_NAME,default=app"` + TestDatabase string `env:"DB_NAME_TEST,default=app_test"` } ) diff --git a/container/container.go b/container/container.go index f942d4c..c31d09c 100644 --- a/container/container.go +++ b/container/container.go @@ -12,6 +12,7 @@ import ( "github.com/go-redis/redis/v8" _ "github.com/jackc/pgx/v4/stdlib" "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" "goweb/config" "goweb/ent" @@ -27,33 +28,41 @@ type Container struct { func NewContainer() *Container { c := new(Container) - c.initWeb() c.initConfig() + c.initWeb() c.initCache() c.initDatabase() c.initORM() return c } -func (c *Container) initWeb() { - c.Web = echo.New() -} - func (c *Container) initConfig() { cfg, err := config.GetConfig() if err != nil { - c.Web.Logger.Fatalf("failed to load configuration: %v", err) + panic(fmt.Sprintf("failed to load config: %v", err)) } c.Config = &cfg } +func (c *Container) initWeb() { + c.Web = echo.New() + + // Configure logging + switch c.Config.App.Environment { + case config.EnvProduction: + c.Web.Logger.SetLevel(log.WARN) + default: + c.Web.Logger.SetLevel(log.DEBUG) + } +} + func (c *Container) initCache() { cacheClient := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", c.Config.Cache.Hostname, c.Config.Cache.Port), Password: c.Config.Cache.Password, }) if _, err := cacheClient.Ping(context.Background()).Result(); err != nil { - c.Web.Logger.Fatalf("failed to connect to cache server: %v", err) + panic(fmt.Sprintf("failed to connect to cache server: %v", err)) } cacheStore := store.NewRedis(cacheClient, nil) c.Cache = cache.New(cacheStore) @@ -62,15 +71,38 @@ func (c *Container) initCache() { func (c *Container) initDatabase() { var err error - addr := fmt.Sprintf("postgresql://%s:%s@%s/%s", - c.Config.Database.User, - c.Config.Database.Password, - c.Config.Database.Hostname, - c.Config.Database.Database, - ) - c.Database, err = sql.Open("pgx", addr) + getAddr := func(dbName string) string { + return fmt.Sprintf("postgresql://%s:%s@%s/%s", + c.Config.Database.User, + c.Config.Database.Password, + c.Config.Database.Hostname, + dbName, + ) + } + + c.Database, err = sql.Open("pgx", getAddr(c.Config.Database.Database)) if err != nil { - c.Web.Logger.Fatalf("failed to connect to database: %v", err) + panic(fmt.Sprintf("failed to connect to database: %v", err)) + } + + // Check if this is a test environment + if c.Config.App.Environment == config.EnvTest { + // Drop the test database, ignoring errors in case it doesn't yet exist + _, _ = c.Database.Exec("DROP DATABASE " + c.Config.Database.TestDatabase) + + // Create the test database + if _, err = c.Database.Exec("CREATE DATABASE " + c.Config.Database.TestDatabase); err != nil { + panic(fmt.Sprintf("failed to create test database: %v", err)) + } + + // Connect to the test database + if err = c.Database.Close(); err != nil { + panic(fmt.Sprintf("failed to close database connection: %v", err)) + } + c.Database, err = sql.Open("pgx", getAddr(c.Config.Database.TestDatabase)) + if err != nil { + panic(fmt.Sprintf("failed to connect to database: %v", err)) + } } } @@ -78,6 +110,6 @@ func (c *Container) initORM() { drv := entsql.OpenDB(dialect.Postgres, c.Database) c.ORM = ent.NewClient(ent.Driver(drv)) if err := c.ORM.Schema.Create(context.Background()); err != nil { - c.Web.Logger.Fatalf("failed to create database schema: %v", err) + panic(fmt.Sprintf("failed to create database schema: %v", err)) } } diff --git a/controllers/controllers_test.go b/controllers/controllers_test.go index d93a490..bc8db84 100644 --- a/controllers/controllers_test.go +++ b/controllers/controllers_test.go @@ -7,12 +7,12 @@ import ( "os" "testing" + "goweb/config" "goweb/container" "github.com/PuerkitoBio/goquery" "github.com/stretchr/testify/assert" - "github.com/labstack/gommon/log" "github.com/stretchr/testify/require" ) @@ -22,10 +22,14 @@ var ( ) func TestMain(m *testing.M) { + // Set the environment to test + if err := os.Setenv("APP_ENV", string(config.EnvTest)); err != nil { + panic(err) + } + // Start a test HTTP server c = container.NewContainer() BuildRouter(c) - c.Web.Logger.SetLevel(log.DEBUG) srv = httptest.NewServer(c.Web) exitVal := m.Run() diff --git a/main.go b/main.go index bfbe12e..1c069a8 100644 --- a/main.go +++ b/main.go @@ -8,24 +8,13 @@ import ( "os/signal" "time" - "goweb/config" "goweb/container" "goweb/controllers" - - "github.com/labstack/gommon/log" ) func main() { c := container.NewContainer() - // Configure logging - switch c.Config.App.Environment { - case config.EnvProduction: - c.Web.Logger.SetLevel(log.WARN) - default: - c.Web.Logger.SetLevel(log.DEBUG) - } - // Build the router controllers.BuildRouter(c)