Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/smoke-claude.lock.yml

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions .github/workflows/smoke-claude.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ permissions:
actions: read

name: Smoke Claude
models:
disallowed: ["*opus*"]
max-turns: 100
engine:
id: claude
Expand Down
80 changes: 67 additions & 13 deletions pkg/parser/import_field_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type importAccumulator struct {
caches []string
features []map[string]any
models []map[string][]string // model alias maps from each imported file (appended in import order)
modelPolicies []map[string][]string // model policy sets from each imported file (appended in import order)
modelCosts []map[string]any // model pricing overlays from each imported file (appended in import order)
runInstallScripts bool // true if any imported workflow sets runtimes.node.run-install-scripts: true
agentFile string
Expand Down Expand Up @@ -89,6 +90,11 @@ type importAccumulator struct {
warnings []string
}

const (
modelPolicyAllowedKey = "allowed"
modelPolicyDisallowedKey = "disallowed"
)

// newImportAccumulator creates and initializes a new importAccumulator.
// Maps (botsSet, etc.) are explicitly initialized to prevent nil map panics
// during deduplication. Slices are left as nil, which is valid for append operations.
Expand Down Expand Up @@ -621,41 +627,88 @@ func (acc *importAccumulator) appendModelsField(fm map[string]any) {
if jsonErr := json.Unmarshal([]byte(modelsContent), &rawModels); jsonErr != nil {
return
}
if _, hasProviders := rawModels["providers"]; hasProviders {
acc.modelCosts = append(acc.modelCosts, rawModels)
if providers, ok := rawModels["providers"].(map[string]any); ok {
parserLog.Printf("Extracted model costs from import: providers=%d", len(providers))
if modelPolicy := normalizeModelPolicies(rawModels); len(modelPolicy) > 0 {
acc.modelPolicies = append(acc.modelPolicies, modelPolicy)
parserLog.Printf("Extracted model policy from import: allowed=%d, disallowed=%d", len(modelPolicy["allowed"]), len(modelPolicy["disallowed"]))
}
if providers, hasProviders := rawModels["providers"]; hasProviders {
acc.modelCosts = append(acc.modelCosts, map[string]any{"providers": providers})
if providerMap, ok := providers.(map[string]any); ok {
parserLog.Printf("Extracted model costs from import: providers=%d", len(providerMap))
} else {
parserLog.Printf("Extracted model costs from import")
}
return
}

modelsMap := normalizeModelAliases(rawModels)
aliasModels := make(map[string]any, len(rawModels))
for key, value := range rawModels {
if isModelPolicyKey(key) {
continue
}
aliasModels[key] = value
}
if len(aliasModels) == 0 {
return
}
modelsMap := normalizeModelAliases(aliasModels)
if len(modelsMap) > 0 {
acc.models = append(acc.models, modelsMap)
parserLog.Printf("Extracted model aliases from import: %d entries", len(modelsMap))
}
}

func normalizeModelPolicies(rawModels map[string]any) map[string][]string {
parse := func(key string) []string {
return parseStringSliceField(rawModels[key], false)
}
allowed := parse(modelPolicyAllowedKey)
disallowed := parse(modelPolicyDisallowedKey)
if len(allowed) == 0 && len(disallowed) == 0 {
return nil
}
return map[string][]string{
modelPolicyAllowedKey: allowed,
modelPolicyDisallowedKey: disallowed,
}
}

func normalizeModelAliases(rawModels map[string]any) map[string][]string {
modelsMap := make(map[string][]string, len(rawModels))
for k, v := range rawModels {
patterns, ok := v.([]any)
if !ok {
strs := parseStringSliceField(v, true)
if len(strs) == 0 {
continue
}
strs := make([]string, 0, len(patterns))
for _, p := range patterns {
if s, ok := p.(string); ok {
strs = append(strs, s)
}
}
modelsMap[k] = strs
}
return modelsMap
}

func parseStringSliceField(value any, keepEmpty bool) []string {
values, ok := value.([]any)
if !ok {
return nil
}
result := make([]string, 0, len(values))
for _, v := range values {
if s, ok := v.(string); ok {
if s == "" && !keepEmpty {
continue
}
result = append(result, s)
}
}
if len(result) == 0 {
return nil
}
return result
}

func isModelPolicyKey(key string) bool {
return key == modelPolicyAllowedKey || key == modelPolicyDisallowedKey
}

func (acc *importAccumulator) extractRunInstallScripts(fm map[string]any, fullPath string) {
if acc.runInstallScripts {
return
Expand Down Expand Up @@ -737,6 +790,7 @@ func (acc *importAccumulator) toImportsResult(topologicalOrder []string) *Import
MergedEnvSources: acc.envSources,
MergedFeatures: acc.features,
MergedModels: acc.models,
MergedModelPolicies: acc.modelPolicies,
MergedModelCosts: acc.modelCosts,
MergedObservability: mergeObservabilityConfigs(acc.observabilityConfigs),
ImportedFiles: topologicalOrder,
Expand Down
47 changes: 47 additions & 0 deletions pkg/parser/import_field_extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,50 @@ func TestExtractConfigFields_FirstWinsAndAccumulates(t *testing.T) {
assert.Contains(t, acc.secretMaskingBuilder.String(), "enabled")
assert.Contains(t, acc.secretMaskingBuilder.String(), "log-mask")
}

func TestAppendModelsField_ExtractsModelPolicySets(t *testing.T) {
acc := newImportAccumulator()
fm := map[string]any{
"models": map[string]any{
"allowed": []any{"gpt-5", "claude-sonnet"},
"disallowed": []any{"gpt-5-pro"},
},
}

acc.appendModelsField(fm)

require.Len(t, acc.modelPolicies, 1, "expected one model policy set")
assert.Equal(t, []string{"gpt-5", "claude-sonnet"}, acc.modelPolicies[0]["allowed"])
assert.Equal(t, []string{"gpt-5-pro"}, acc.modelPolicies[0]["disallowed"])
assert.Empty(t, acc.models, "policy fields should not be interpreted as model aliases")
}

func TestAppendModelsField_ExtractsModelCostsAndPolicyTogether(t *testing.T) {
acc := newImportAccumulator()
fm := map[string]any{
"models": map[string]any{
"allowed": []any{"gpt-5-mini"},
"providers": map[string]any{
"openai": map[string]any{
"models": map[string]any{
"gpt-5-mini": map[string]any{
"cost": map[string]any{"input": "1e-6"},
},
},
},
},
},
}

acc.appendModelsField(fm)

require.Len(t, acc.modelCosts, 1, "expected one model cost overlay")
require.Len(t, acc.modelPolicies, 1, "expected one model policy set")
assert.Equal(t, []string{"gpt-5-mini"}, acc.modelPolicies[0]["allowed"])
assert.Contains(t, acc.modelCosts[0], "providers")
assert.Len(t, acc.modelCosts[0], 1)
for _, key := range []string{"allowed", "disallowed"} {
_, present := acc.modelCosts[0][key]
assert.Falsef(t, present, "model cost overlay should not contain policy key %q", key)
}
}
1 change: 1 addition & 0 deletions pkg/parser/import_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type ImportsResult struct {
MergedEnvSources map[string]string // env var name → source import path (for conflict detection and lock file header listing)
MergedFeatures []map[string]any // Merged features configuration from all imports (parsed YAML structures)
MergedModels []map[string][]string // Merged model alias definitions from all imports (first import to define a key wins among imports)
MergedModelPolicies []map[string][]string // Merged model policy sets from all imports (models.allowed/disallowed)
MergedModelCosts []map[string]any // Merged model pricing overlays (models.json provider structure) from all imports
MergedObservability string // Merged observability config (JSON) from all imports as an endpoint array (deduped by URL)
MergedEngineMCPToolTimeout string // First engine.mcp.tool-timeout found across all imports (Go duration string, e.g. "10m")
Expand Down
17 changes: 15 additions & 2 deletions pkg/parser/schemas/main_workflow_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2693,10 +2693,23 @@
]
},
"models": {
"description": "Custom model pricing data in the same structure as models.json. Merged with the built-in models.json at runtime; frontmatter entries override matching models and fill gaps for unknown models. Useful for custom or private models, or to adjust pricing for AI Credits cost accounting.",
"description": "Model policy and optional pricing configuration. The policy fields (allowed/disallowed) are merged as unions across imports. The providers field is optional and supplies pricing data merged by provider/model key.",
"type": "object",
"required": ["providers"],
"properties": {
"allowed": {
"type": "array",
"description": "Allowlist of model names/patterns. Mapped to AWF apiProxy.allowedModels.",
"items": {
"type": "string"
}
},
"disallowed": {
"type": "array",
"description": "Denylist of model names/patterns. Mapped to AWF apiProxy.disallowedModels.",
"items": {
"type": "string"
}
},
"providers": {
"type": "object",
"description": "Provider-keyed map of model pricing data.",
Expand Down
33 changes: 33 additions & 0 deletions pkg/workflow/awf_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import (
"github.com/github/gh-aw/pkg/jsonutil"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/setutil"
"github.com/github/gh-aw/pkg/workflow/compilerenv"
)

//go:embed schemas/awf-config.schema.json
Expand Down Expand Up @@ -242,6 +243,11 @@ type AWFAPIProxyConfig struct {
// AWF resolves aliases recursively; loops are not permitted.
// Per the AWF config schema, this lives under apiProxy.models.
Models map[string][]string `json:"models,omitempty"`

// AllowedModels is the explicit allowlist policy for model names/patterns.
AllowedModels []string `json:"allowedModels,omitempty"`
// DisallowedModels is the explicit denylist policy for model names/patterns.
DisallowedModels []string `json:"disallowedModels,omitempty"`
}

// AWFModelFallbackConfig is the "apiProxy.modelFallback" section of the AWF config file.
Expand Down Expand Up @@ -492,6 +498,15 @@ func BuildAWFConfigJSON(config AWFCommandConfig) (string, error) {
apiProxy.Models = config.WorkflowData.ModelMappings
awfConfigLog.Printf("Models section: %d alias entries", len(config.WorkflowData.ModelMappings))
}
allowedModels, disallowedModels := resolveModelPolicyForAWFConfig(config.WorkflowData)
if len(allowedModels) > 0 {
apiProxy.AllowedModels = allowedModels
awfConfigLog.Printf("Models policy: %d allowed model pattern(s)", len(allowedModels))
}
if len(disallowedModels) > 0 {
apiProxy.DisallowedModels = disallowedModels
awfConfigLog.Printf("Models policy: %d disallowed model pattern(s)", len(disallowedModels))
}

awfConfig.APIProxy = apiProxy

Expand Down Expand Up @@ -550,6 +565,24 @@ func splitDomainList(domains string) []string {
return result
}

func resolveModelPolicyForAWFConfig(workflowData *WorkflowData) ([]string, []string) {
envAllowed, hasAllowedOverride := compilerenv.ResolvePolicyModelsAllowed()
envDisallowed, hasDisallowedOverride := compilerenv.ResolvePolicyModelsDisallowed()
var allowed []string
var disallowed []string
if hasAllowedOverride {
allowed = envAllowed
} else if workflowData != nil {
allowed = workflowData.ModelPolicyAllowed
}
if hasDisallowedOverride {
disallowed = envDisallowed
} else if workflowData != nil {
disallowed = workflowData.ModelPolicyDisallowed
}
return allowed, disallowed
}

func extractModelMultipliers(workflowData *WorkflowData) map[string]float64 {
if workflowData == nil || workflowData.EngineConfig == nil || workflowData.EngineConfig.TokenWeights == nil {
return nil
Expand Down
45 changes: 45 additions & 0 deletions pkg/workflow/awf_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1619,3 +1619,48 @@ func TestBuildAWFTopologyAttachList(t *testing.T) {
assert.Equal(t, []string{"awmg-mcpg", "awmg-cli-proxy"}, targets)
})
}

func TestBuildAWFConfigJSON_EmitsModelPolicyFromWorkflowData(t *testing.T) {
config := AWFCommandConfig{
EngineName: "copilot",
AllowedDomains: "github.com",
WorkflowData: &WorkflowData{
EngineConfig: &EngineConfig{ID: "copilot"},
NetworkPermissions: &NetworkPermissions{
Firewall: &FirewallConfig{Enabled: true},
},
ModelPolicyAllowed: []string{"gpt-5", "claude-sonnet"},
ModelPolicyDisallowed: []string{"gpt-5-pro", "claude-opus"},
},
}

jsonStr, err := BuildAWFConfigJSON(config)
require.NoError(t, err)
assert.Contains(t, jsonStr, `"allowedModels":["gpt-5","claude-sonnet"]`)
assert.Contains(t, jsonStr, `"disallowedModels":["gpt-5-pro","claude-opus"]`)
}

func TestBuildAWFConfigJSON_ModelPolicyEnvOverridePrecedence(t *testing.T) {
t.Setenv(compilerenv.PolicyModelsAllowed, "gemini-pro,gpt-5-mini")
t.Setenv(compilerenv.PolicyModelsDisallowed, "claude-opus, gpt-5-pro")

config := AWFCommandConfig{
EngineName: "copilot",
AllowedDomains: "github.com",
WorkflowData: &WorkflowData{
EngineConfig: &EngineConfig{ID: "copilot"},
NetworkPermissions: &NetworkPermissions{
Firewall: &FirewallConfig{Enabled: true},
},
ModelPolicyAllowed: []string{"frontmatter-allowed"},
ModelPolicyDisallowed: []string{"frontmatter-disallowed"},
},
}

jsonStr, err := BuildAWFConfigJSON(config)
require.NoError(t, err)
assert.Contains(t, jsonStr, `"allowedModels":["gemini-pro","gpt-5-mini"]`)
assert.Contains(t, jsonStr, `"disallowedModels":["claude-opus","gpt-5-pro"]`)
assert.NotContains(t, jsonStr, "frontmatter-allowed")
assert.NotContains(t, jsonStr, "frontmatter-disallowed")
}
2 changes: 2 additions & 0 deletions pkg/workflow/compiler_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,8 @@ type WorkflowData struct {
KnownActionCredentialEnvVars map[string]struct{} // env vars for clean_known_action_credentials.sh; keyed by GH_AW_CLEAN_* names; nil when no known credential-leaking actions are detected
ModelMappings map[string][]string // merged model alias map (builtins + imported workflow aliases + main frontmatter overrides, in priority order); NOT yet emitted to AWF config JSON — pending AWF firewall support (config.models)
ModelCosts map[string]any // model pricing data from frontmatter `models` field (providers structure); merged with built-in models.json at runtime by generate_aw_info.cjs
ModelPolicyAllowed []string // merged models.allowed policy list (union across imports + main frontmatter)
ModelPolicyDisallowed []string // merged models.disallowed policy list (union across imports + main frontmatter)
ActionPinMappings map[string]string // action-pin redirect table from aw.json action_pins: maps "owner/repo@version" → "owner/repo@version"
}

Expand Down
Loading
Loading