diff --git a/container/auth_test.go b/container/auth_test.go index cb984ae..68b8b99 100644 --- a/container/auth_test.go +++ b/container/auth_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "goweb/ent/passwordtoken" "goweb/ent/user" @@ -65,25 +66,42 @@ func TestGeneratePasswordResetToken(t *testing.T) { } func TestGetValidPasswordToken(t *testing.T) { + // Check that a fake token is not valid _, err := c.Auth.GetValidPasswordToken(ctx, "faketoken", usr.ID) assert.Error(t, err) + // Generate a valid token and check that it is returned token, pt, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID) require.NoError(t, err) pt2, err := c.Auth.GetValidPasswordToken(ctx, token, usr.ID) require.NoError(t, err) assert.Equal(t, pt.ID, pt2.ID) + + // Expire the token by pushed the date far enough back + _, err = c.ORM.PasswordToken. + Update(). + SetCreatedAt(time.Now().Add(-(c.Config.App.PasswordToken.Expiration + 10))). + Where(passwordtoken.ID(pt.ID)). + Save(context.Background()) + require.NoError(t, err) + + // Expired tokens should not be valid + _, err = c.Auth.GetValidPasswordToken(ctx, token, usr.ID) + assert.Error(t, err) } func TestDeletePasswordTokens(t *testing.T) { + // Create three tokens for the user for i := 0; i < 3; i++ { _, _, err := c.Auth.GeneratePasswordResetToken(ctx, usr.ID) require.NoError(t, err) } + // Delete all tokens for the user err := c.Auth.DeletePasswordTokens(ctx, usr.ID) require.NoError(t, err) + // Check that no tokens remain count, err := c.ORM.PasswordToken. Query(). Where(passwordtoken.HasUserWith(user.ID(usr.ID))). diff --git a/ent/passwordtoken_update.go b/ent/passwordtoken_update.go index 9f4e3ff..9dc5cad 100644 --- a/ent/passwordtoken_update.go +++ b/ent/passwordtoken_update.go @@ -9,6 +9,7 @@ import ( "goweb/ent/passwordtoken" "goweb/ent/predicate" "goweb/ent/user" + "time" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -34,6 +35,20 @@ func (ptu *PasswordTokenUpdate) SetHash(s string) *PasswordTokenUpdate { return ptu } +// SetCreatedAt sets the "created_at" field. +func (ptu *PasswordTokenUpdate) SetCreatedAt(t time.Time) *PasswordTokenUpdate { + ptu.mutation.SetCreatedAt(t) + return ptu +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (ptu *PasswordTokenUpdate) SetNillableCreatedAt(t *time.Time) *PasswordTokenUpdate { + if t != nil { + ptu.SetCreatedAt(*t) + } + return ptu +} + // SetUserID sets the "user" edge to the User entity by ID. func (ptu *PasswordTokenUpdate) SetUserID(id int) *PasswordTokenUpdate { ptu.mutation.SetUserID(id) @@ -154,6 +169,13 @@ func (ptu *PasswordTokenUpdate) sqlSave(ctx context.Context) (n int, err error) Column: passwordtoken.FieldHash, }) } + if value, ok := ptu.mutation.CreatedAt(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeTime, + Value: value, + Column: passwordtoken.FieldCreatedAt, + }) + } if ptu.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -214,6 +236,20 @@ func (ptuo *PasswordTokenUpdateOne) SetHash(s string) *PasswordTokenUpdateOne { return ptuo } +// SetCreatedAt sets the "created_at" field. +func (ptuo *PasswordTokenUpdateOne) SetCreatedAt(t time.Time) *PasswordTokenUpdateOne { + ptuo.mutation.SetCreatedAt(t) + return ptuo +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (ptuo *PasswordTokenUpdateOne) SetNillableCreatedAt(t *time.Time) *PasswordTokenUpdateOne { + if t != nil { + ptuo.SetCreatedAt(*t) + } + return ptuo +} + // SetUserID sets the "user" edge to the User entity by ID. func (ptuo *PasswordTokenUpdateOne) SetUserID(id int) *PasswordTokenUpdateOne { ptuo.mutation.SetUserID(id) @@ -358,6 +394,13 @@ func (ptuo *PasswordTokenUpdateOne) sqlSave(ctx context.Context) (_node *Passwor Column: passwordtoken.FieldHash, }) } + if value, ok := ptuo.mutation.CreatedAt(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeTime, + Value: value, + Column: passwordtoken.FieldCreatedAt, + }) + } if ptuo.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/ent/schema/passwordtoken.go b/ent/schema/passwordtoken.go index 3384172..e518d16 100644 --- a/ent/schema/passwordtoken.go +++ b/ent/schema/passwordtoken.go @@ -20,8 +20,7 @@ func (PasswordToken) Fields() []ent.Field { Sensitive(). NotEmpty(), field.Time("created_at"). - Default(time.Now). - Immutable(), + Default(time.Now), } }