// Path: Backend/internal/handlers/auth/auth.go package auth import ( "context" "errors" "strings" "time" "boostai-backend/internal/config" "boostai-backend/internal/database" "boostai-backend/internal/http/respond" authmw "boostai-backend/internal/middleware" "boostai-backend/internal/sqlc" "github.com/gofiber/fiber/v2" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "golang.org/x/crypto/bcrypt" ) const authQueryTimeout = 5 * time.Second type Handler struct { db *database.DB queries *sqlc.Queries cfg *config.Config auth *authmw.AuthMiddleware } type authProfileResponse struct { PreferredName *string `json:"preferred_name"` ProfileIconURL *string `json:"profile_icon_url"` Headline *string `json:"headline"` Bio *string `json:"bio"` Timezone *string `json:"timezone"` Locale *string `json:"locale"` GradeLevel *string `json:"grade_level"` LearningGoal *string `json:"learning_goal"` CreatedAt *time.Time `json:"created_at"` UpdatedAt *time.Time `json:"updated_at"` } type authUserResponse struct { ID int64 `json:"id"` Email string `json:"email"` Role string `json:"role"` FullName string `json:"full_name"` IsActive bool `json:"is_active"` CreatedAt *time.Time `json:"created_at,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` Profile authProfileResponse `json:"profile"` } type authResponse struct { User authUserResponse `json:"user"` } type registerRequest struct { FirstName string `json:"first_name"` LastName string `json:"last_name"` Email string `json:"email"` Password string `json:"password"` Role string `json:"role"` } type loginRequest struct { Email string `json:"email"` Password string `json:"password"` RememberMe bool `json:"remember_me"` } type updateProfileRequest struct { FullName *string `json:"full_name"` PreferredName *string `json:"preferred_name"` ProfileIconURL *string `json:"profile_icon_url"` Headline *string `json:"headline"` Bio *string `json:"bio"` Timezone *string `json:"timezone"` Locale *string `json:"locale"` GradeLevel *string `json:"grade_level"` LearningGoal *string `json:"learning_goal"` } func NewHandler(cfg *config.Config, db *database.DB, auth *authmw.AuthMiddleware) *Handler { return &Handler{db: db, queries: sqlc.New(db.Pool), cfg: cfg, auth: auth} } func (h *Handler) RegisterUser(c *fiber.Ctx) error { var req registerRequest if err := c.BodyParser(&req); err != nil { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body") } fullName := strings.TrimSpace(strings.TrimSpace(req.FirstName) + " " + strings.TrimSpace(req.LastName)) if fullName == "" || strings.TrimSpace(req.Email) == "" || strings.TrimSpace(req.Password) == "" { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "first_name, last_name, email, and password are required") } role := sqlc.UserRoleStudent if strings.TrimSpace(req.Role) != "" { role = sqlc.UserRole(strings.TrimSpace(req.Role)) } if role != sqlc.UserRoleStudent && role != sqlc.UserRoleTeacher { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "role must be student or teacher") } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to secure password") } ctx, cancel := withTimeout() defer cancel() user, err := h.queries.CreateUser(ctx, sqlc.CreateUserParams{ Email: strings.TrimSpace(strings.ToLower(req.Email)), PasswordHash: pgtype.Text{String: string(hashedPassword), Valid: true}, Role: role, FullName: fullName, }) if err != nil { return respond.DatabaseError(c, err) } if err := h.setSessionCookie(c, user, false); err != nil { return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to create session") } authUser, err := h.queries.GetAuthUserByID(ctx, user.ID) if err != nil { return respond.DatabaseError(c, err) } return c.Status(fiber.StatusCreated).JSON(authResponse{User: mapAuthUserByID(authUser)}) } func (h *Handler) Login(c *fiber.Ctx) error { var req loginRequest if err := c.BodyParser(&req); err != nil { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body") } if strings.TrimSpace(req.Email) == "" || strings.TrimSpace(req.Password) == "" { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "email and password are required") } ctx, cancel := withTimeout() defer cancel() user, err := h.queries.GetUserByEmail(ctx, strings.TrimSpace(strings.ToLower(req.Email))) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password") } return respond.DatabaseError(c, err) } if !user.IsActive || !user.PasswordHash.Valid { return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password") } if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash.String), []byte(req.Password)); err != nil { return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password") } if err := h.setSessionCookie(c, user, req.RememberMe); err != nil { return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to create session") } authUser, err := h.queries.GetAuthUserByID(ctx, user.ID) if err != nil { return respond.DatabaseError(c, err) } return c.JSON(authResponse{User: mapAuthUserByID(authUser)}) } func (h *Handler) Me(c *fiber.Ctx) error { userID := authmw.CurrentUserID(c) if userID == 0 { return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "Authentication required") } ctx, cancel := withTimeout() defer cancel() user, err := h.queries.GetAuthUserByID(ctx, userID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "User not found") } return respond.DatabaseError(c, err) } return c.JSON(authResponse{User: mapAuthUserByID(user)}) } func (h *Handler) UpdateMe(c *fiber.Ctx) error { userID := authmw.CurrentUserID(c) if userID == 0 { return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "Authentication required") } var req updateProfileRequest if err := c.BodyParser(&req); err != nil { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body") } ctx, cancel := withTimeout() defer cancel() tx, err := h.db.Pool.Begin(ctx) if err != nil { return respond.DatabaseError(c, err) } defer func() { _ = tx.Rollback(ctx) }() queries := h.queries.WithTx(tx) current, err := queries.GetAuthUserByID(ctx, userID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "User not found") } return respond.DatabaseError(c, err) } fullName, err := mergeRequiredString(current.UserFullName, req.FullName, "full_name") if err != nil { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", err.Error()) } if fullName != current.UserFullName { if _, err := queries.UpdateUserFullName(ctx, sqlc.UpdateUserFullNameParams{ID: userID, FullName: fullName}); err != nil { return respond.DatabaseError(c, err) } } if _, err := queries.UpsertUserProfile(ctx, sqlc.UpsertUserProfileParams{ UserID: userID, PreferredName: mergeNullableText(current.PreferredName, req.PreferredName), ProfileIconUrl: mergeNullableText(current.ProfileIconUrl, req.ProfileIconURL), Headline: mergeNullableText(current.Headline, req.Headline), Bio: mergeNullableText(current.Bio, req.Bio), Timezone: mergeNullableText(current.Timezone, req.Timezone), Locale: mergeNullableText(current.Locale, req.Locale), GradeLevel: mergeNullableText(current.GradeLevel, req.GradeLevel), LearningGoal: mergeNullableText(current.LearningGoal, req.LearningGoal), }); err != nil { return respond.DatabaseError(c, err) } updated, err := queries.GetAuthUserByID(ctx, userID) if err != nil { return respond.DatabaseError(c, err) } if err := tx.Commit(ctx); err != nil { return respond.DatabaseError(c, err) } return c.JSON(authResponse{User: mapAuthUserByID(updated)}) } func (h *Handler) Logout(c *fiber.Ctx) error { h.clearSessionCookie(c) return c.JSON(fiber.Map{"status": "ok"}) } func (h *Handler) setSessionCookie(c *fiber.Ctx, user sqlc.User, rememberMe bool) error { ttl := 24 * time.Hour if rememberMe { ttl = 30 * 24 * time.Hour } token, err := h.auth.CreateToken(user.ID, user.Role, user.Email, ttl) if err != nil { return err } c.Cookie(&fiber.Cookie{ Name: h.cfg.SessionCookie, Value: token, HTTPOnly: true, Secure: h.cfg.IsProduction(), SameSite: fiber.CookieSameSiteLaxMode, Path: "/", Expires: time.Now().UTC().Add(ttl), }) return nil } func (h *Handler) clearSessionCookie(c *fiber.Ctx) { c.Cookie(&fiber.Cookie{ Name: h.cfg.SessionCookie, Value: "", HTTPOnly: true, Secure: h.cfg.IsProduction(), SameSite: fiber.CookieSameSiteLaxMode, Path: "/", Expires: time.Unix(0, 0), MaxAge: -1, }) } func mapAuthUserByID(user sqlc.GetAuthUserByIDRow) authUserResponse { return authUserResponse{ ID: user.UserID, Email: user.UserEmail, Role: string(user.UserRole), FullName: user.UserFullName, IsActive: user.UserIsActive, CreatedAt: timePointer(user.UserCreatedAt), UpdatedAt: timePointer(user.UserUpdatedAt), Profile: mapAuthProfile(user.PreferredName, user.ProfileIconUrl, user.Headline, user.Bio, user.Timezone, user.Locale, user.GradeLevel, user.LearningGoal, user.ProfileCreatedAt, user.ProfileUpdatedAt), } } func mapAuthProfile(preferredName, profileIconURL, headline, bio, timezone, locale, gradeLevel, learningGoal pgtype.Text, createdAt, updatedAt pgtype.Timestamptz) authProfileResponse { return authProfileResponse{ PreferredName: textPointer(preferredName), ProfileIconURL: textPointer(profileIconURL), Headline: textPointer(headline), Bio: textPointer(bio), Timezone: textPointer(timezone), Locale: textPointer(locale), GradeLevel: textPointer(gradeLevel), LearningGoal: textPointer(learningGoal), CreatedAt: timePointer(createdAt), UpdatedAt: timePointer(updatedAt), } } func withTimeout() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), authQueryTimeout) } func mergeRequiredString(current string, input *string, fieldName string) (string, error) { if input == nil { return current, nil } value := strings.TrimSpace(*input) if value == "" { return "", errors.New(fieldName + " cannot be empty") } return value, nil } func mergeNullableText(current pgtype.Text, input *string) pgtype.Text { if input == nil { return current } value := strings.TrimSpace(*input) if value == "" { return pgtype.Text{} } return pgtype.Text{String: value, Valid: true} } func textPointer(value pgtype.Text) *string { if !value.Valid { return nil } text := value.String return &text } func timePointer(value pgtype.Timestamptz) *time.Time { if !value.Valid { return nil } timestamp := value.Time.UTC() return ×tamp }