Before Fine Tune
This commit is contained in:
@@ -11,6 +11,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type requestStyle string
|
||||
|
||||
const (
|
||||
requestStyleResponses requestStyle = "responses"
|
||||
requestStyleChatCompletions requestStyle = "chat_completions"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
endpoint string
|
||||
apiKey string
|
||||
@@ -109,14 +116,9 @@ func (s *Service) ReviewSubmission(ctx context.Context, input AssignmentReviewIn
|
||||
return nil, fmt.Errorf("marshal AI review input: %w", err)
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"model": s.model,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "system",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": strings.TrimSpace(`You are reviewing student homework submissions for a teacher workflow.
|
||||
outputText, err := s.runStructuredRequest(
|
||||
ctx,
|
||||
strings.TrimSpace(`You are reviewing student homework submissions for a teacher workflow.
|
||||
|
||||
You must assess the student's understanding by looking at the student's final answer and working against the saved correct answer when one is available. Do not re-grade weighting.
|
||||
|
||||
@@ -142,56 +144,11 @@ Interpretation guidance:
|
||||
- support = the student shows meaningful gaps and likely needs targeted help.
|
||||
- redo = the student should redo the assignment because understanding is broadly too weak or incomplete.
|
||||
|
||||
Review the full assignment in one pass and produce a short assignment-level summary.`),
|
||||
}},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": string(payloadJSON),
|
||||
}},
|
||||
},
|
||||
},
|
||||
"text": map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "json_schema",
|
||||
"name": "assignment_review",
|
||||
"strict": true,
|
||||
"schema": reviewSchema(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal AI review request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build AI review request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("api-key", s.apiKey)
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send AI review request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read AI review response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("AI review request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
|
||||
}
|
||||
|
||||
outputText, err := extractOutputText(respBytes)
|
||||
Review the full assignment in one pass and produce a short assignment-level summary.`),
|
||||
string(payloadJSON),
|
||||
"assignment_review",
|
||||
reviewSchema(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -215,14 +172,9 @@ func (s *Service) PlanRedoAssignment(ctx context.Context, input RedoPlanInput) (
|
||||
return nil, fmt.Errorf("marshal redo plan input: %w", err)
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"model": s.model,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "system",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": strings.TrimSpace(`You are planning the next redo assignment for a student.
|
||||
outputText, err := s.runStructuredRequest(
|
||||
ctx,
|
||||
strings.TrimSpace(`You are planning the next redo assignment for a student.
|
||||
|
||||
You are NOT writing final math questions. You are only producing a structured topic+difficulty blueprint for a later generator layer.
|
||||
|
||||
@@ -238,55 +190,10 @@ Rules:
|
||||
- reason on each item should briefly explain why that topic/difficulty belongs in the redo set.
|
||||
- Do not invent topics outside the allowed topic vocabulary.
|
||||
- Do not output prose outside the JSON schema.`),
|
||||
}},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": string(payloadJSON),
|
||||
}},
|
||||
},
|
||||
},
|
||||
"text": map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "json_schema",
|
||||
"name": "redo_assignment_plan",
|
||||
"strict": true,
|
||||
"schema": redoPlanSchema(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal redo plan request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build redo plan request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("api-key", s.apiKey)
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send redo plan request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read redo plan response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("redo plan request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
|
||||
}
|
||||
|
||||
outputText, err := extractOutputText(respBytes)
|
||||
string(payloadJSON),
|
||||
"redo_assignment_plan",
|
||||
redoPlanSchema(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -300,6 +207,146 @@ Rules:
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *Service) runStructuredRequest(ctx context.Context, systemPrompt, userPrompt, schemaName string, schema map[string]any) (string, error) {
|
||||
respBytes, err := s.sendStructuredRequest(ctx, systemPrompt, userPrompt, schemaName, schema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch s.requestStyle() {
|
||||
case requestStyleChatCompletions:
|
||||
return extractChatCompletionText(respBytes)
|
||||
default:
|
||||
return extractResponsesOutputText(respBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) sendStructuredRequest(ctx context.Context, systemPrompt, userPrompt, schemaName string, schema map[string]any) ([]byte, error) {
|
||||
body, err := s.buildStructuredRequestBody(systemPrompt, userPrompt, schemaName, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal AI review request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build AI review request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
s.applyAuthHeader(req)
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send AI review request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read AI review response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("AI review request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
|
||||
}
|
||||
|
||||
return respBytes, nil
|
||||
}
|
||||
|
||||
func (s *Service) buildStructuredRequestBody(systemPrompt, userPrompt, schemaName string, schema map[string]any) (map[string]any, error) {
|
||||
switch s.requestStyle() {
|
||||
case requestStyleChatCompletions:
|
||||
body := map[string]any{
|
||||
"model": s.model,
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "system",
|
||||
"content": systemPrompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": userPrompt,
|
||||
},
|
||||
},
|
||||
"temperature": 0,
|
||||
"response_format": map[string]any{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]any{
|
||||
"name": schemaName,
|
||||
"strict": true,
|
||||
"schema": schema,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if s.shouldDisableThinking() {
|
||||
body["chat_template_kwargs"] = map[string]any{
|
||||
"enable_thinking": false,
|
||||
}
|
||||
}
|
||||
|
||||
return body, nil
|
||||
default:
|
||||
return map[string]any{
|
||||
"model": s.model,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "system",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": systemPrompt,
|
||||
}},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": userPrompt,
|
||||
}},
|
||||
},
|
||||
},
|
||||
"text": map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "json_schema",
|
||||
"name": schemaName,
|
||||
"strict": true,
|
||||
"schema": schema,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) requestStyle() requestStyle {
|
||||
endpoint := strings.ToLower(strings.TrimSpace(s.endpoint))
|
||||
if strings.Contains(endpoint, "/chat/completions") {
|
||||
return requestStyleChatCompletions
|
||||
}
|
||||
return requestStyleResponses
|
||||
}
|
||||
|
||||
func (s *Service) applyAuthHeader(req *http.Request) {
|
||||
if s.isAzureEndpoint() {
|
||||
req.Header.Set("api-key", s.apiKey)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
||||
}
|
||||
|
||||
func (s *Service) shouldDisableThinking() bool {
|
||||
return s.requestStyle() == requestStyleChatCompletions && !s.isAzureEndpoint()
|
||||
}
|
||||
|
||||
func (s *Service) isAzureEndpoint() bool {
|
||||
endpoint := strings.ToLower(strings.TrimSpace(s.endpoint))
|
||||
return strings.Contains(endpoint, "cognitiveservices.azure.com") || strings.Contains(endpoint, ".openai.azure.com")
|
||||
}
|
||||
|
||||
func reviewSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
@@ -358,7 +405,7 @@ func redoPlanSchema() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func extractOutputText(respBytes []byte) (string, error) {
|
||||
func extractResponsesOutputText(respBytes []byte) (string, error) {
|
||||
var direct struct {
|
||||
OutputText string `json:"output_text"`
|
||||
}
|
||||
@@ -380,6 +427,47 @@ func extractOutputText(respBytes []byte) (string, error) {
|
||||
return "", fmt.Errorf("AI review response did not contain structured output text")
|
||||
}
|
||||
|
||||
func extractChatCompletionText(respBytes []byte) (string, error) {
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content any `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBytes, &payload); err != nil {
|
||||
return "", fmt.Errorf("decode AI review chat completion response: %w", err)
|
||||
}
|
||||
|
||||
for _, choice := range payload.Choices {
|
||||
if text := strings.TrimSpace(extractMessageContent(choice.Message.Content)); text != "" {
|
||||
return text, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("AI review chat completion response did not contain message content")
|
||||
}
|
||||
|
||||
func extractMessageContent(content any) string {
|
||||
switch typed := content.(type) {
|
||||
case string:
|
||||
return typed
|
||||
case []any:
|
||||
parts := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
text := strings.TrimSpace(findOutputText(item))
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func findOutputText(value any) string {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
|
||||
177
Backend/internal/aireview/service_test.go
Normal file
177
Backend/internal/aireview/service_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package aireview
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildStructuredRequestBodyResponsesStyle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService("https://example.test/v1/responses", "test-key", "test-model")
|
||||
body, err := svc.buildStructuredRequestBody("system prompt", "user prompt", "test_schema", reviewSchema())
|
||||
if err != nil {
|
||||
t.Fatalf("buildStructuredRequestBody returned error: %v", err)
|
||||
}
|
||||
|
||||
if got := body["model"]; got != "test-model" {
|
||||
t.Fatalf("model = %v, want test-model", got)
|
||||
}
|
||||
|
||||
if _, ok := body["input"]; !ok {
|
||||
t.Fatalf("responses body missing input field: %#v", body)
|
||||
}
|
||||
|
||||
if _, ok := body["text"]; !ok {
|
||||
t.Fatalf("responses body missing text field: %#v", body)
|
||||
}
|
||||
|
||||
if _, ok := body["messages"]; ok {
|
||||
t.Fatalf("responses body should not include messages: %#v", body)
|
||||
}
|
||||
if _, ok := body["response_format"]; ok {
|
||||
t.Fatalf("responses body should not include response_format: %#v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildStructuredRequestBodyVLLMChatStyle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService("http://100.92.130.19:8000/v1/chat/completions", "test-key", "qwen3.6-27b")
|
||||
body, err := svc.buildStructuredRequestBody("system prompt", "user prompt", "assignment_review", reviewSchema())
|
||||
if err != nil {
|
||||
t.Fatalf("buildStructuredRequestBody returned error: %v", err)
|
||||
}
|
||||
|
||||
if got := body["model"]; got != "qwen3.6-27b" {
|
||||
t.Fatalf("model = %v, want qwen3.6-27b", got)
|
||||
}
|
||||
|
||||
if _, ok := body["messages"]; !ok {
|
||||
t.Fatalf("chat body missing messages: %#v", body)
|
||||
}
|
||||
|
||||
responseFormat, ok := body["response_format"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("response_format missing or wrong type: %#v", body["response_format"])
|
||||
}
|
||||
if got := responseFormat["type"]; got != "json_schema" {
|
||||
t.Fatalf("response_format.type = %v, want json_schema", got)
|
||||
}
|
||||
|
||||
kwargs, ok := body["chat_template_kwargs"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected chat_template_kwargs for vLLM chat style: %#v", body)
|
||||
}
|
||||
if got := kwargs["enable_thinking"]; got != false {
|
||||
t.Fatalf("enable_thinking = %v, want false", got)
|
||||
}
|
||||
|
||||
if _, ok := body["input"]; ok {
|
||||
t.Fatalf("chat body should not include responses input field: %#v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildStructuredRequestBodyAzureChatStyleDoesNotDisableThinking(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService("https://example.openai.azure.com/openai/deployments/review/chat/completions?api-version=2025-01-01-preview", "test-key", "gpt-4.1-mini")
|
||||
body, err := svc.buildStructuredRequestBody("system prompt", "user prompt", "assignment_review", reviewSchema())
|
||||
if err != nil {
|
||||
t.Fatalf("buildStructuredRequestBody returned error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := body["chat_template_kwargs"]; ok {
|
||||
t.Fatalf("azure chat body should not include chat_template_kwargs: %#v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuthHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("azure uses api-key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService("https://example.openai.azure.com/openai/deployments/review/chat/completions?api-version=2025-01-01-preview", "azure-key", "gpt-4.1-mini")
|
||||
req, err := http.NewRequest(http.MethodPost, svc.endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest returned error: %v", err)
|
||||
}
|
||||
|
||||
svc.applyAuthHeader(req)
|
||||
|
||||
if got := req.Header.Get("api-key"); got != "azure-key" {
|
||||
t.Fatalf("api-key header = %q, want azure-key", got)
|
||||
}
|
||||
if got := req.Header.Get("Authorization"); got != "" {
|
||||
t.Fatalf("Authorization header = %q, want empty", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-azure uses bearer auth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService("http://100.92.130.19:8000/v1/chat/completions", "vllm-key", "qwen3.6-27b")
|
||||
req, err := http.NewRequest(http.MethodPost, svc.endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest returned error: %v", err)
|
||||
}
|
||||
|
||||
svc.applyAuthHeader(req)
|
||||
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer vllm-key" {
|
||||
t.Fatalf("Authorization header = %q, want %q", got, "Bearer vllm-key")
|
||||
}
|
||||
if got := req.Header.Get("api-key"); got != "" {
|
||||
t.Fatalf("api-key header = %q, want empty", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractChatCompletionText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("string content", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
respBytes := []byte(`{"choices":[{"message":{"content":"{\"status\":\"ok\"}"}}]}`)
|
||||
got, err := extractChatCompletionText(respBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("extractChatCompletionText returned error: %v", err)
|
||||
}
|
||||
if got != `{"status":"ok"}` {
|
||||
t.Fatalf("content = %q, want %q", got, `{"status":"ok"}`)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("array content", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := map[string]any{
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"message": map[string]any{
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "{\"status\":"},
|
||||
map[string]any{"type": "text", "text": "\"ok\"}"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
respBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal returned error: %v", err)
|
||||
}
|
||||
|
||||
got, err := extractChatCompletionText(respBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("extractChatCompletionText returned error: %v", err)
|
||||
}
|
||||
if got != "{\"status\":\n\"ok\"}" {
|
||||
t.Fatalf("content = %q, want joined text output", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -8,28 +8,36 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port string
|
||||
Environment string
|
||||
AllowedOrigins string
|
||||
DatabaseURL string
|
||||
JWTSecret string
|
||||
SessionCookie string
|
||||
AIReviewEndpoint string
|
||||
AIReviewAPIKey string
|
||||
AIReviewModel string
|
||||
Port string
|
||||
Environment string
|
||||
AllowedOrigins string
|
||||
DatabaseURL string
|
||||
JWTSecret string
|
||||
SessionCookie string
|
||||
MockDataDir string
|
||||
AdminReseedEnabled bool
|
||||
AdminReseedSecret string
|
||||
ReseedPagePassword string
|
||||
AIReviewEndpoint string
|
||||
AIReviewAPIKey string
|
||||
AIReviewModel string
|
||||
}
|
||||
|
||||
func Load() *Config {
|
||||
return &Config{
|
||||
Port: getEnv("BACKEND_INTERNAL_PORT", "8081"),
|
||||
Environment: getEnv("GO_ENV", "development"),
|
||||
AllowedOrigins: getEnv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:4321,http://localhost:8080,http://windows-wsl:8080"),
|
||||
DatabaseURL: getEnv("DATABASE_URL", "postgres://boostai:boostai_dev_password@localhost:5439/boostai?sslmode=disable"),
|
||||
JWTSecret: getEnv("JWT_SECRET", "boostai-dev-jwt-secret-change-me"),
|
||||
SessionCookie: getEnv("SESSION_COOKIE_NAME", "boostai_session"),
|
||||
AIReviewEndpoint: getEnv("AI_REVIEW_ENDPOINT", ""),
|
||||
AIReviewAPIKey: getEnv("AI_REVIEW_API_KEY", ""),
|
||||
AIReviewModel: getEnv("AI_REVIEW_MODEL", ""),
|
||||
Port: getEnv("BACKEND_INTERNAL_PORT", "8081"),
|
||||
Environment: getEnv("GO_ENV", "development"),
|
||||
AllowedOrigins: getEnv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:4321,http://localhost:8080,http://windows-wsl:8080"),
|
||||
DatabaseURL: getEnv("DATABASE_URL", "postgres://boostai:boostai_dev_password@localhost:5439/boostai?sslmode=disable"),
|
||||
JWTSecret: getEnv("JWT_SECRET", "boostai-dev-jwt-secret-change-me"),
|
||||
SessionCookie: getEnv("SESSION_COOKIE_NAME", "boostai_session"),
|
||||
MockDataDir: getEnv("MOCK_DATA_DIR", "../Mock-Data"),
|
||||
AdminReseedEnabled: getEnvBool("ENABLE_ADMIN_RESEED", false),
|
||||
AdminReseedSecret: getEnv("ADMIN_RESEED_SECRET", ""),
|
||||
ReseedPagePassword: getEnv("RESEED_PAGE_PASSWORD", "1588"),
|
||||
AIReviewEndpoint: getEnv("AI_REVIEW_ENDPOINT", ""),
|
||||
AIReviewAPIKey: getEnv("AI_REVIEW_API_KEY", ""),
|
||||
AIReviewModel: getEnv("AI_REVIEW_MODEL", ""),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,3 +56,19 @@ func getEnv(key, fallback string) string {
|
||||
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnvBool(key string, fallback bool) bool {
|
||||
value, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return fallback
|
||||
}
|
||||
|
||||
switch strings.TrimSpace(strings.ToLower(value)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
100
Backend/internal/handlers/api/admin/handler.go
Normal file
100
Backend/internal/handlers/api/admin/handler.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"boostai-backend/internal/config"
|
||||
"boostai-backend/internal/database"
|
||||
"boostai-backend/internal/http/respond"
|
||||
"boostai-backend/internal/seeddata"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
ReseedHeaderName = "X-Admin-Reseed-Secret"
|
||||
ReseedConfirm = "RESEED"
|
||||
)
|
||||
|
||||
type Runner interface {
|
||||
Run(ctx context.Context, mockDataDir string) (seeddata.Summary, error)
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
runner runFunc
|
||||
}
|
||||
|
||||
type runFunc func(timeoutCtx context.Context, mockDataDir string) (seeddata.Summary, error)
|
||||
|
||||
type reseedRequest struct {
|
||||
Confirm string `json:"confirm"`
|
||||
}
|
||||
|
||||
type reseedResponse struct {
|
||||
OK bool `json:"ok"`
|
||||
Environment string `json:"environment"`
|
||||
TriggeredBy string `json:"triggered_by,omitempty"`
|
||||
TriggeredAt time.Time `json:"triggered_at"`
|
||||
Summary seeddata.Summary `json:"summary"`
|
||||
}
|
||||
|
||||
func NewHandler(db *database.DB, cfg *config.Config) *Handler {
|
||||
return &Handler{
|
||||
cfg: cfg,
|
||||
runner: func(timeoutCtx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
return seeddata.Run(timeoutCtx, db, mockDataDir)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) ReseedDatabase(c *fiber.Ctx) error {
|
||||
if !h.cfg.AdminReseedEnabled {
|
||||
return respond.Error(c, fiber.StatusNotFound, "not_found", "The requested endpoint does not exist")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(h.cfg.AdminReseedSecret) == "" {
|
||||
return respond.Error(c, fiber.StatusServiceUnavailable, "admin_reseed_unavailable", "Admin reseed is not configured")
|
||||
}
|
||||
|
||||
providedSecret := strings.TrimSpace(c.Get(ReseedHeaderName))
|
||||
if subtle.ConstantTimeCompare([]byte(providedSecret), []byte(h.cfg.AdminReseedSecret)) != 1 {
|
||||
return respond.Error(c, fiber.StatusForbidden, "forbidden", "Valid reseed secret required")
|
||||
}
|
||||
|
||||
var req reseedRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Invalid request body")
|
||||
}
|
||||
if strings.TrimSpace(req.Confirm) != ReseedConfirm {
|
||||
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "confirm must equal RESEED")
|
||||
}
|
||||
|
||||
triggeredBy, _ := c.Locals("auth.email").(string)
|
||||
userID, _ := c.Locals("auth.user_id").(int64)
|
||||
startedAt := time.Now().UTC()
|
||||
log.Printf("admin reseed requested environment=%s user_id=%d email=%s ip=%s", h.cfg.Environment, userID, triggeredBy, c.IP())
|
||||
|
||||
timeoutCtx, cancelTimeout := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancelTimeout()
|
||||
|
||||
summary, err := h.runner(timeoutCtx, h.cfg.MockDataDir)
|
||||
if err != nil {
|
||||
log.Printf("admin reseed failed environment=%s user_id=%d email=%s err=%v", h.cfg.Environment, userID, triggeredBy, err)
|
||||
return respond.Error(c, fiber.StatusInternalServerError, "admin_reseed_failed", err.Error())
|
||||
}
|
||||
|
||||
log.Printf("admin reseed completed environment=%s user_id=%d email=%s users=%d assignments=%d student_answers=%d", h.cfg.Environment, userID, triggeredBy, summary.Users, summary.Assignments, summary.StudentAnswers)
|
||||
|
||||
return c.JSON(reseedResponse{
|
||||
OK: true,
|
||||
Environment: h.cfg.Environment,
|
||||
TriggeredBy: triggeredBy,
|
||||
TriggeredAt: startedAt,
|
||||
Summary: summary,
|
||||
})
|
||||
}
|
||||
108
Backend/internal/handlers/api/admin/handler_test.go
Normal file
108
Backend/internal/handlers/api/admin/handler_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"boostai-backend/internal/config"
|
||||
"boostai-backend/internal/seeddata"
|
||||
"boostai-backend/internal/sqlc"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func TestReseedDatabaseRequiresEnableFlag(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := &Handler{cfg: &config.Config{Environment: "production"}}
|
||||
status := performReseedRequest(t, h, map[string]any{"confirm": "RESEED"}, "secret", true)
|
||||
if status != fiber.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReseedDatabaseRequiresSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
return seeddata.Summary{}, nil
|
||||
})
|
||||
status := performReseedRequest(t, h, map[string]any{"confirm": "RESEED"}, "wrong", true)
|
||||
if status != fiber.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReseedDatabaseRequiresConfirm(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
return seeddata.Summary{}, nil
|
||||
})
|
||||
status := performReseedRequest(t, h, map[string]any{"confirm": "nope"}, "secret", true)
|
||||
if status != fiber.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReseedDatabaseReturnsSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
if mockDataDir != "/app/Mock-Data" {
|
||||
t.Fatalf("expected mock data dir /app/Mock-Data, got %q", mockDataDir)
|
||||
}
|
||||
return seeddata.Summary{Users: 13, Assignments: 8, StudentAnswers: 588}, nil
|
||||
})
|
||||
status := performReseedRequest(t, h, map[string]any{"confirm": "RESEED"}, "secret", true)
|
||||
if status != fiber.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReseedDatabaseSurfacesRunnerError(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
return seeddata.Summary{}, errors.New("boom")
|
||||
})
|
||||
status := performReseedRequest(t, h, map[string]any{"confirm": "RESEED"}, "secret", true)
|
||||
if status != fiber.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestHandler(fn func(context.Context, string) (seeddata.Summary, error)) *Handler {
|
||||
return &Handler{
|
||||
cfg: &config.Config{Environment: "production", AdminReseedEnabled: true, AdminReseedSecret: "secret", MockDataDir: "/app/Mock-Data"},
|
||||
runner: fn,
|
||||
}
|
||||
}
|
||||
|
||||
func performReseedRequest(t *testing.T, handler *Handler, payload map[string]any, secret string, authenticated bool) int {
|
||||
t.Helper()
|
||||
app := fiber.New()
|
||||
app.Post("/internal/admin/reseed", func(c *fiber.Ctx) error {
|
||||
if authenticated {
|
||||
c.Locals("auth.user_id", int64(42))
|
||||
c.Locals("auth.role", sqlc.UserRoleTeacher)
|
||||
c.Locals("auth.email", "teacher@example.com")
|
||||
}
|
||||
return handler.ReseedDatabase(c)
|
||||
})
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/internal/admin/reseed", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if secret != "" {
|
||||
req.Header.Set(ReseedHeaderName, secret)
|
||||
}
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("app.Test: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode
|
||||
}
|
||||
11
Backend/internal/handlers/api/admin/routes.go
Normal file
11
Backend/internal/handlers/api/admin/routes.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
authmw "boostai-backend/internal/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func RegisterRoutes(app fiber.Router, auth *authmw.AuthMiddleware, h *Handler) {
|
||||
app.Post("/internal/admin/reseed", auth.RequireTeacher(), h.ReseedDatabase)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"boostai-backend/internal/assignmentgen"
|
||||
"boostai-backend/internal/config"
|
||||
"boostai-backend/internal/database"
|
||||
adminhandler "boostai-backend/internal/handlers/api/admin"
|
||||
answershandler "boostai-backend/internal/handlers/api/answers"
|
||||
assignmentshandler "boostai-backend/internal/handlers/api/assignments"
|
||||
classroomshandler "boostai-backend/internal/handlers/api/classrooms"
|
||||
@@ -22,6 +23,7 @@ type Handler struct {
|
||||
questions *questionshandler.Handler
|
||||
assignments *assignmentshandler.Handler
|
||||
answers *answershandler.Handler
|
||||
admin *adminhandler.Handler
|
||||
}
|
||||
|
||||
func NewHandler(db *database.DB, cfg *config.Config) *Handler {
|
||||
@@ -37,5 +39,6 @@ func NewHandler(db *database.DB, cfg *config.Config) *Handler {
|
||||
questions: questionshandler.NewHandler(queries, questionGenerator),
|
||||
assignments: assignmentshandler.NewHandler(queries, aiReviewService, assignmentGenerator),
|
||||
answers: answershandler.NewHandler(queries, aiReviewService),
|
||||
admin: adminhandler.NewHandler(db, cfg),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"boostai-backend/internal/handlers/api/admin"
|
||||
"boostai-backend/internal/handlers/api/answers"
|
||||
"boostai-backend/internal/handlers/api/assignments"
|
||||
"boostai-backend/internal/handlers/api/classrooms"
|
||||
@@ -19,4 +20,5 @@ func (h *Handler) Register(app fiber.Router, auth *authmw.AuthMiddleware) {
|
||||
questions.RegisterRoutes(app, auth, h.questions)
|
||||
assignments.RegisterRoutes(app, auth, h.assignments)
|
||||
answers.RegisterRoutes(app, auth, h.answers)
|
||||
admin.RegisterRoutes(app, auth, h.admin)
|
||||
}
|
||||
|
||||
338
Backend/internal/handlers/web/reseed/reseed.go
Normal file
338
Backend/internal/handlers/web/reseed/reseed.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package reseed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"html"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"boostai-backend/internal/config"
|
||||
"boostai-backend/internal/database"
|
||||
"boostai-backend/internal/seeddata"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const reseedCookieName = "boostai_reseed_auth"
|
||||
|
||||
type runFunc func(ctx context.Context, mockDataDir string) (seeddata.Summary, error)
|
||||
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
runner runFunc
|
||||
}
|
||||
|
||||
type pageData struct {
|
||||
Authorized bool
|
||||
Environment string
|
||||
MockDataDir string
|
||||
Error string
|
||||
Success string
|
||||
Summary *seeddata.Summary
|
||||
}
|
||||
|
||||
func NewHandler(db *database.DB, cfg *config.Config) *Handler {
|
||||
return &Handler{
|
||||
cfg: cfg,
|
||||
runner: func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
return seeddata.Run(ctx, db, mockDataDir)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Page(c *fiber.Ctx) error {
|
||||
if !h.cfg.AdminReseedEnabled {
|
||||
return fiber.ErrNotFound
|
||||
}
|
||||
|
||||
return h.renderPage(c, pageData{
|
||||
Authorized: h.isAuthorized(c),
|
||||
Environment: h.cfg.Environment,
|
||||
MockDataDir: h.cfg.MockDataDir,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) Login(c *fiber.Ctx) error {
|
||||
if !h.cfg.AdminReseedEnabled {
|
||||
return fiber.ErrNotFound
|
||||
}
|
||||
|
||||
password := strings.TrimSpace(c.FormValue("password"))
|
||||
if subtle.ConstantTimeCompare([]byte(password), []byte(h.cfg.ReseedPagePassword)) != 1 {
|
||||
return h.renderPage(c.Status(fiber.StatusUnauthorized), pageData{
|
||||
Authorized: false,
|
||||
Environment: h.cfg.Environment,
|
||||
MockDataDir: h.cfg.MockDataDir,
|
||||
Error: "Invalid password",
|
||||
})
|
||||
}
|
||||
|
||||
h.setAuthCookie(c)
|
||||
return c.Redirect("/reseed", fiber.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (h *Handler) Run(c *fiber.Ctx) error {
|
||||
if !h.cfg.AdminReseedEnabled {
|
||||
return fiber.ErrNotFound
|
||||
}
|
||||
|
||||
if !h.isAuthorized(c) {
|
||||
return h.renderPage(c.Status(fiber.StatusUnauthorized), pageData{
|
||||
Authorized: false,
|
||||
Environment: h.cfg.Environment,
|
||||
MockDataDir: h.cfg.MockDataDir,
|
||||
Error: "Please unlock the reseed page first",
|
||||
})
|
||||
}
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
log.Printf("browser reseed requested environment=%s ip=%s", h.cfg.Environment, c.IP())
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
summary, err := h.runner(timeoutCtx, h.cfg.MockDataDir)
|
||||
if err != nil {
|
||||
log.Printf("browser reseed failed environment=%s ip=%s err=%v", h.cfg.Environment, c.IP(), err)
|
||||
return h.renderPage(c.Status(fiber.StatusInternalServerError), pageData{
|
||||
Authorized: true,
|
||||
Environment: h.cfg.Environment,
|
||||
MockDataDir: h.cfg.MockDataDir,
|
||||
Error: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
log.Printf("browser reseed completed environment=%s ip=%s users=%d assignments=%d student_answers=%d", h.cfg.Environment, c.IP(), summary.Users, summary.Assignments, summary.StudentAnswers)
|
||||
|
||||
return h.renderPage(c, pageData{
|
||||
Authorized: true,
|
||||
Environment: h.cfg.Environment,
|
||||
MockDataDir: h.cfg.MockDataDir,
|
||||
Success: fmt.Sprintf("Reseed completed at %s UTC", startedAt.Format("2006-01-02 15:04:05")),
|
||||
Summary: &summary,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) Logout(c *fiber.Ctx) error {
|
||||
if !h.cfg.AdminReseedEnabled {
|
||||
return fiber.ErrNotFound
|
||||
}
|
||||
|
||||
h.clearAuthCookie(c)
|
||||
return c.Redirect("/reseed", fiber.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (h *Handler) isAuthorized(c *fiber.Ctx) bool {
|
||||
provided := strings.TrimSpace(c.Cookies(reseedCookieName))
|
||||
expected := h.authCookieValue()
|
||||
if provided == "" || expected == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) == 1
|
||||
}
|
||||
|
||||
func (h *Handler) setAuthCookie(c *fiber.Ctx) {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: reseedCookieName,
|
||||
Value: h.authCookieValue(),
|
||||
HTTPOnly: true,
|
||||
Secure: h.cfg.IsProduction(),
|
||||
SameSite: fiber.CookieSameSiteLaxMode,
|
||||
Path: "/",
|
||||
Expires: time.Now().UTC().Add(12 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) clearAuthCookie(c *fiber.Ctx) {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: reseedCookieName,
|
||||
Value: "",
|
||||
HTTPOnly: true,
|
||||
Secure: h.cfg.IsProduction(),
|
||||
SameSite: fiber.CookieSameSiteLaxMode,
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) authCookieValue() string {
|
||||
sum := sha256.Sum256([]byte(h.cfg.JWTSecret + "|" + h.cfg.ReseedPagePassword))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (h *Handler) renderPage(c *fiber.Ctx, data pageData) error {
|
||||
content := reseedPageHTML(data)
|
||||
c.Type("html", "utf-8")
|
||||
return c.SendString(content)
|
||||
}
|
||||
|
||||
func reseedPageHTML(data pageData) string {
|
||||
var statusHTML strings.Builder
|
||||
if data.Error != "" {
|
||||
statusHTML.WriteString(`<div class="notice notice-error">` + html.EscapeString(data.Error) + `</div>`)
|
||||
}
|
||||
if data.Success != "" {
|
||||
statusHTML.WriteString(`<div class="notice notice-success">` + html.EscapeString(data.Success) + `</div>`)
|
||||
}
|
||||
if data.Summary != nil {
|
||||
statusHTML.WriteString(`<pre class="summary">`)
|
||||
statusHTML.WriteString(html.EscapeString(fmt.Sprintf(
|
||||
"users: %d\nclassrooms: %d\nquestions: %d\ntags: %d\nassignments: %d\nassignment_links: %d\nstudent_answers: %d\nmock_data_dir: %s",
|
||||
data.Summary.Users,
|
||||
data.Summary.Classrooms,
|
||||
data.Summary.Questions,
|
||||
data.Summary.Tags,
|
||||
data.Summary.Assignments,
|
||||
data.Summary.AssignmentLinks,
|
||||
data.Summary.StudentAnswers,
|
||||
data.Summary.MockDataDir,
|
||||
)))
|
||||
statusHTML.WriteString(`</pre>`)
|
||||
}
|
||||
|
||||
var body strings.Builder
|
||||
if data.Authorized {
|
||||
body.WriteString(`
|
||||
<div class="card">
|
||||
<h2>Reseed database</h2>
|
||||
<p>This will clear seeded app data and repopulate it from Mock-Data.</p>
|
||||
<form method="post" action="/reseed/run">
|
||||
<button class="danger" type="submit">Reseed now</button>
|
||||
</form>
|
||||
<form method="post" action="/reseed/logout">
|
||||
<button type="submit">Lock page</button>
|
||||
</form>
|
||||
</div>
|
||||
`)
|
||||
} else {
|
||||
body.WriteString(`
|
||||
<div class="card">
|
||||
<h2>Unlock reseed</h2>
|
||||
<form method="post" action="/reseed/login">
|
||||
<label for="password">Password</label>
|
||||
<input id="password" name="password" type="password" autocomplete="current-password" required />
|
||||
<button type="submit">Unlock</button>
|
||||
</form>
|
||||
</div>
|
||||
`)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>BoostAI Reseed</title>
|
||||
<style>
|
||||
:root { color-scheme: dark; }
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: Inter, Arial, sans-serif;
|
||||
background: #0f172a;
|
||||
color: #e2e8f0;
|
||||
}
|
||||
main {
|
||||
max-width: 720px;
|
||||
margin: 48px auto;
|
||||
padding: 0 20px 48px;
|
||||
}
|
||||
h1, h2 { margin-top: 0; }
|
||||
.card {
|
||||
background: #111827;
|
||||
border: 1px solid #334155;
|
||||
border-radius: 16px;
|
||||
padding: 24px;
|
||||
margin-bottom: 20px;
|
||||
box-shadow: 0 12px 30px rgba(0,0,0,0.25);
|
||||
}
|
||||
.meta {
|
||||
display: grid;
|
||||
grid-template-columns: 180px 1fr;
|
||||
gap: 8px 12px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.meta strong { color: #93c5fd; }
|
||||
label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
font-weight: 600;
|
||||
}
|
||||
input {
|
||||
width: 100%%;
|
||||
box-sizing: border-box;
|
||||
padding: 12px 14px;
|
||||
border-radius: 10px;
|
||||
border: 1px solid #475569;
|
||||
background: #0f172a;
|
||||
color: #e2e8f0;
|
||||
margin-bottom: 14px;
|
||||
}
|
||||
button {
|
||||
padding: 12px 16px;
|
||||
border-radius: 10px;
|
||||
border: 0;
|
||||
cursor: pointer;
|
||||
font-weight: 700;
|
||||
background: #38bdf8;
|
||||
color: #082f49;
|
||||
margin-right: 12px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
button.danger {
|
||||
background: #f87171;
|
||||
color: #450a0a;
|
||||
}
|
||||
.notice {
|
||||
padding: 14px 16px;
|
||||
border-radius: 12px;
|
||||
margin-bottom: 16px;
|
||||
font-weight: 600;
|
||||
}
|
||||
.notice-error {
|
||||
background: rgba(220, 38, 38, 0.18);
|
||||
border: 1px solid rgba(248, 113, 113, 0.45);
|
||||
color: #fecaca;
|
||||
}
|
||||
.notice-success {
|
||||
background: rgba(22, 163, 74, 0.18);
|
||||
border: 1px solid rgba(74, 222, 128, 0.45);
|
||||
color: #bbf7d0;
|
||||
}
|
||||
.summary {
|
||||
background: #020617;
|
||||
border: 1px solid #334155;
|
||||
border-radius: 12px;
|
||||
padding: 16px;
|
||||
overflow: auto;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<div class="card">
|
||||
<h1>BoostAI reseed</h1>
|
||||
<div class="meta">
|
||||
<strong>Environment</strong><span>%s</span>
|
||||
<strong>Mock data path</strong><span>%s</span>
|
||||
<strong>Mode</strong><span>Browser-protected destructive reseed</span>
|
||||
</div>
|
||||
%s
|
||||
</div>
|
||||
%s
|
||||
</main>
|
||||
</body>
|
||||
</html>`,
|
||||
html.EscapeString(data.Environment),
|
||||
html.EscapeString(data.MockDataDir),
|
||||
statusHTML.String(),
|
||||
body.String(),
|
||||
)
|
||||
}
|
||||
137
Backend/internal/handlers/web/reseed/reseed_test.go
Normal file
137
Backend/internal/handlers/web/reseed/reseed_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package reseed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"boostai-backend/internal/config"
|
||||
"boostai-backend/internal/seeddata"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func TestPageRequiresEnableFlag(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := &Handler{cfg: &config.Config{Environment: "production"}}
|
||||
status, _ := performRequest(t, h, http.MethodGet, "/reseed", "", "")
|
||||
if status != fiber.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginSetsCookieAndRedirects(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
return seeddata.Summary{}, nil
|
||||
})
|
||||
status, resp := performRequest(t, h, http.MethodPost, "/reseed/login", "password=1588", "")
|
||||
if status != fiber.StatusSeeOther {
|
||||
t.Fatalf("expected 303, got %d", status)
|
||||
}
|
||||
if location := resp.Header.Get("Location"); location != "/reseed" {
|
||||
t.Fatalf("expected redirect to /reseed, got %q", location)
|
||||
}
|
||||
if cookie := resp.Header.Get("Set-Cookie"); !strings.Contains(cookie, reseedCookieName+"=") {
|
||||
t.Fatalf("expected auth cookie, got %q", cookie)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRequiresAuthCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
return seeddata.Summary{}, nil
|
||||
})
|
||||
status, body := performRequestBody(t, h, http.MethodPost, "/reseed/run", "", "")
|
||||
if status != fiber.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", status)
|
||||
}
|
||||
if !strings.Contains(body, "Please unlock the reseed page first") {
|
||||
t.Fatalf("expected unlock message, got %q", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunExecutesReseed(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
if mockDataDir != "/app/Mock-Data" {
|
||||
t.Fatalf("expected mock data dir /app/Mock-Data, got %q", mockDataDir)
|
||||
}
|
||||
return seeddata.Summary{Users: 13, Assignments: 8, StudentAnswers: 588, MockDataDir: mockDataDir}, nil
|
||||
})
|
||||
status, body := performRequestBody(t, h, http.MethodPost, "/reseed/run", "", reseedCookieName+"="+h.authCookieValue())
|
||||
if status != fiber.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", status)
|
||||
}
|
||||
if !strings.Contains(body, "Reseed completed") || !strings.Contains(body, "student_answers: 588") {
|
||||
t.Fatalf("expected success summary, got %q", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSurfacesRunnerError(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(func(ctx context.Context, mockDataDir string) (seeddata.Summary, error) {
|
||||
return seeddata.Summary{}, errors.New("boom")
|
||||
})
|
||||
status, body := performRequestBody(t, h, http.MethodPost, "/reseed/run", "", reseedCookieName+"="+h.authCookieValue())
|
||||
if status != fiber.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", status)
|
||||
}
|
||||
if !strings.Contains(body, "boom") {
|
||||
t.Fatalf("expected error body, got %q", body)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestHandler(fn func(context.Context, string) (seeddata.Summary, error)) *Handler {
|
||||
return &Handler{
|
||||
cfg: &config.Config{
|
||||
Environment: "production",
|
||||
AdminReseedEnabled: true,
|
||||
MockDataDir: "/app/Mock-Data",
|
||||
JWTSecret: "jwt-secret",
|
||||
ReseedPagePassword: "1588",
|
||||
},
|
||||
runner: fn,
|
||||
}
|
||||
}
|
||||
|
||||
func performRequest(t *testing.T, handler *Handler, method, path, formBody, cookieHeader string) (int, *http.Response) {
|
||||
t.Helper()
|
||||
app := testApp(handler)
|
||||
req := httptest.NewRequest(method, path, strings.NewReader(formBody))
|
||||
if formBody != "" {
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
if cookieHeader != "" {
|
||||
req.Header.Set("Cookie", cookieHeader)
|
||||
}
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("app.Test: %v", err)
|
||||
}
|
||||
return resp.StatusCode, resp
|
||||
}
|
||||
|
||||
func performRequestBody(t *testing.T, handler *Handler, method, path, formBody, cookieHeader string) (int, string) {
|
||||
t.Helper()
|
||||
status, resp := performRequest(t, handler, method, path, formBody, cookieHeader)
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll: %v", err)
|
||||
}
|
||||
return status, string(bodyBytes)
|
||||
}
|
||||
|
||||
func testApp(handler *Handler) *fiber.App {
|
||||
app := fiber.New()
|
||||
app.Get("/reseed", handler.Page)
|
||||
app.Post("/reseed/login", handler.Login)
|
||||
app.Post("/reseed/run", handler.Run)
|
||||
app.Post("/reseed/logout", handler.Logout)
|
||||
return app
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"boostai-backend/internal/database"
|
||||
webAuth "boostai-backend/internal/handlers/web/auth"
|
||||
"boostai-backend/internal/handlers/web/health"
|
||||
webReseed "boostai-backend/internal/handlers/web/reseed"
|
||||
"boostai-backend/internal/handlers/web/root"
|
||||
authmw "boostai-backend/internal/middleware"
|
||||
|
||||
@@ -19,9 +20,14 @@ func registerWebRoutes(app *fiber.App, cfg *config.Config, db *database.DB, auth
|
||||
rootHandler := root.NewHandler()
|
||||
healthHandler := health.NewHandler(cfg.Environment, db)
|
||||
authHandler := webAuth.NewHandler(cfg, db, authMiddleware)
|
||||
reseedHandler := webReseed.NewHandler(db, cfg)
|
||||
|
||||
app.Get("/", rootHandler.Index)
|
||||
app.Get("/health", healthHandler.Check)
|
||||
app.Get("/reseed", reseedHandler.Page)
|
||||
app.Post("/reseed/login", reseedHandler.Login)
|
||||
app.Post("/reseed/run", reseedHandler.Run)
|
||||
app.Post("/reseed/logout", reseedHandler.Logout)
|
||||
|
||||
authGroup := app.Group("/auth")
|
||||
authGroup.Post("/register", authHandler.RegisterUser)
|
||||
|
||||
975
Backend/internal/seeddata/seed.go
Normal file
975
Backend/internal/seeddata/seed.go
Normal file
@@ -0,0 +1,975 @@
|
||||
package seeddata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"boostai-backend/internal/database"
|
||||
sharedapi "boostai-backend/internal/handlers/api/shared"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const SeededPassword = "password123"
|
||||
|
||||
type studentRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
FullName string `json:"fullname"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Active bool `json:"active"`
|
||||
IsDeleted bool `json:"is_deleted"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type classroomFile struct {
|
||||
Classroom classroomRecord `json:"classroom"`
|
||||
Tutor tutorRecord `json:"tutor"`
|
||||
ClassroomStudentRs []classroomStudentRecord `json:"classroom_student_rs"`
|
||||
}
|
||||
|
||||
type classroomRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
TutorID int64 `json:"tutor_id"`
|
||||
InviteCode string `json:"invite_code"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type tutorRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
FullName string `json:"fullname"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Active bool `json:"active"`
|
||||
IsDeleted bool `json:"is_deleted"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type classroomStudentRecord struct {
|
||||
ClassroomID int64 `json:"classroom_id"`
|
||||
StudentID int64 `json:"student_id"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
type questionRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
Topic string `json:"topic"`
|
||||
SubTopic *string `json:"sub_topic"`
|
||||
Tag *string `json:"tag"`
|
||||
Difficulty string `json:"difficulty"`
|
||||
QuestionText string `json:"question_text"`
|
||||
CorrectAnswer string `json:"correct_answer"`
|
||||
Source string `json:"source"`
|
||||
TeacherID int64 `json:"teacher_id"`
|
||||
IsDeleted bool `json:"is_deleted"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type assignmentRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
TeacherID int64 `json:"teacher_id"`
|
||||
Topic string `json:"topic"`
|
||||
DueDate int64 `json:"due_date"`
|
||||
Status string `json:"status"`
|
||||
IsDeleted bool `json:"is_deleted"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type assignmentQuestionRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
AssignmentID int64 `json:"assignment_id"`
|
||||
QuestionBankID int64 `json:"question_bank_id"`
|
||||
QuestionOrder int32 `json:"question_order"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
type assignmentAssigneeRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
AssignmentID int64 `json:"assignment_id"`
|
||||
StudentID int64 `json:"student_id"`
|
||||
Status string `json:"status"`
|
||||
StartedAt *int64 `json:"started_at"`
|
||||
SubmittedAt *int64 `json:"submitted_at"`
|
||||
OverallScore *float64 `json:"overall_score"`
|
||||
AIFeedback *string `json:"ai_feedback"`
|
||||
NextStepOutcome *string `json:"next_step_outcome"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
type studentAnswerRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
AssigneeID int64 `json:"assignee_id"`
|
||||
AssignmentQuestionID int64 `json:"assignment_question_id"`
|
||||
AnswerLatex *string `json:"answer_latex"`
|
||||
ExtractedAnswer *string `json:"extracted_answer"`
|
||||
SolveMode *string `json:"solve_mode"`
|
||||
WorkingSteps *string `json:"working_steps"`
|
||||
AIReasoning *string `json:"ai_reasoning"`
|
||||
IsCorrect *bool `json:"is_correct"`
|
||||
AIFeedback *string `json:"ai_feedback"`
|
||||
ReviewNeedsAttention *bool `json:"review_needs_attention"`
|
||||
ReviewIssueReason *string `json:"review_issue_reason"`
|
||||
ReviewCorrectnessScore *float64 `json:"review_correctness_score"`
|
||||
ReviewUnderstandingScore *float64 `json:"review_understanding_score"`
|
||||
ReviewQuestionScore *float64 `json:"review_question_score"`
|
||||
ReviewConfidence *float64 `json:"review_confidence"`
|
||||
ReviewTags []string `json:"review_tags"`
|
||||
GradingStatus string `json:"grading_status"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
AnsweredAt *int64 `json:"_answered_at"`
|
||||
UnderSolveMode *string `json:"_solve_mode"`
|
||||
UnderIsCorrect *bool `json:"_is_correct"`
|
||||
UnderMisconceptionTag *string `json:"_misconception_tag"`
|
||||
}
|
||||
|
||||
type assignmentQuestionRef struct {
|
||||
AssignmentID int64
|
||||
QuestionID int64
|
||||
Position int32
|
||||
}
|
||||
|
||||
const DefaultMockDataDir = "../Mock-Data"
|
||||
|
||||
type Summary struct {
|
||||
Users int `json:"users"`
|
||||
Classrooms int `json:"classrooms"`
|
||||
Questions int `json:"questions"`
|
||||
Tags int `json:"tags"`
|
||||
Assignments int `json:"assignments"`
|
||||
AssignmentLinks int `json:"assignment_links"`
|
||||
StudentAnswers int `json:"student_answers"`
|
||||
MockDataDir string `json:"mock_data_dir"`
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, db *database.DB, mockDataDir string) (Summary, error) {
|
||||
if strings.TrimSpace(mockDataDir) == "" {
|
||||
mockDataDir = filepath.Clean(filepath.Join("..", "Mock-Data"))
|
||||
}
|
||||
|
||||
var (
|
||||
students []studentRecord
|
||||
classroomPayload classroomFile
|
||||
questions []questionRecord
|
||||
assignments []assignmentRecord
|
||||
assignmentQuestions []assignmentQuestionRecord
|
||||
assignmentAssignees []assignmentAssigneeRecord
|
||||
studentAnswers []studentAnswerRecord
|
||||
)
|
||||
|
||||
if err := loadJSON(filepath.Join(mockDataDir, "students.json"), &students); err != nil {
|
||||
return Summary{}, err
|
||||
}
|
||||
if err := loadJSON(filepath.Join(mockDataDir, "classroom.json"), &classroomPayload); err != nil {
|
||||
return Summary{}, err
|
||||
}
|
||||
if err := loadJSON(filepath.Join(mockDataDir, "question_bank.json"), &questions); err != nil {
|
||||
return Summary{}, err
|
||||
}
|
||||
if err := loadJSON(filepath.Join(mockDataDir, "assignments.json"), &assignments); err != nil {
|
||||
return Summary{}, err
|
||||
}
|
||||
if err := loadJSON(filepath.Join(mockDataDir, "assignment_questions.json"), &assignmentQuestions); err != nil {
|
||||
return Summary{}, err
|
||||
}
|
||||
if err := loadJSON(filepath.Join(mockDataDir, "assignment_assignees.json"), &assignmentAssignees); err != nil {
|
||||
return Summary{}, err
|
||||
}
|
||||
if err := loadJSON(filepath.Join(mockDataDir, "student_answers.json"), &studentAnswers); err != nil {
|
||||
return Summary{}, err
|
||||
}
|
||||
|
||||
if err := db.Migrate(); err != nil {
|
||||
return Summary{}, fmt.Errorf("migrate database: %w", err)
|
||||
}
|
||||
|
||||
tx, err := db.Pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return Summary{}, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
if err := resetSeedData(ctx, tx); err != nil {
|
||||
return Summary{}, fmt.Errorf("reset data: %w", err)
|
||||
}
|
||||
|
||||
if err := seedUsers(ctx, tx, classroomPayload.Tutor, students); err != nil {
|
||||
return Summary{}, fmt.Errorf("seed users: %w", err)
|
||||
}
|
||||
|
||||
if err := seedClassroom(ctx, tx, classroomPayload); err != nil {
|
||||
return Summary{}, fmt.Errorf("seed classroom: %w", err)
|
||||
}
|
||||
|
||||
tagIDs, err := seedQuestionsAndTags(ctx, tx, questions)
|
||||
if err != nil {
|
||||
return Summary{}, fmt.Errorf("seed questions: %w", err)
|
||||
}
|
||||
|
||||
if err := seedAssignments(ctx, tx, classroomPayload.Classroom.ID, assignments); err != nil {
|
||||
return Summary{}, fmt.Errorf("seed assignments: %w", err)
|
||||
}
|
||||
|
||||
if err := seedAssignmentAssignees(ctx, tx, assignmentAssignees); err != nil {
|
||||
return Summary{}, fmt.Errorf("seed assignment assignees: %w", err)
|
||||
}
|
||||
|
||||
assignmentQuestionMap, err := seedAssignmentQuestions(ctx, tx, assignmentQuestions)
|
||||
if err != nil {
|
||||
return Summary{}, fmt.Errorf("seed assignment questions: %w", err)
|
||||
}
|
||||
|
||||
if err := seedStudentAnswers(ctx, tx, assignmentAssignees, assignmentQuestionMap, studentAnswers); err != nil {
|
||||
return Summary{}, fmt.Errorf("seed student answers: %w", err)
|
||||
}
|
||||
|
||||
if err := seedMessages(ctx, tx, classroomPayload.Tutor, students); err != nil {
|
||||
return Summary{}, fmt.Errorf("seed messages: %w", err)
|
||||
}
|
||||
|
||||
if err := syncSequences(ctx, tx); err != nil {
|
||||
return Summary{}, fmt.Errorf("sync sequences: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return Summary{}, fmt.Errorf("commit seed transaction: %w", err)
|
||||
}
|
||||
|
||||
return Summary{
|
||||
Users: len(students) + 1,
|
||||
Classrooms: 1,
|
||||
Questions: len(questions),
|
||||
Tags: len(tagIDs),
|
||||
Assignments: len(assignments),
|
||||
AssignmentLinks: len(assignmentQuestions),
|
||||
StudentAnswers: len(studentAnswers),
|
||||
MockDataDir: mockDataDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadJSON(path string, target any) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", path, err)
|
||||
}
|
||||
if err := json.Unmarshal(data, target); err != nil {
|
||||
return fmt.Errorf("parse %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resetSeedData(ctx context.Context, tx pgx.Tx) error {
|
||||
_, err := tx.Exec(ctx, `
|
||||
TRUNCATE TABLE
|
||||
messages,
|
||||
message_thread_participants,
|
||||
message_threads,
|
||||
assignment_student_questions,
|
||||
student_answers,
|
||||
assignment_questions,
|
||||
assignment_assignees,
|
||||
assignments,
|
||||
question_tags,
|
||||
tags,
|
||||
questions,
|
||||
classroom_students,
|
||||
classrooms,
|
||||
users
|
||||
RESTART IDENTITY CASCADE`)
|
||||
return err
|
||||
}
|
||||
|
||||
func seedUsers(ctx context.Context, tx pgx.Tx, tutor tutorRecord, students []studentRecord) error {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(SeededPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := insertUser(ctx, tx, tutor.ID, tutor.Email, string(hashedPassword), "teacher", tutor.FullName, tutor.Active && !tutor.IsDeleted, tutor.CreatedAt, tutor.UpdatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, student := range students {
|
||||
if student.IsDeleted {
|
||||
continue
|
||||
}
|
||||
if err := insertUser(ctx, tx, student.ID, student.Email, string(hashedPassword), "student", student.FullName, student.Active && !student.IsDeleted, student.CreatedAt, student.UpdatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedMessages(ctx context.Context, tx pgx.Tx, tutor tutorRecord, students []studentRecord) error {
|
||||
activeStudents := make([]studentRecord, 0, len(students))
|
||||
for _, student := range students {
|
||||
if student.IsDeleted || !student.Active {
|
||||
continue
|
||||
}
|
||||
activeStudents = append(activeStudents, student)
|
||||
}
|
||||
|
||||
if len(activeStudents) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
threadSeedCount := len(activeStudents)
|
||||
if threadSeedCount > 3 {
|
||||
threadSeedCount = 3
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
for idx := 0; idx < threadSeedCount; idx++ {
|
||||
student := activeStudents[idx]
|
||||
threadCreatedAt := now.Add(-time.Duration(threadSeedCount-idx) * 6 * time.Hour)
|
||||
subject := fmt.Sprintf("Study check-in for %s", firstName(student.FullName))
|
||||
|
||||
var threadID int64
|
||||
if err := tx.QueryRow(ctx, `
|
||||
INSERT INTO message_threads (created_by_user_id, subject, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $3)
|
||||
RETURNING id`, tutor.ID, subject, threadCreatedAt).Scan(&threadID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
teacherBody := fmt.Sprintf("Hi %s, send me a quick update when you finish today's maths block. If one question feels sticky, tell me which one and I'll help.", firstName(student.FullName))
|
||||
studentBody := fmt.Sprintf("Thanks %s — I started the assignment set and I'm feeling better about the fraction questions now.", firstName(tutor.FullName))
|
||||
|
||||
messageTimes := []time.Time{threadCreatedAt.Add(12 * time.Minute), threadCreatedAt.Add(54 * time.Minute)}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO messages (thread_id, sender_user_id, body, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $4), ($1, $5, $6, $7, $7)`, threadID, tutor.ID, teacherBody, messageTimes[0], student.ID, studentBody, messageTimes[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
teacherReadAt := pgtype.Timestamptz{Time: messageTimes[1], Valid: true}
|
||||
studentReadAt := pgtype.Timestamptz{}
|
||||
if idx == 0 {
|
||||
studentReadAt = pgtype.Timestamptz{Time: messageTimes[1], Valid: true}
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO message_thread_participants (thread_id, user_id, joined_at, last_read_at)
|
||||
VALUES ($1, $2, $3, $4), ($1, $5, $3, $6)`, threadID, tutor.ID, threadCreatedAt, teacherReadAt, student.ID, studentReadAt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(ctx, `UPDATE message_threads SET updated_at = $2 WHERE id = $1`, threadID, messageTimes[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertUser(ctx context.Context, tx pgx.Tx, id int64, email, passwordHash, role, fullName string, active bool, createdAtMs, updatedAtMs int64) error {
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO users (id, email, password_hash, role, full_name, is_active, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4::user_role, $5, $6, $7, $8)`,
|
||||
id,
|
||||
email,
|
||||
passwordHash,
|
||||
role,
|
||||
fullName,
|
||||
active,
|
||||
msToTime(createdAtMs),
|
||||
msToTime(updatedAtMs),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func seedClassroom(ctx context.Context, tx pgx.Tx, payload classroomFile) error {
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO classrooms (id, teacher_id, name, code, description, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
||||
payload.Classroom.ID,
|
||||
payload.Classroom.TutorID,
|
||||
payload.Classroom.Name,
|
||||
sharedapi.NullableText(optionalString(payload.Classroom.InviteCode)),
|
||||
sharedapi.NullableText(classroomDescription(payload.Classroom.Name)),
|
||||
msToTime(payload.Classroom.CreatedAt),
|
||||
msToTime(payload.Classroom.UpdatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, relation := range payload.ClassroomStudentRs {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO classroom_students (classroom_id, student_id, joined_at)
|
||||
VALUES ($1, $2, $3)`, relation.ClassroomID, relation.StudentID, msToTime(relation.CreatedAt)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedQuestionsAndTags(ctx context.Context, tx pgx.Tx, questions []questionRecord) ([]int64, error) {
|
||||
tagSet := map[string]struct{}{}
|
||||
for _, question := range questions {
|
||||
if question.Tag != nil {
|
||||
tag := strings.TrimSpace(*question.Tag)
|
||||
if tag != "" {
|
||||
tagSet[tag] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tagNames := make([]string, 0, len(tagSet))
|
||||
for tag := range tagSet {
|
||||
tagNames = append(tagNames, tag)
|
||||
}
|
||||
sort.Strings(tagNames)
|
||||
|
||||
tagIDByName := make(map[string]int64, len(tagNames))
|
||||
tagIDs := make([]int64, 0, len(tagNames))
|
||||
for _, tagName := range tagNames {
|
||||
var tagID int64
|
||||
if err := tx.QueryRow(ctx, `
|
||||
INSERT INTO tags (name)
|
||||
VALUES ($1)
|
||||
RETURNING id`, tagName).Scan(&tagID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tagIDByName[tagName] = tagID
|
||||
tagIDs = append(tagIDs, tagID)
|
||||
}
|
||||
|
||||
for _, question := range questions {
|
||||
if err := insertQuestion(ctx, tx, question); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if question.Tag == nil {
|
||||
continue
|
||||
}
|
||||
tagName := strings.TrimSpace(*question.Tag)
|
||||
if tagName == "" {
|
||||
continue
|
||||
}
|
||||
tagID := tagIDByName[tagName]
|
||||
if _, err := tx.Exec(ctx, `INSERT INTO question_tags (question_id, tag_id) VALUES ($1, $2)`, question.ID, tagID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tagIDs, nil
|
||||
}
|
||||
|
||||
func insertQuestion(ctx context.Context, tx pgx.Tx, record questionRecord) error {
|
||||
status := "published"
|
||||
if record.IsDeleted {
|
||||
status = "archived"
|
||||
}
|
||||
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO questions (id, author_teacher_id, title, prompt, topic, subject, difficulty, source, correct_answer, status, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5::question_topic, $6, $7::question_difficulty, $8, $9, $10::question_status, $11, $12)`,
|
||||
record.ID,
|
||||
record.TeacherID,
|
||||
questionTitle(record.QuestionText),
|
||||
record.QuestionText,
|
||||
nullableTopic(record.Topic),
|
||||
nullableSubject(record.SubTopic, record.Topic),
|
||||
nullableDifficulty(record.Difficulty),
|
||||
nullableSource(record.Source),
|
||||
sharedapi.NullableText(&record.CorrectAnswer),
|
||||
status,
|
||||
msToTime(record.CreatedAt),
|
||||
msToTime(record.UpdatedAt),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func seedAssignments(ctx context.Context, tx pgx.Tx, classroomID int64, assignments []assignmentRecord) error {
|
||||
for _, assignment := range assignments {
|
||||
if assignment.IsDeleted {
|
||||
continue
|
||||
}
|
||||
|
||||
assignmentStatus := normalizeAssignmentStatus(assignment.Status)
|
||||
publishedAt := optionalPublishedAt(assignmentStatus, assignment.CreatedAt, assignment.UpdatedAt)
|
||||
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO assignments (id, teacher_id, classroom_id, title, instructions, due_at, published_at, pass_threshold, status, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::assignment_status, $10, $11)`,
|
||||
assignment.ID,
|
||||
assignment.TeacherID,
|
||||
classroomID,
|
||||
assignment.Name,
|
||||
sharedapi.NullableText(optionalString(assignmentInstructions(assignment.Topic))),
|
||||
optionalDueAt(assignment.DueDate),
|
||||
publishedAt,
|
||||
requiredNumeric(6.0),
|
||||
assignmentStatus,
|
||||
msToTime(assignment.CreatedAt),
|
||||
msToTime(assignment.UpdatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedAssignmentQuestions(ctx context.Context, tx pgx.Tx, rows []assignmentQuestionRecord) (map[int64]assignmentQuestionRef, error) {
|
||||
assignmentQuestionMap := make(map[int64]assignmentQuestionRef, len(rows))
|
||||
for _, row := range rows {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO assignment_questions (assignment_id, question_id, position)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (assignment_id, question_id) DO UPDATE
|
||||
SET position = EXCLUDED.position`, row.AssignmentID, row.QuestionBankID, row.QuestionOrder); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assignmentQuestionMap[row.ID] = assignmentQuestionRef{
|
||||
AssignmentID: row.AssignmentID,
|
||||
QuestionID: row.QuestionBankID,
|
||||
Position: row.QuestionOrder,
|
||||
}
|
||||
}
|
||||
return assignmentQuestionMap, nil
|
||||
}
|
||||
|
||||
func seedAssignmentAssignees(ctx context.Context, tx pgx.Tx, rows []assignmentAssigneeRecord) error {
|
||||
for _, row := range rows {
|
||||
assignedAt := firstValidMs(row.StartedAt, row.SubmittedAt, row.CreatedAt)
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO assignment_assignees (assignment_id, student_id, assigned_at, ai_feedback, overall_score, pass_threshold, pass_status, next_step_outcome)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::assignment_pass_status, $8::assignment_next_step_outcome)`,
|
||||
row.AssignmentID,
|
||||
row.StudentID,
|
||||
assignedAt,
|
||||
sharedapi.NullableText(row.AIFeedback),
|
||||
optionalNumeric(row.OverallScore),
|
||||
requiredNumeric(6.0),
|
||||
normalizePassStatus(row.OverallScore, 6.0),
|
||||
normalizeNextStepOutcome(row.NextStepOutcome),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedStudentAnswers(ctx context.Context, tx pgx.Tx, assignees []assignmentAssigneeRecord, assignmentQuestionMap map[int64]assignmentQuestionRef, answers []studentAnswerRecord) error {
|
||||
type assigneeRef struct {
|
||||
AssignmentID int64
|
||||
StudentID int64
|
||||
}
|
||||
|
||||
assigneeAssignment := make(map[int64]assigneeRef, len(assignees))
|
||||
for _, row := range assignees {
|
||||
assigneeAssignment[row.ID] = assigneeRef{AssignmentID: row.AssignmentID, StudentID: row.StudentID}
|
||||
}
|
||||
|
||||
for _, answer := range answers {
|
||||
questionRef, ok := assignmentQuestionMap[answer.AssignmentQuestionID]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing assignment question mapping for %d", answer.AssignmentQuestionID)
|
||||
}
|
||||
|
||||
assigneeRef, ok := assigneeAssignment[answer.AssigneeID]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing assignee mapping for %d", answer.AssigneeID)
|
||||
}
|
||||
|
||||
if assigneeRef.AssignmentID != questionRef.AssignmentID {
|
||||
return fmt.Errorf("assignment mismatch for assignee %d and assignment question %d", answer.AssigneeID, answer.AssignmentQuestionID)
|
||||
}
|
||||
|
||||
answerText := firstNonEmpty(answer.ExtractedAnswer, answer.AnswerLatex)
|
||||
answerStatus := normalizeAnswerStatus(answer.GradingStatus)
|
||||
solveMode := normalizeSolveMode(firstNonEmpty(answer.SolveMode, answer.UnderSolveMode))
|
||||
reviewedAt := optionalReviewedAt(answerStatus, answer.CreatedAt)
|
||||
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO student_answers (
|
||||
id,
|
||||
assignment_id,
|
||||
question_id,
|
||||
student_id,
|
||||
answer_text,
|
||||
ai_feedback,
|
||||
teacher_feedback,
|
||||
status,
|
||||
submitted_at,
|
||||
reviewed_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
solve_mode,
|
||||
working_steps,
|
||||
is_correct,
|
||||
review_needs_attention,
|
||||
review_issue_reason,
|
||||
review_correctness_score,
|
||||
review_understanding_score,
|
||||
review_question_score,
|
||||
review_confidence,
|
||||
review_tags
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8::answer_status, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22
|
||||
)`,
|
||||
answer.ID,
|
||||
assigneeRef.AssignmentID,
|
||||
questionRef.QuestionID,
|
||||
assigneeRef.StudentID,
|
||||
sharedapi.NullableText(answerText),
|
||||
sharedapi.NullableText(answer.AIFeedback),
|
||||
nil,
|
||||
answerStatus,
|
||||
optionalMsToTime(answer.AnsweredAt),
|
||||
reviewedAt,
|
||||
msToTime(answer.CreatedAt),
|
||||
deriveUpdatedAt(answer.CreatedAt, questionRef.Position),
|
||||
solveMode,
|
||||
sharedapi.NullableText(answer.WorkingSteps),
|
||||
sharedapi.NullableBool(firstNonNilBool(answer.IsCorrect, answer.UnderIsCorrect)),
|
||||
boolOrDefault(answer.ReviewNeedsAttention, false),
|
||||
sharedapi.NullableText(answer.ReviewIssueReason),
|
||||
optionalNumeric(answer.ReviewCorrectnessScore),
|
||||
optionalNumeric(answer.ReviewUnderstandingScore),
|
||||
optionalNumeric(answer.ReviewQuestionScore),
|
||||
optionalNumeric(answer.ReviewConfidence),
|
||||
stringsArray(answer.ReviewTags),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncSequences(ctx context.Context, tx pgx.Tx) error {
|
||||
statements := []string{
|
||||
"SELECT setval('users_id_seq', COALESCE((SELECT MAX(id) FROM users), 1))",
|
||||
"SELECT setval('classrooms_id_seq', COALESCE((SELECT MAX(id) FROM classrooms), 1))",
|
||||
"SELECT setval('questions_id_seq', COALESCE((SELECT MAX(id) FROM questions), 1))",
|
||||
"SELECT setval('tags_id_seq', COALESCE((SELECT MAX(id) FROM tags), 1))",
|
||||
"SELECT setval('assignments_id_seq', COALESCE((SELECT MAX(id) FROM assignments), 1))",
|
||||
"SELECT setval('assignment_student_questions_id_seq', COALESCE((SELECT MAX(id) FROM assignment_student_questions), 1))",
|
||||
"SELECT setval('student_answers_id_seq', COALESCE((SELECT MAX(id) FROM student_answers), 1))",
|
||||
"SELECT setval('message_threads_id_seq', COALESCE((SELECT MAX(id) FROM message_threads), 1))",
|
||||
"SELECT setval('messages_id_seq', COALESCE((SELECT MAX(id) FROM messages), 1))",
|
||||
}
|
||||
|
||||
for _, statement := range statements {
|
||||
if _, err := tx.Exec(ctx, statement); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func msToTime(value int64) time.Time {
|
||||
if value <= 0 {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
return time.UnixMilli(value).UTC()
|
||||
}
|
||||
|
||||
func optionalMsToTime(value *int64) pgtype.Timestamptz {
|
||||
if value == nil || *value <= 0 {
|
||||
return pgtype.Timestamptz{}
|
||||
}
|
||||
return pgtype.Timestamptz{Time: time.UnixMilli(*value).UTC(), Valid: true}
|
||||
}
|
||||
|
||||
func optionalDueAt(value int64) pgtype.Timestamptz {
|
||||
if value <= 0 {
|
||||
return pgtype.Timestamptz{}
|
||||
}
|
||||
return pgtype.Timestamptz{Time: time.UnixMilli(value).UTC(), Valid: true}
|
||||
}
|
||||
|
||||
func deriveUpdatedAt(createdAtMs int64, position int32) time.Time {
|
||||
created := msToTime(createdAtMs)
|
||||
if position <= 0 {
|
||||
return created
|
||||
}
|
||||
return created.Add(time.Duration(position) * time.Minute)
|
||||
}
|
||||
|
||||
func normalizeAssignmentStatus(value string) string {
|
||||
switch strings.TrimSpace(strings.ToLower(value)) {
|
||||
case "published", "assigned", "open":
|
||||
return "assigned"
|
||||
case "closed", "complete", "completed":
|
||||
return "closed"
|
||||
default:
|
||||
return "draft"
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAnswerStatus(value string) string {
|
||||
switch strings.TrimSpace(strings.ToLower(value)) {
|
||||
case "in_progress":
|
||||
return "in_progress"
|
||||
case "submitted":
|
||||
return "submitted"
|
||||
case "reviewed", "graded":
|
||||
return "reviewed"
|
||||
default:
|
||||
return "not_started"
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSolveMode(primary *string) string {
|
||||
if primary == nil {
|
||||
return "just_answer"
|
||||
}
|
||||
switch strings.TrimSpace(strings.ToLower(*primary)) {
|
||||
case "just_answer", "mental":
|
||||
return "just_answer"
|
||||
case "step_by_step", "calculator":
|
||||
return "step_by_step"
|
||||
case "solve_together":
|
||||
return "solve_together"
|
||||
case "handwritten", "written":
|
||||
return "handwritten"
|
||||
default:
|
||||
return "handwritten"
|
||||
}
|
||||
}
|
||||
|
||||
func nullableTopic(value string) any {
|
||||
return normalizeQuestionTopic(value)
|
||||
}
|
||||
|
||||
func nullableSubject(subTopic *string, topic string) any {
|
||||
if subTopic != nil {
|
||||
if trimmed := strings.TrimSpace(*subTopic); trimmed != "" {
|
||||
return sharedapi.NullableText(&trimmed)
|
||||
}
|
||||
}
|
||||
trimmed := strings.TrimSpace(topic)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
return sharedapi.NullableText(&trimmed)
|
||||
}
|
||||
|
||||
func nullableDifficulty(value string) any {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func nullableSource(value string) any {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func assignmentInstructions(topic string) string {
|
||||
trimmed := strings.TrimSpace(topic)
|
||||
if trimmed == "" {
|
||||
return "Complete each question and show your reasoning where needed."
|
||||
}
|
||||
return fmt.Sprintf("Complete each %s question and show your reasoning where needed.", strings.ReplaceAll(trimmed, "_", " "))
|
||||
}
|
||||
|
||||
func questionTitle(prompt string) string {
|
||||
trimmed := strings.TrimSpace(prompt)
|
||||
if trimmed == "" {
|
||||
return "Seeded question"
|
||||
}
|
||||
if len(trimmed) <= 60 {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:57] + "..."
|
||||
}
|
||||
|
||||
func firstName(fullName string) string {
|
||||
parts := strings.Fields(strings.TrimSpace(fullName))
|
||||
if len(parts) == 0 {
|
||||
return "there"
|
||||
}
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...*string) *string {
|
||||
for _, value := range values {
|
||||
if value == nil {
|
||||
continue
|
||||
}
|
||||
trimmed := strings.TrimSpace(*value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
copyValue := trimmed
|
||||
return ©Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func firstNonNilBool(values ...*bool) *bool {
|
||||
for _, value := range values {
|
||||
if value != nil {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func optionalNumeric(value *float64) pgtype.Numeric {
|
||||
if value == nil {
|
||||
return pgtype.Numeric{}
|
||||
}
|
||||
numeric, err := sharedapi.NullableFloat64AsNumeric(value)
|
||||
if err != nil {
|
||||
return pgtype.Numeric{}
|
||||
}
|
||||
return numeric
|
||||
}
|
||||
|
||||
func stringsArray(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeQuestionTopic(value string) any {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
mapped := map[string]string{
|
||||
"place value": "place_value",
|
||||
"place_value": "place_value",
|
||||
"arithmetic": "arithmetic",
|
||||
"negative numbers": "negative_numbers",
|
||||
"negative_numbers": "negative_numbers",
|
||||
"bidmas": "bidmas",
|
||||
"fractions": "fractions",
|
||||
"algebra": "algebra",
|
||||
"geometry": "geometry",
|
||||
"data": "data",
|
||||
}
|
||||
if normalized, ok := mapped[trimmed]; ok {
|
||||
return normalized
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizePassStatus(score *float64, threshold float64) string {
|
||||
if score == nil {
|
||||
return "pending"
|
||||
}
|
||||
if *score >= threshold {
|
||||
return "pass"
|
||||
}
|
||||
return "no_pass"
|
||||
}
|
||||
|
||||
func normalizeNextStepOutcome(value *string) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(strings.ToLower(*value))
|
||||
switch trimmed {
|
||||
case "redo", "accept", "support":
|
||||
return trimmed
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func requiredNumeric(value float64) pgtype.Numeric {
|
||||
numeric, err := sharedapi.NullableFloat64AsNumeric(&value)
|
||||
if err != nil {
|
||||
return pgtype.Numeric{}
|
||||
}
|
||||
return numeric
|
||||
}
|
||||
|
||||
func boolOrDefault(value *bool, fallback bool) bool {
|
||||
if value == nil {
|
||||
return fallback
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func firstValidMs(values ...any) time.Time {
|
||||
for _, value := range values {
|
||||
switch typed := value.(type) {
|
||||
case *int64:
|
||||
if typed != nil && *typed > 0 {
|
||||
return time.UnixMilli(*typed).UTC()
|
||||
}
|
||||
case int64:
|
||||
if typed > 0 {
|
||||
return time.UnixMilli(typed).UTC()
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func optionalPublishedAt(status string, createdAtMs, updatedAtMs int64) pgtype.Timestamptz {
|
||||
if status != "assigned" && status != "closed" {
|
||||
return pgtype.Timestamptz{}
|
||||
}
|
||||
timestamp := createdAtMs
|
||||
if updatedAtMs > 0 {
|
||||
timestamp = updatedAtMs
|
||||
}
|
||||
return optionalDueAt(timestamp)
|
||||
}
|
||||
|
||||
func optionalReviewedAt(status string, createdAtMs int64) pgtype.Timestamptz {
|
||||
if status != "reviewed" {
|
||||
return pgtype.Timestamptz{}
|
||||
}
|
||||
return pgtype.Timestamptz{Time: msToTime(createdAtMs), Valid: true}
|
||||
}
|
||||
|
||||
func optionalString(value string) *string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
return &trimmed
|
||||
}
|
||||
|
||||
func classroomDescription(name string) *string {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
description := fmt.Sprintf("Seeded classroom for %s", trimmed)
|
||||
return &description
|
||||
}
|
||||
Reference in New Issue
Block a user