diff --git a/ent/admin/extension.go b/ent/admin/extension.go index c9f5a6a..556b577 100644 --- a/ent/admin/extension.go +++ b/ent/admin/extension.go @@ -8,6 +8,7 @@ import ( "entgo.io/ent/entc" "entgo.io/ent/entc/gen" + "entgo.io/ent/schema/field" ) var ( @@ -24,8 +25,9 @@ func (*Extension) Templates() []*gen.Template { gen.MustParse( gen.NewTemplate("admin"). Funcs(template.FuncMap{ - "fieldName": fieldName, - "fieldLabel": fieldLabel, + "fieldName": fieldName, + "fieldLabel": fieldLabel, + "fieldIsPointer": fieldIsPointer, }). ParseFS(templateDir, "templates/*tmpl"), ), @@ -58,6 +60,17 @@ func fieldLabel(name string) string { return upperFirst(out) } +func fieldIsPointer(f *gen.Field) bool { + switch { + case f.Type.Type == field.TypeBool: + return false + case f.Optional, + f.Default: + return true + } + return false +} + func upperFirst(s string) string { if len(s) == 0 { return s diff --git a/ent/admin/handler.go b/ent/admin/handler.go index fdcbc50..666c347 100644 --- a/ent/admin/handler.go +++ b/ent/admin/handler.go @@ -90,8 +90,9 @@ func (h *Handler) PasswordTokenCreate(ctx echo.Context) error { op := h.client.PasswordToken.Create() op.SetHash(payload.Hash) op.SetUserID(payload.UserID) - op.SetCreatedAt(payload.CreatedAt) - // op.SetUserID(payload.User) + if payload.CreatedAt != nil { + op.SetCreatedAt(*payload.CreatedAt) + } _, err := op.Save(ctx.Request().Context()) return err } @@ -110,8 +111,9 @@ func (h *Handler) PasswordTokenUpdate(ctx echo.Context, id int) error { op := entity.Update() op.SetHash(payload.Hash) op.SetUserID(payload.UserID) - op.SetCreatedAt(payload.CreatedAt) - // op.SetUserID(payload.User) + if payload.CreatedAt != nil { + op.SetCreatedAt(*payload.CreatedAt) + } _, err = op.Save(ctx.Request().Context()) return err } @@ -176,7 +178,9 @@ func (h *Handler) UserCreate(ctx echo.Context) error { op.SetEmail(payload.Email) op.SetPassword(payload.Password) op.SetVerified(payload.Verified) - op.SetCreatedAt(payload.CreatedAt) + if payload.CreatedAt != nil { + op.SetCreatedAt(*payload.CreatedAt) + } _, err := op.Save(ctx.Request().Context()) return err } diff --git a/ent/admin/templates/handler.tmpl b/ent/admin/templates/handler.tmpl index c5dbc66..d6fc38c 100644 --- a/ent/admin/templates/handler.tmpl +++ b/ent/admin/templates/handler.tmpl @@ -96,11 +96,12 @@ op := h.client.{{ $n.Name }}.Create() {{- range $f := $n.Fields }} - op.Set{{ fieldName $f.Name }}(payload.{{ fieldName $f.Name }}) - {{- end }} - {{- range $e := $n.Edges }} - {{- if not $e.Inverse}} - // op.Set{{ fieldName $e.Name }}ID(payload.{{ fieldName $e.Name }}) + {{- if (fieldIsPointer $f) }} + if payload.{{ fieldName $f.Name }} != nil { + op.Set{{ fieldName $f.Name }}(*payload.{{ fieldName $f.Name }}) + } + {{- else }} + op.Set{{ fieldName $f.Name }}(payload.{{ fieldName $f.Name }}) {{- end }} {{- end }} _, err := op.Save(ctx.Request().Context()) @@ -121,12 +122,13 @@ op := entity.Update() {{- range $f := $n.Fields }} {{- if not $f.Immutable }} - op.Set{{ fieldName $f.Name }}(payload.{{ fieldName $f.Name }}) - {{- end }} - {{- end }} - {{- range $e := $n.Edges }} - {{- if not $e.Inverse}} - // op.Set{{ fieldName $e.Name }}ID(payload.{{ fieldName $e.Name }}) + {{- if (fieldIsPointer $f) }} + if payload.{{ fieldName $f.Name }} != nil { + op.Set{{ fieldName $f.Name }}(*payload.{{ fieldName $f.Name }}) + } + {{- else }} + op.Set{{ fieldName $f.Name }}(payload.{{ fieldName $f.Name }}) + {{- end }} {{- end }} {{- end }} _, err = op.Save(ctx.Request().Context()) diff --git a/ent/admin/templates/types.tmpl b/ent/admin/templates/types.tmpl index 5adea0d..86fca5f 100644 --- a/ent/admin/templates/types.tmpl +++ b/ent/admin/templates/types.tmpl @@ -8,7 +8,7 @@ {{ range $n := $.Nodes }} type {{ $n.Name }} struct { {{- range $f := $n.Fields }} - {{ fieldName $f.Name }} {{ $f.Type }} `form:"{{ $f.Name }}"` + {{ fieldName $f.Name }} {{ if (fieldIsPointer $f) }}*{{ end }}{{ $f.Type }} `form:"{{ $f.Name }}"` {{- end }} } {{ end }} diff --git a/ent/admin/types.go b/ent/admin/types.go index 7767fb8..833ec71 100644 --- a/ent/admin/types.go +++ b/ent/admin/types.go @@ -4,17 +4,17 @@ package admin import "time" type PasswordToken struct { - Hash string `form:"hash"` - UserID int `form:"user_id"` - CreatedAt time.Time `form:"created_at"` + Hash string `form:"hash"` + UserID int `form:"user_id"` + CreatedAt *time.Time `form:"created_at"` } type User struct { - Name string `form:"name"` - Email string `form:"email"` - Password string `form:"password"` - Verified bool `form:"verified"` - CreatedAt time.Time `form:"created_at"` + Name string `form:"name"` + Email string `form:"email"` + Password string `form:"password"` + Verified bool `form:"verified"` + CreatedAt *time.Time `form:"created_at"` } type EntityList struct {