Before Fine Tune

This commit is contained in:
MangoPig
2026-05-26 13:43:09 +01:00
parent 4f79137d89
commit f29aff25f5
35 changed files with 6953 additions and 142 deletions

View File

@@ -11,6 +11,13 @@ import (
"time"
)
type requestStyle string
const (
requestStyleResponses requestStyle = "responses"
requestStyleChatCompletions requestStyle = "chat_completions"
)
type Service struct {
endpoint string
apiKey string
@@ -109,14 +116,9 @@ func (s *Service) ReviewSubmission(ctx context.Context, input AssignmentReviewIn
return nil, fmt.Errorf("marshal AI review input: %w", err)
}
body := map[string]any{
"model": s.model,
"input": []map[string]any{
{
"role": "system",
"content": []map[string]any{{
"type": "input_text",
"text": strings.TrimSpace(`You are reviewing student homework submissions for a teacher workflow.
outputText, err := s.runStructuredRequest(
ctx,
strings.TrimSpace(`You are reviewing student homework submissions for a teacher workflow.
You must assess the student's understanding by looking at the student's final answer and working against the saved correct answer when one is available. Do not re-grade weighting.
@@ -142,56 +144,11 @@ Interpretation guidance:
- support = the student shows meaningful gaps and likely needs targeted help.
- redo = the student should redo the assignment because understanding is broadly too weak or incomplete.
Review the full assignment in one pass and produce a short assignment-level summary.`),
}},
},
{
"role": "user",
"content": []map[string]any{{
"type": "input_text",
"text": string(payloadJSON),
}},
},
},
"text": map[string]any{
"format": map[string]any{
"type": "json_schema",
"name": "assignment_review",
"strict": true,
"schema": reviewSchema(),
},
},
}
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal AI review request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("build AI review request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", s.apiKey)
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send AI review request: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read AI review response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("AI review request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
}
outputText, err := extractOutputText(respBytes)
Review the full assignment in one pass and produce a short assignment-level summary.`),
string(payloadJSON),
"assignment_review",
reviewSchema(),
)
if err != nil {
return nil, err
}
@@ -215,14 +172,9 @@ func (s *Service) PlanRedoAssignment(ctx context.Context, input RedoPlanInput) (
return nil, fmt.Errorf("marshal redo plan input: %w", err)
}
body := map[string]any{
"model": s.model,
"input": []map[string]any{
{
"role": "system",
"content": []map[string]any{{
"type": "input_text",
"text": strings.TrimSpace(`You are planning the next redo assignment for a student.
outputText, err := s.runStructuredRequest(
ctx,
strings.TrimSpace(`You are planning the next redo assignment for a student.
You are NOT writing final math questions. You are only producing a structured topic+difficulty blueprint for a later generator layer.
@@ -238,55 +190,10 @@ Rules:
- reason on each item should briefly explain why that topic/difficulty belongs in the redo set.
- Do not invent topics outside the allowed topic vocabulary.
- Do not output prose outside the JSON schema.`),
}},
},
{
"role": "user",
"content": []map[string]any{{
"type": "input_text",
"text": string(payloadJSON),
}},
},
},
"text": map[string]any{
"format": map[string]any{
"type": "json_schema",
"name": "redo_assignment_plan",
"strict": true,
"schema": redoPlanSchema(),
},
},
}
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal redo plan request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("build redo plan request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", s.apiKey)
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send redo plan request: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read redo plan response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("redo plan request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
}
outputText, err := extractOutputText(respBytes)
string(payloadJSON),
"redo_assignment_plan",
redoPlanSchema(),
)
if err != nil {
return nil, err
}
@@ -300,6 +207,146 @@ Rules:
return &result, nil
}
func (s *Service) runStructuredRequest(ctx context.Context, systemPrompt, userPrompt, schemaName string, schema map[string]any) (string, error) {
respBytes, err := s.sendStructuredRequest(ctx, systemPrompt, userPrompt, schemaName, schema)
if err != nil {
return "", err
}
switch s.requestStyle() {
case requestStyleChatCompletions:
return extractChatCompletionText(respBytes)
default:
return extractResponsesOutputText(respBytes)
}
}
func (s *Service) sendStructuredRequest(ctx context.Context, systemPrompt, userPrompt, schemaName string, schema map[string]any) ([]byte, error) {
body, err := s.buildStructuredRequestBody(systemPrompt, userPrompt, schemaName, schema)
if err != nil {
return nil, err
}
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal AI review request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("build AI review request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
s.applyAuthHeader(req)
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send AI review request: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read AI review response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("AI review request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
}
return respBytes, nil
}
func (s *Service) buildStructuredRequestBody(systemPrompt, userPrompt, schemaName string, schema map[string]any) (map[string]any, error) {
switch s.requestStyle() {
case requestStyleChatCompletions:
body := map[string]any{
"model": s.model,
"messages": []map[string]any{
{
"role": "system",
"content": systemPrompt,
},
{
"role": "user",
"content": userPrompt,
},
},
"temperature": 0,
"response_format": map[string]any{
"type": "json_schema",
"json_schema": map[string]any{
"name": schemaName,
"strict": true,
"schema": schema,
},
},
}
if s.shouldDisableThinking() {
body["chat_template_kwargs"] = map[string]any{
"enable_thinking": false,
}
}
return body, nil
default:
return map[string]any{
"model": s.model,
"input": []map[string]any{
{
"role": "system",
"content": []map[string]any{{
"type": "input_text",
"text": systemPrompt,
}},
},
{
"role": "user",
"content": []map[string]any{{
"type": "input_text",
"text": userPrompt,
}},
},
},
"text": map[string]any{
"format": map[string]any{
"type": "json_schema",
"name": schemaName,
"strict": true,
"schema": schema,
},
},
}, nil
}
}
func (s *Service) requestStyle() requestStyle {
endpoint := strings.ToLower(strings.TrimSpace(s.endpoint))
if strings.Contains(endpoint, "/chat/completions") {
return requestStyleChatCompletions
}
return requestStyleResponses
}
func (s *Service) applyAuthHeader(req *http.Request) {
if s.isAzureEndpoint() {
req.Header.Set("api-key", s.apiKey)
return
}
req.Header.Set("Authorization", "Bearer "+s.apiKey)
}
func (s *Service) shouldDisableThinking() bool {
return s.requestStyle() == requestStyleChatCompletions && !s.isAzureEndpoint()
}
func (s *Service) isAzureEndpoint() bool {
endpoint := strings.ToLower(strings.TrimSpace(s.endpoint))
return strings.Contains(endpoint, "cognitiveservices.azure.com") || strings.Contains(endpoint, ".openai.azure.com")
}
func reviewSchema() map[string]any {
return map[string]any{
"type": "object",
@@ -358,7 +405,7 @@ func redoPlanSchema() map[string]any {
}
}
func extractOutputText(respBytes []byte) (string, error) {
func extractResponsesOutputText(respBytes []byte) (string, error) {
var direct struct {
OutputText string `json:"output_text"`
}
@@ -380,6 +427,47 @@ func extractOutputText(respBytes []byte) (string, error) {
return "", fmt.Errorf("AI review response did not contain structured output text")
}
func extractChatCompletionText(respBytes []byte) (string, error) {
var payload struct {
Choices []struct {
Message struct {
Content any `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(respBytes, &payload); err != nil {
return "", fmt.Errorf("decode AI review chat completion response: %w", err)
}
for _, choice := range payload.Choices {
if text := strings.TrimSpace(extractMessageContent(choice.Message.Content)); text != "" {
return text, nil
}
}
return "", fmt.Errorf("AI review chat completion response did not contain message content")
}
func extractMessageContent(content any) string {
switch typed := content.(type) {
case string:
return typed
case []any:
parts := make([]string, 0, len(typed))
for _, item := range typed {
text := strings.TrimSpace(findOutputText(item))
if text == "" {
continue
}
parts = append(parts, text)
}
return strings.Join(parts, "\n")
default:
return ""
}
}
func findOutputText(value any) string {
switch typed := value.(type) {
case map[string]any: