From 2c9cf2a21a9682f4a2958c614506273aacc8f05a Mon Sep 17 00:00:00 2001 From: mikestefanello <552328+mikestefanello@users.noreply.github.com> Date: Sat, 19 Apr 2025 16:13:12 -0400 Subject: [PATCH] Added admin bool field and middleware. --- ent/admin/handler.go | 5 ++++ ent/admin/types.go | 1 + ent/migrate/schema.go | 1 + ent/mutation.go | 56 +++++++++++++++++++++++++++++++++++++++++- ent/runtime/runtime.go | 6 ++++- ent/schema/user.go | 2 ++ ent/user.go | 13 +++++++++- ent/user/user.go | 10 ++++++++ ent/user/where.go | 15 +++++++++++ ent/user_create.go | 25 +++++++++++++++++++ ent/user_update.go | 34 +++++++++++++++++++++++++ pkg/handlers/admin.go | 4 +-- pkg/middleware/auth.go | 16 ++++++++++++ 13 files changed, 183 insertions(+), 5 deletions(-) diff --git a/ent/admin/handler.go b/ent/admin/handler.go index bf13d8a..670854d 100644 --- a/ent/admin/handler.go +++ b/ent/admin/handler.go @@ -194,6 +194,7 @@ func (h *Handler) UserCreate(ctx echo.Context) error { op.SetPassword(*payload.Password) } op.SetVerified(payload.Verified) + op.SetAdmin(payload.Admin) if payload.CreatedAt != nil { op.SetCreatedAt(*payload.CreatedAt) } @@ -219,6 +220,7 @@ func (h *Handler) UserUpdate(ctx echo.Context, id int) error { op.SetPassword(*payload.Password) } op.SetVerified(payload.Verified) + op.SetAdmin(payload.Admin) _, err = op.Save(ctx.Request().Context()) return err } @@ -246,6 +248,7 @@ func (h *Handler) UserList(ctx echo.Context) (*EntityList, error) { "Name", "Email", "Verified", + "Admin", "Created at", }, Entities: make([]EntityValues, 0, len(res)), @@ -260,6 +263,7 @@ func (h *Handler) UserList(ctx echo.Context) (*EntityList, error) { res[i].Name, res[i].Email, fmt.Sprint(res[i].Verified), + fmt.Sprint(res[i].Admin), res[i].CreatedAt.Format(h.Config.TimeFormat), }, }) @@ -278,6 +282,7 @@ func (h *Handler) UserGet(ctx echo.Context, id int) (url.Values, error) { v.Set("name", entity.Name) v.Set("email", entity.Email) v.Set("verified", fmt.Sprint(entity.Verified)) + v.Set("admin", fmt.Sprint(entity.Admin)) return v, err } diff --git a/ent/admin/types.go b/ent/admin/types.go index 48c4bc9..a786a92 100644 --- a/ent/admin/types.go +++ b/ent/admin/types.go @@ -14,6 +14,7 @@ type User struct { Email string `form:"email"` Password *string `form:"password"` Verified bool `form:"verified"` + Admin bool `form:"admin"` CreatedAt *time.Time `form:"created_at"` } diff --git a/ent/migrate/schema.go b/ent/migrate/schema.go index 825c99e..07dd0c7 100644 --- a/ent/migrate/schema.go +++ b/ent/migrate/schema.go @@ -36,6 +36,7 @@ var ( {Name: "email", Type: field.TypeString, Unique: true}, {Name: "password", Type: field.TypeString}, {Name: "verified", Type: field.TypeBool, Default: false}, + {Name: "admin", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime}, } // UsersTable holds the schema information for the "users" table. diff --git a/ent/mutation.go b/ent/mutation.go index 2167b5f..37d071d 100644 --- a/ent/mutation.go +++ b/ent/mutation.go @@ -530,6 +530,7 @@ type UserMutation struct { email *string password *string verified *bool + admin *bool created_at *time.Time clearedFields map[string]struct{} owner map[int]struct{} @@ -782,6 +783,42 @@ func (m *UserMutation) ResetVerified() { m.verified = nil } +// SetAdmin sets the "admin" field. +func (m *UserMutation) SetAdmin(b bool) { + m.admin = &b +} + +// Admin returns the value of the "admin" field in the mutation. +func (m *UserMutation) Admin() (r bool, exists bool) { + v := m.admin + if v == nil { + return + } + return *v, true +} + +// OldAdmin returns the old "admin" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldAdmin(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAdmin is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAdmin requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAdmin: %w", err) + } + return oldValue.Admin, nil +} + +// ResetAdmin resets all changes to the "admin" field. +func (m *UserMutation) ResetAdmin() { + m.admin = nil +} + // SetCreatedAt sets the "created_at" field. func (m *UserMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -906,7 +943,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 5) + fields := make([]string, 0, 6) if m.name != nil { fields = append(fields, user.FieldName) } @@ -919,6 +956,9 @@ func (m *UserMutation) Fields() []string { if m.verified != nil { fields = append(fields, user.FieldVerified) } + if m.admin != nil { + fields = append(fields, user.FieldAdmin) + } if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -938,6 +978,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.Password() case user.FieldVerified: return m.Verified() + case user.FieldAdmin: + return m.Admin() case user.FieldCreatedAt: return m.CreatedAt() } @@ -957,6 +999,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldPassword(ctx) case user.FieldVerified: return m.OldVerified(ctx) + case user.FieldAdmin: + return m.OldAdmin(ctx) case user.FieldCreatedAt: return m.OldCreatedAt(ctx) } @@ -996,6 +1040,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetVerified(v) return nil + case user.FieldAdmin: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdmin(v) + return nil case user.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -1064,6 +1115,9 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldVerified: m.ResetVerified() return nil + case user.FieldAdmin: + m.ResetAdmin() + return nil case user.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/ent/runtime/runtime.go b/ent/runtime/runtime.go index dc38f10..5b0401e 100644 --- a/ent/runtime/runtime.go +++ b/ent/runtime/runtime.go @@ -60,8 +60,12 @@ func init() { userDescVerified := userFields[3].Descriptor() // user.DefaultVerified holds the default value on creation for the verified field. user.DefaultVerified = userDescVerified.Default.(bool) + // userDescAdmin is the schema descriptor for admin field. + userDescAdmin := userFields[4].Descriptor() + // user.DefaultAdmin holds the default value on creation for the admin field. + user.DefaultAdmin = userDescAdmin.Default.(bool) // userDescCreatedAt is the schema descriptor for created_at field. - userDescCreatedAt := userFields[4].Descriptor() + userDescCreatedAt := userFields[5].Descriptor() // user.DefaultCreatedAt holds the default value on creation for the created_at field. user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time) } diff --git a/ent/schema/user.go b/ent/schema/user.go index 5e624b0..ac2d35e 100644 --- a/ent/schema/user.go +++ b/ent/schema/user.go @@ -37,6 +37,8 @@ func (User) Fields() []ent.Field { NotEmpty(), field.Bool("verified"). Default(false), + field.Bool("admin"). + Default(false), field.Time("created_at"). Default(time.Now). Immutable(), diff --git a/ent/user.go b/ent/user.go index 576a575..70de1e5 100644 --- a/ent/user.go +++ b/ent/user.go @@ -25,6 +25,8 @@ type User struct { Password string `json:"-"` // Verified holds the value of the "verified" field. Verified bool `json:"verified,omitempty"` + // Admin holds the value of the "admin" field. + Admin bool `json:"admin,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -56,7 +58,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case user.FieldVerified: + case user.FieldVerified, user.FieldAdmin: values[i] = new(sql.NullBool) case user.FieldID: values[i] = new(sql.NullInt64) @@ -109,6 +111,12 @@ func (u *User) assignValues(columns []string, values []any) error { } else if value.Valid { u.Verified = value.Bool } + case user.FieldAdmin: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field admin", values[i]) + } else if value.Valid { + u.Admin = value.Bool + } case user.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -167,6 +175,9 @@ func (u *User) String() string { builder.WriteString("verified=") builder.WriteString(fmt.Sprintf("%v", u.Verified)) builder.WriteString(", ") + builder.WriteString("admin=") + builder.WriteString(fmt.Sprintf("%v", u.Admin)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(u.CreatedAt.Format(time.ANSIC)) builder.WriteByte(')') diff --git a/ent/user/user.go b/ent/user/user.go index f516800..d0c7dc7 100644 --- a/ent/user/user.go +++ b/ent/user/user.go @@ -23,6 +23,8 @@ const ( FieldPassword = "password" // FieldVerified holds the string denoting the verified field in the database. FieldVerified = "verified" + // FieldAdmin holds the string denoting the admin field in the database. + FieldAdmin = "admin" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // EdgeOwner holds the string denoting the owner edge name in mutations. @@ -45,6 +47,7 @@ var Columns = []string{ FieldEmail, FieldPassword, FieldVerified, + FieldAdmin, FieldCreatedAt, } @@ -73,6 +76,8 @@ var ( PasswordValidator func(string) error // DefaultVerified holds the default value on creation for the "verified" field. DefaultVerified bool + // DefaultAdmin holds the default value on creation for the "admin" field. + DefaultAdmin bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time ) @@ -105,6 +110,11 @@ func ByVerified(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldVerified, opts...).ToFunc() } +// ByAdmin orders the results by the admin field. +func ByAdmin(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdmin, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/ent/user/where.go b/ent/user/where.go index 5ec1168..064afdb 100644 --- a/ent/user/where.go +++ b/ent/user/where.go @@ -75,6 +75,11 @@ func Verified(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldVerified, v)) } +// Admin applies equality check predicate on the "admin" field. It's identical to AdminEQ. +func Admin(v bool) predicate.User { + return predicate.User(sql.FieldEQ(FieldAdmin, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -285,6 +290,16 @@ func VerifiedNEQ(v bool) predicate.User { return predicate.User(sql.FieldNEQ(FieldVerified, v)) } +// AdminEQ applies the EQ predicate on the "admin" field. +func AdminEQ(v bool) predicate.User { + return predicate.User(sql.FieldEQ(FieldAdmin, v)) +} + +// AdminNEQ applies the NEQ predicate on the "admin" field. +func AdminNEQ(v bool) predicate.User { + return predicate.User(sql.FieldNEQ(FieldAdmin, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/ent/user_create.go b/ent/user_create.go index 6f98b46..dfca9e6 100644 --- a/ent/user_create.go +++ b/ent/user_create.go @@ -53,6 +53,20 @@ func (uc *UserCreate) SetNillableVerified(b *bool) *UserCreate { return uc } +// SetAdmin sets the "admin" field. +func (uc *UserCreate) SetAdmin(b bool) *UserCreate { + uc.mutation.SetAdmin(b) + return uc +} + +// SetNillableAdmin sets the "admin" field if the given value is not nil. +func (uc *UserCreate) SetNillableAdmin(b *bool) *UserCreate { + if b != nil { + uc.SetAdmin(*b) + } + return uc +} + // SetCreatedAt sets the "created_at" field. func (uc *UserCreate) SetCreatedAt(t time.Time) *UserCreate { uc.mutation.SetCreatedAt(t) @@ -123,6 +137,10 @@ func (uc *UserCreate) defaults() error { v := user.DefaultVerified uc.mutation.SetVerified(v) } + if _, ok := uc.mutation.Admin(); !ok { + v := user.DefaultAdmin + uc.mutation.SetAdmin(v) + } if _, ok := uc.mutation.CreatedAt(); !ok { if user.DefaultCreatedAt == nil { return fmt.Errorf("ent: uninitialized user.DefaultCreatedAt (forgotten import ent/runtime?)") @@ -162,6 +180,9 @@ func (uc *UserCreate) check() error { if _, ok := uc.mutation.Verified(); !ok { return &ValidationError{Name: "verified", err: errors.New(`ent: missing required field "User.verified"`)} } + if _, ok := uc.mutation.Admin(); !ok { + return &ValidationError{Name: "admin", err: errors.New(`ent: missing required field "User.admin"`)} + } if _, ok := uc.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "User.created_at"`)} } @@ -207,6 +228,10 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldVerified, field.TypeBool, value) _node.Verified = value } + if value, ok := uc.mutation.Admin(); ok { + _spec.SetField(user.FieldAdmin, field.TypeBool, value) + _node.Admin = value + } if value, ok := uc.mutation.CreatedAt(); ok { _spec.SetField(user.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value diff --git a/ent/user_update.go b/ent/user_update.go index 827ac93..23dc8a5 100644 --- a/ent/user_update.go +++ b/ent/user_update.go @@ -84,6 +84,20 @@ func (uu *UserUpdate) SetNillableVerified(b *bool) *UserUpdate { return uu } +// SetAdmin sets the "admin" field. +func (uu *UserUpdate) SetAdmin(b bool) *UserUpdate { + uu.mutation.SetAdmin(b) + return uu +} + +// SetNillableAdmin sets the "admin" field if the given value is not nil. +func (uu *UserUpdate) SetNillableAdmin(b *bool) *UserUpdate { + if b != nil { + uu.SetAdmin(*b) + } + return uu +} + // AddOwnerIDs adds the "owner" edge to the PasswordToken entity by IDs. func (uu *UserUpdate) AddOwnerIDs(ids ...int) *UserUpdate { uu.mutation.AddOwnerIDs(ids...) @@ -196,6 +210,9 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := uu.mutation.Verified(); ok { _spec.SetField(user.FieldVerified, field.TypeBool, value) } + if value, ok := uu.mutation.Admin(); ok { + _spec.SetField(user.FieldAdmin, field.TypeBool, value) + } if uu.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -317,6 +334,20 @@ func (uuo *UserUpdateOne) SetNillableVerified(b *bool) *UserUpdateOne { return uuo } +// SetAdmin sets the "admin" field. +func (uuo *UserUpdateOne) SetAdmin(b bool) *UserUpdateOne { + uuo.mutation.SetAdmin(b) + return uuo +} + +// SetNillableAdmin sets the "admin" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableAdmin(b *bool) *UserUpdateOne { + if b != nil { + uuo.SetAdmin(*b) + } + return uuo +} + // AddOwnerIDs adds the "owner" edge to the PasswordToken entity by IDs. func (uuo *UserUpdateOne) AddOwnerIDs(ids ...int) *UserUpdateOne { uuo.mutation.AddOwnerIDs(ids...) @@ -459,6 +490,9 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) if value, ok := uuo.mutation.Verified(); ok { _spec.SetField(user.FieldVerified, field.TypeBool, value) } + if value, ok := uuo.mutation.Admin(); ok { + _spec.SetField(user.FieldAdmin, field.TypeBool, value) + } if uuo.mutation.OwnerCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/pkg/handlers/admin.go b/pkg/handlers/admin.go index a612b0e..c01ab9b 100644 --- a/pkg/handlers/admin.go +++ b/pkg/handlers/admin.go @@ -12,6 +12,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mikestefanello/pagoda/ent" "github.com/mikestefanello/pagoda/ent/admin" + "github.com/mikestefanello/pagoda/pkg/middleware" "github.com/mikestefanello/pagoda/pkg/msg" "github.com/mikestefanello/pagoda/pkg/pager" "github.com/mikestefanello/pagoda/pkg/redirect" @@ -46,8 +47,7 @@ func (h *Admin) Init(c *services.Container) error { } func (h *Admin) Routes(g *echo.Group) { - // TODO admin user status middleware - entities := g.Group("/admin/content") + entities := g.Group("/admin/content", middleware.RequireAdmin) for _, n := range h.graph.Nodes { ng := entities.Group(fmt.Sprintf("/%s", strings.ToLower(n.Name))) diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index 542d1b9..d06391f 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -103,3 +103,19 @@ func RequireNoAuthentication(next echo.HandlerFunc) echo.HandlerFunc { return next(c) } } + +// RequireAdmin requires that the user be an admin in order to proceed. +func RequireAdmin(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if u := c.Get(context.AuthenticatedUserKey); u != nil { + if user, ok := u.(*ent.User); ok { + if user.Admin { + // TODO tests + return next(c) + } + } + } + + return echo.NewHTTPError(http.StatusUnauthorized) + } +}