385 lines
11 KiB
Go
385 lines
11 KiB
Go
// Path: Backend/internal/handlers/auth/auth.go
|
|
|
|
package auth
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"boostai-backend/internal/config"
|
|
"boostai-backend/internal/database"
|
|
"boostai-backend/internal/http/respond"
|
|
authmw "boostai-backend/internal/middleware"
|
|
"boostai-backend/internal/sqlc"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
const authQueryTimeout = 5 * time.Second
|
|
|
|
type Handler struct {
|
|
db *database.DB
|
|
queries *sqlc.Queries
|
|
cfg *config.Config
|
|
auth *authmw.AuthMiddleware
|
|
}
|
|
|
|
type authProfileResponse struct {
|
|
PreferredName *string `json:"preferred_name"`
|
|
ProfileIconURL *string `json:"profile_icon_url"`
|
|
Headline *string `json:"headline"`
|
|
Bio *string `json:"bio"`
|
|
Timezone *string `json:"timezone"`
|
|
Locale *string `json:"locale"`
|
|
GradeLevel *string `json:"grade_level"`
|
|
LearningGoal *string `json:"learning_goal"`
|
|
CreatedAt *time.Time `json:"created_at"`
|
|
UpdatedAt *time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type authUserResponse struct {
|
|
ID int64 `json:"id"`
|
|
Email string `json:"email"`
|
|
Role string `json:"role"`
|
|
FullName string `json:"full_name"`
|
|
IsActive bool `json:"is_active"`
|
|
CreatedAt *time.Time `json:"created_at,omitempty"`
|
|
UpdatedAt *time.Time `json:"updated_at,omitempty"`
|
|
Profile authProfileResponse `json:"profile"`
|
|
}
|
|
|
|
type authResponse struct {
|
|
User authUserResponse `json:"user"`
|
|
}
|
|
|
|
type registerRequest struct {
|
|
FirstName string `json:"first_name"`
|
|
LastName string `json:"last_name"`
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
Role string `json:"role"`
|
|
}
|
|
|
|
type loginRequest struct {
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
RememberMe bool `json:"remember_me"`
|
|
}
|
|
|
|
type updateProfileRequest struct {
|
|
FullName *string `json:"full_name"`
|
|
PreferredName *string `json:"preferred_name"`
|
|
ProfileIconURL *string `json:"profile_icon_url"`
|
|
Headline *string `json:"headline"`
|
|
Bio *string `json:"bio"`
|
|
Timezone *string `json:"timezone"`
|
|
Locale *string `json:"locale"`
|
|
GradeLevel *string `json:"grade_level"`
|
|
LearningGoal *string `json:"learning_goal"`
|
|
}
|
|
|
|
func NewHandler(cfg *config.Config, db *database.DB, auth *authmw.AuthMiddleware) *Handler {
|
|
return &Handler{db: db, queries: sqlc.New(db.Pool), cfg: cfg, auth: auth}
|
|
}
|
|
|
|
func (h *Handler) RegisterUser(c *fiber.Ctx) error {
|
|
var req registerRequest
|
|
if err := c.BodyParser(&req); err != nil {
|
|
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
|
|
}
|
|
|
|
fullName := strings.TrimSpace(strings.TrimSpace(req.FirstName) + " " + strings.TrimSpace(req.LastName))
|
|
if fullName == "" || strings.TrimSpace(req.Email) == "" || strings.TrimSpace(req.Password) == "" {
|
|
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "first_name, last_name, email, and password are required")
|
|
}
|
|
|
|
role := sqlc.UserRoleStudent
|
|
if strings.TrimSpace(req.Role) != "" {
|
|
role = sqlc.UserRole(strings.TrimSpace(req.Role))
|
|
}
|
|
if role != sqlc.UserRoleStudent && role != sqlc.UserRoleTeacher {
|
|
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "role must be student or teacher")
|
|
}
|
|
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to secure password")
|
|
}
|
|
|
|
ctx, cancel := withTimeout()
|
|
defer cancel()
|
|
|
|
user, err := h.queries.CreateUser(ctx, sqlc.CreateUserParams{
|
|
Email: strings.TrimSpace(strings.ToLower(req.Email)),
|
|
PasswordHash: pgtype.Text{String: string(hashedPassword), Valid: true},
|
|
Role: role,
|
|
FullName: fullName,
|
|
})
|
|
if err != nil {
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
if err := h.setSessionCookie(c, user, false); err != nil {
|
|
return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to create session")
|
|
}
|
|
|
|
authUser, err := h.queries.GetAuthUserByID(ctx, user.ID)
|
|
if err != nil {
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
return c.Status(fiber.StatusCreated).JSON(authResponse{User: mapAuthUserByID(authUser)})
|
|
}
|
|
|
|
func (h *Handler) Login(c *fiber.Ctx) error {
|
|
var req loginRequest
|
|
if err := c.BodyParser(&req); err != nil {
|
|
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
|
|
}
|
|
|
|
if strings.TrimSpace(req.Email) == "" || strings.TrimSpace(req.Password) == "" {
|
|
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "email and password are required")
|
|
}
|
|
|
|
ctx, cancel := withTimeout()
|
|
defer cancel()
|
|
|
|
user, err := h.queries.GetUserByEmail(ctx, strings.TrimSpace(strings.ToLower(req.Email)))
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password")
|
|
}
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
if !user.IsActive || !user.PasswordHash.Valid {
|
|
return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password")
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash.String), []byte(req.Password)); err != nil {
|
|
return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password")
|
|
}
|
|
|
|
if err := h.setSessionCookie(c, user, req.RememberMe); err != nil {
|
|
return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to create session")
|
|
}
|
|
|
|
authUser, err := h.queries.GetAuthUserByID(ctx, user.ID)
|
|
if err != nil {
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
return c.JSON(authResponse{User: mapAuthUserByID(authUser)})
|
|
}
|
|
|
|
func (h *Handler) Me(c *fiber.Ctx) error {
|
|
userID := authmw.CurrentUserID(c)
|
|
if userID == 0 {
|
|
return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "Authentication required")
|
|
}
|
|
|
|
ctx, cancel := withTimeout()
|
|
defer cancel()
|
|
|
|
user, err := h.queries.GetAuthUserByID(ctx, userID)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "User not found")
|
|
}
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
return c.JSON(authResponse{User: mapAuthUserByID(user)})
|
|
}
|
|
|
|
func (h *Handler) UpdateMe(c *fiber.Ctx) error {
|
|
userID := authmw.CurrentUserID(c)
|
|
if userID == 0 {
|
|
return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "Authentication required")
|
|
}
|
|
|
|
var req updateProfileRequest
|
|
if err := c.BodyParser(&req); err != nil {
|
|
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
|
|
}
|
|
|
|
ctx, cancel := withTimeout()
|
|
defer cancel()
|
|
|
|
tx, err := h.db.Pool.Begin(ctx)
|
|
if err != nil {
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
defer func() {
|
|
_ = tx.Rollback(ctx)
|
|
}()
|
|
|
|
queries := h.queries.WithTx(tx)
|
|
current, err := queries.GetAuthUserByID(ctx, userID)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "User not found")
|
|
}
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
fullName, err := mergeRequiredString(current.UserFullName, req.FullName, "full_name")
|
|
if err != nil {
|
|
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", err.Error())
|
|
}
|
|
|
|
if fullName != current.UserFullName {
|
|
if _, err := queries.UpdateUserFullName(ctx, sqlc.UpdateUserFullNameParams{ID: userID, FullName: fullName}); err != nil {
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
}
|
|
|
|
if _, err := queries.UpsertUserProfile(ctx, sqlc.UpsertUserProfileParams{
|
|
UserID: userID,
|
|
PreferredName: mergeNullableText(current.PreferredName, req.PreferredName),
|
|
ProfileIconUrl: mergeNullableText(current.ProfileIconUrl, req.ProfileIconURL),
|
|
Headline: mergeNullableText(current.Headline, req.Headline),
|
|
Bio: mergeNullableText(current.Bio, req.Bio),
|
|
Timezone: mergeNullableText(current.Timezone, req.Timezone),
|
|
Locale: mergeNullableText(current.Locale, req.Locale),
|
|
GradeLevel: mergeNullableText(current.GradeLevel, req.GradeLevel),
|
|
LearningGoal: mergeNullableText(current.LearningGoal, req.LearningGoal),
|
|
}); err != nil {
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
updated, err := queries.GetAuthUserByID(ctx, userID)
|
|
if err != nil {
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return respond.DatabaseError(c, err)
|
|
}
|
|
|
|
return c.JSON(authResponse{User: mapAuthUserByID(updated)})
|
|
}
|
|
|
|
func (h *Handler) Logout(c *fiber.Ctx) error {
|
|
h.clearSessionCookie(c)
|
|
return c.JSON(fiber.Map{"status": "ok"})
|
|
}
|
|
|
|
func (h *Handler) setSessionCookie(c *fiber.Ctx, user sqlc.User, rememberMe bool) error {
|
|
ttl := 24 * time.Hour
|
|
if rememberMe {
|
|
ttl = 30 * 24 * time.Hour
|
|
}
|
|
|
|
token, err := h.auth.CreateToken(user.ID, user.Role, user.Email, ttl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: h.cfg.SessionCookie,
|
|
Value: token,
|
|
HTTPOnly: true,
|
|
Secure: h.cfg.IsProduction(),
|
|
SameSite: fiber.CookieSameSiteLaxMode,
|
|
Path: "/",
|
|
Expires: time.Now().UTC().Add(ttl),
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *Handler) clearSessionCookie(c *fiber.Ctx) {
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: h.cfg.SessionCookie,
|
|
Value: "",
|
|
HTTPOnly: true,
|
|
Secure: h.cfg.IsProduction(),
|
|
SameSite: fiber.CookieSameSiteLaxMode,
|
|
Path: "/",
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
})
|
|
}
|
|
|
|
func mapAuthUserByID(user sqlc.GetAuthUserByIDRow) authUserResponse {
|
|
return authUserResponse{
|
|
ID: user.UserID,
|
|
Email: user.UserEmail,
|
|
Role: string(user.UserRole),
|
|
FullName: user.UserFullName,
|
|
IsActive: user.UserIsActive,
|
|
CreatedAt: timePointer(user.UserCreatedAt),
|
|
UpdatedAt: timePointer(user.UserUpdatedAt),
|
|
Profile: mapAuthProfile(user.PreferredName, user.ProfileIconUrl, user.Headline, user.Bio, user.Timezone, user.Locale, user.GradeLevel, user.LearningGoal, user.ProfileCreatedAt, user.ProfileUpdatedAt),
|
|
}
|
|
}
|
|
|
|
func mapAuthProfile(preferredName, profileIconURL, headline, bio, timezone, locale, gradeLevel, learningGoal pgtype.Text, createdAt, updatedAt pgtype.Timestamptz) authProfileResponse {
|
|
return authProfileResponse{
|
|
PreferredName: textPointer(preferredName),
|
|
ProfileIconURL: textPointer(profileIconURL),
|
|
Headline: textPointer(headline),
|
|
Bio: textPointer(bio),
|
|
Timezone: textPointer(timezone),
|
|
Locale: textPointer(locale),
|
|
GradeLevel: textPointer(gradeLevel),
|
|
LearningGoal: textPointer(learningGoal),
|
|
CreatedAt: timePointer(createdAt),
|
|
UpdatedAt: timePointer(updatedAt),
|
|
}
|
|
}
|
|
|
|
func withTimeout() (context.Context, context.CancelFunc) {
|
|
return context.WithTimeout(context.Background(), authQueryTimeout)
|
|
}
|
|
|
|
func mergeRequiredString(current string, input *string, fieldName string) (string, error) {
|
|
if input == nil {
|
|
return current, nil
|
|
}
|
|
|
|
value := strings.TrimSpace(*input)
|
|
if value == "" {
|
|
return "", errors.New(fieldName + " cannot be empty")
|
|
}
|
|
|
|
return value, nil
|
|
}
|
|
|
|
func mergeNullableText(current pgtype.Text, input *string) pgtype.Text {
|
|
if input == nil {
|
|
return current
|
|
}
|
|
|
|
value := strings.TrimSpace(*input)
|
|
if value == "" {
|
|
return pgtype.Text{}
|
|
}
|
|
|
|
return pgtype.Text{String: value, Valid: true}
|
|
}
|
|
|
|
func textPointer(value pgtype.Text) *string {
|
|
if !value.Valid {
|
|
return nil
|
|
}
|
|
|
|
text := value.String
|
|
return &text
|
|
}
|
|
|
|
func timePointer(value pgtype.Timestamptz) *time.Time {
|
|
if !value.Valid {
|
|
return nil
|
|
}
|
|
|
|
timestamp := value.Time.UTC()
|
|
return ×tamp
|
|
}
|