Files
BoostAI/Backend/internal/handlers/web/auth/auth.go
2026-05-25 17:05:06 +01:00

385 lines
11 KiB
Go

// Path: Backend/internal/handlers/auth/auth.go
package auth
import (
"context"
"errors"
"strings"
"time"
"boostai-backend/internal/config"
"boostai-backend/internal/database"
"boostai-backend/internal/http/respond"
authmw "boostai-backend/internal/middleware"
"boostai-backend/internal/sqlc"
"github.com/gofiber/fiber/v2"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"golang.org/x/crypto/bcrypt"
)
const authQueryTimeout = 5 * time.Second
type Handler struct {
db *database.DB
queries *sqlc.Queries
cfg *config.Config
auth *authmw.AuthMiddleware
}
type authProfileResponse struct {
PreferredName *string `json:"preferred_name"`
ProfileIconURL *string `json:"profile_icon_url"`
Headline *string `json:"headline"`
Bio *string `json:"bio"`
Timezone *string `json:"timezone"`
Locale *string `json:"locale"`
GradeLevel *string `json:"grade_level"`
LearningGoal *string `json:"learning_goal"`
CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
}
type authUserResponse struct {
ID int64 `json:"id"`
Email string `json:"email"`
Role string `json:"role"`
FullName string `json:"full_name"`
IsActive bool `json:"is_active"`
CreatedAt *time.Time `json:"created_at,omitempty"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
Profile authProfileResponse `json:"profile"`
}
type authResponse struct {
User authUserResponse `json:"user"`
}
type registerRequest struct {
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
Email string `json:"email"`
Password string `json:"password"`
Role string `json:"role"`
}
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
RememberMe bool `json:"remember_me"`
}
type updateProfileRequest struct {
FullName *string `json:"full_name"`
PreferredName *string `json:"preferred_name"`
ProfileIconURL *string `json:"profile_icon_url"`
Headline *string `json:"headline"`
Bio *string `json:"bio"`
Timezone *string `json:"timezone"`
Locale *string `json:"locale"`
GradeLevel *string `json:"grade_level"`
LearningGoal *string `json:"learning_goal"`
}
func NewHandler(cfg *config.Config, db *database.DB, auth *authmw.AuthMiddleware) *Handler {
return &Handler{db: db, queries: sqlc.New(db.Pool), cfg: cfg, auth: auth}
}
func (h *Handler) RegisterUser(c *fiber.Ctx) error {
var req registerRequest
if err := c.BodyParser(&req); err != nil {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
}
fullName := strings.TrimSpace(strings.TrimSpace(req.FirstName) + " " + strings.TrimSpace(req.LastName))
if fullName == "" || strings.TrimSpace(req.Email) == "" || strings.TrimSpace(req.Password) == "" {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "first_name, last_name, email, and password are required")
}
role := sqlc.UserRoleStudent
if strings.TrimSpace(req.Role) != "" {
role = sqlc.UserRole(strings.TrimSpace(req.Role))
}
if role != sqlc.UserRoleStudent && role != sqlc.UserRoleTeacher {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "role must be student or teacher")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to secure password")
}
ctx, cancel := withTimeout()
defer cancel()
user, err := h.queries.CreateUser(ctx, sqlc.CreateUserParams{
Email: strings.TrimSpace(strings.ToLower(req.Email)),
PasswordHash: pgtype.Text{String: string(hashedPassword), Valid: true},
Role: role,
FullName: fullName,
})
if err != nil {
return respond.DatabaseError(c, err)
}
if err := h.setSessionCookie(c, user, false); err != nil {
return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to create session")
}
authUser, err := h.queries.GetAuthUserByID(ctx, user.ID)
if err != nil {
return respond.DatabaseError(c, err)
}
return c.Status(fiber.StatusCreated).JSON(authResponse{User: mapAuthUserByID(authUser)})
}
func (h *Handler) Login(c *fiber.Ctx) error {
var req loginRequest
if err := c.BodyParser(&req); err != nil {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
}
if strings.TrimSpace(req.Email) == "" || strings.TrimSpace(req.Password) == "" {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "email and password are required")
}
ctx, cancel := withTimeout()
defer cancel()
user, err := h.queries.GetUserByEmail(ctx, strings.TrimSpace(strings.ToLower(req.Email)))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password")
}
return respond.DatabaseError(c, err)
}
if !user.IsActive || !user.PasswordHash.Valid {
return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password")
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash.String), []byte(req.Password)); err != nil {
return respond.Error(c, fiber.StatusUnauthorized, "invalid_credentials", "Invalid email or password")
}
if err := h.setSessionCookie(c, user, req.RememberMe); err != nil {
return respond.Error(c, fiber.StatusInternalServerError, "auth_error", "Unable to create session")
}
authUser, err := h.queries.GetAuthUserByID(ctx, user.ID)
if err != nil {
return respond.DatabaseError(c, err)
}
return c.JSON(authResponse{User: mapAuthUserByID(authUser)})
}
func (h *Handler) Me(c *fiber.Ctx) error {
userID := authmw.CurrentUserID(c)
if userID == 0 {
return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "Authentication required")
}
ctx, cancel := withTimeout()
defer cancel()
user, err := h.queries.GetAuthUserByID(ctx, userID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "User not found")
}
return respond.DatabaseError(c, err)
}
return c.JSON(authResponse{User: mapAuthUserByID(user)})
}
func (h *Handler) UpdateMe(c *fiber.Ctx) error {
userID := authmw.CurrentUserID(c)
if userID == 0 {
return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "Authentication required")
}
var req updateProfileRequest
if err := c.BodyParser(&req); err != nil {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
}
ctx, cancel := withTimeout()
defer cancel()
tx, err := h.db.Pool.Begin(ctx)
if err != nil {
return respond.DatabaseError(c, err)
}
defer func() {
_ = tx.Rollback(ctx)
}()
queries := h.queries.WithTx(tx)
current, err := queries.GetAuthUserByID(ctx, userID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusUnauthorized, "unauthorized", "User not found")
}
return respond.DatabaseError(c, err)
}
fullName, err := mergeRequiredString(current.UserFullName, req.FullName, "full_name")
if err != nil {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", err.Error())
}
if fullName != current.UserFullName {
if _, err := queries.UpdateUserFullName(ctx, sqlc.UpdateUserFullNameParams{ID: userID, FullName: fullName}); err != nil {
return respond.DatabaseError(c, err)
}
}
if _, err := queries.UpsertUserProfile(ctx, sqlc.UpsertUserProfileParams{
UserID: userID,
PreferredName: mergeNullableText(current.PreferredName, req.PreferredName),
ProfileIconUrl: mergeNullableText(current.ProfileIconUrl, req.ProfileIconURL),
Headline: mergeNullableText(current.Headline, req.Headline),
Bio: mergeNullableText(current.Bio, req.Bio),
Timezone: mergeNullableText(current.Timezone, req.Timezone),
Locale: mergeNullableText(current.Locale, req.Locale),
GradeLevel: mergeNullableText(current.GradeLevel, req.GradeLevel),
LearningGoal: mergeNullableText(current.LearningGoal, req.LearningGoal),
}); err != nil {
return respond.DatabaseError(c, err)
}
updated, err := queries.GetAuthUserByID(ctx, userID)
if err != nil {
return respond.DatabaseError(c, err)
}
if err := tx.Commit(ctx); err != nil {
return respond.DatabaseError(c, err)
}
return c.JSON(authResponse{User: mapAuthUserByID(updated)})
}
func (h *Handler) Logout(c *fiber.Ctx) error {
h.clearSessionCookie(c)
return c.JSON(fiber.Map{"status": "ok"})
}
func (h *Handler) setSessionCookie(c *fiber.Ctx, user sqlc.User, rememberMe bool) error {
ttl := 24 * time.Hour
if rememberMe {
ttl = 30 * 24 * time.Hour
}
token, err := h.auth.CreateToken(user.ID, user.Role, user.Email, ttl)
if err != nil {
return err
}
c.Cookie(&fiber.Cookie{
Name: h.cfg.SessionCookie,
Value: token,
HTTPOnly: true,
Secure: h.cfg.IsProduction(),
SameSite: fiber.CookieSameSiteLaxMode,
Path: "/",
Expires: time.Now().UTC().Add(ttl),
})
return nil
}
func (h *Handler) clearSessionCookie(c *fiber.Ctx) {
c.Cookie(&fiber.Cookie{
Name: h.cfg.SessionCookie,
Value: "",
HTTPOnly: true,
Secure: h.cfg.IsProduction(),
SameSite: fiber.CookieSameSiteLaxMode,
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
})
}
func mapAuthUserByID(user sqlc.GetAuthUserByIDRow) authUserResponse {
return authUserResponse{
ID: user.UserID,
Email: user.UserEmail,
Role: string(user.UserRole),
FullName: user.UserFullName,
IsActive: user.UserIsActive,
CreatedAt: timePointer(user.UserCreatedAt),
UpdatedAt: timePointer(user.UserUpdatedAt),
Profile: mapAuthProfile(user.PreferredName, user.ProfileIconUrl, user.Headline, user.Bio, user.Timezone, user.Locale, user.GradeLevel, user.LearningGoal, user.ProfileCreatedAt, user.ProfileUpdatedAt),
}
}
func mapAuthProfile(preferredName, profileIconURL, headline, bio, timezone, locale, gradeLevel, learningGoal pgtype.Text, createdAt, updatedAt pgtype.Timestamptz) authProfileResponse {
return authProfileResponse{
PreferredName: textPointer(preferredName),
ProfileIconURL: textPointer(profileIconURL),
Headline: textPointer(headline),
Bio: textPointer(bio),
Timezone: textPointer(timezone),
Locale: textPointer(locale),
GradeLevel: textPointer(gradeLevel),
LearningGoal: textPointer(learningGoal),
CreatedAt: timePointer(createdAt),
UpdatedAt: timePointer(updatedAt),
}
}
func withTimeout() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), authQueryTimeout)
}
func mergeRequiredString(current string, input *string, fieldName string) (string, error) {
if input == nil {
return current, nil
}
value := strings.TrimSpace(*input)
if value == "" {
return "", errors.New(fieldName + " cannot be empty")
}
return value, nil
}
func mergeNullableText(current pgtype.Text, input *string) pgtype.Text {
if input == nil {
return current
}
value := strings.TrimSpace(*input)
if value == "" {
return pgtype.Text{}
}
return pgtype.Text{String: value, Valid: true}
}
func textPointer(value pgtype.Text) *string {
if !value.Valid {
return nil
}
text := value.String
return &text
}
func timePointer(value pgtype.Timestamptz) *time.Time {
if !value.Valid {
return nil
}
timestamp := value.Time.UTC()
return &timestamp
}