discourse/plugins/discourse-ai/lib/mcp/oauth_flow.rb
Sam cbb63ef66a
FEATURE: Add advanced OAuth options for MCP servers (#38913)
Adds three new configurable fields to MCP server OAuth:

- `oauth_authorization_params` — JSON object merged into authorization
  requests (e.g. `{"access_type":"offline"}` for Google APIs)
- `oauth_token_params` — JSON object merged into token exchange and
  refresh requests (e.g. `{"audience":"..."}` for resource indicators)
- `oauth_require_refresh_token` — fails OAuth if the provider does not
  return a refresh token, surfacing misconfiguration early

The OAuth flow is also improved in several ways:

- Reads `token_endpoint_auth_methods_supported` from discovery metadata
  and negotiates the correct client authentication method
  (client_secret_basic, client_secret_post, or none)
- Validates client registration requirements before starting the flow,
  giving actionable error messages when dynamic registration is
  unavailable
- Null values in custom params remove default parameters, allowing
  overrides like removing the `resource` indicator

Additionally, the MCP client now passes through tool result errors
(isError: true) instead of raising exceptions, so the AI can see
and reason about tool-level failures.
2026-04-01 08:47:23 +11:00

394 lines
14 KiB
Ruby
Vendored

# frozen_string_literal: true
require "base64"
require "openssl"
module DiscourseAi
module Mcp
class OAuthFlow
STATE_TTL = 10.minutes
OAuthError =
Class.new(StandardError) do
attr_reader :server
def initialize(message = nil, server: nil)
@server = server
super(message)
end
end
class << self
def start!(server:, user:)
validate_local_oauth_urls!(server)
discovery = OAuthDiscovery.discover!(server)
server.store_oauth_discovery!(discovery)
if server.oauth_client_registration != "manual" && server.oauth_client_id.blank? &&
discovery.registration_endpoint.present?
OAuthClientRegistration.register!(server: server, discovery: discovery)
server.reload
end
validate_client_registration!(server, discovery)
token_endpoint_client_auth_method(server, discovery)
state = SecureRandom.hex(32)
code_verifier = generate_code_verifier
Rails.cache.write(
state_cache_key(state),
{
"ai_mcp_server_id" => server.id,
"user_id" => user.id,
"code_verifier" => code_verifier,
},
expires_in: STATE_TTL,
)
build_authorization_url(
server: server,
discovery: discovery,
state: state,
code_verifier: code_verifier,
)
end
def complete!(params:, current_user:)
state = params[:state].to_s
payload = Rails.cache.read(state_cache_key(state))
Rails.cache.delete(state_cache_key(state))
if payload.blank? || payload["user_id"].to_i != current_user.id
raise DiscourseAi::Mcp::Client::Error,
I18n.t("discourse_ai.mcp_servers.errors.oauth_state_invalid")
end
server = AiMcpServer.find(payload["ai_mcp_server_id"])
if params[:error].present?
raise DiscourseAi::Mcp::Client::Error,
params[:error_description].presence || params[:error].to_s.tr("_", " ")
end
discovery = server.oauth_discovery_result || OAuthDiscovery.discover!(server)
token_payload =
token_request(
server: server,
endpoint: discovery.token_endpoint,
auth_method: token_endpoint_client_auth_method(server, discovery),
params:
merged_oauth_params(
{
"grant_type" => "authorization_code",
"code" => params[:code].to_s,
"redirect_uri" => server.oauth_callback_url,
"code_verifier" => payload["code_verifier"],
"client_id" => server.effective_oauth_client_id,
"resource" => discovery.resource.presence || server.url,
},
server.oauth_token_params,
),
)
store_token_response!(server, token_payload, preserve_refresh_token: false)
server.mark_oauth_authorized!
DiscourseAi::Mcp::ToolRegistry.invalidate!(server.id)
AiAgent.agent_cache.flush!
server
rescue StandardError => e
server&.mark_oauth_error!(e.message)
raise OAuthError.new(e.message, server: server), cause: e
end
def refresh!(server)
discovery = server.oauth_discovery_result || OAuthDiscovery.discover!(server)
refresh_token = server.oauth_token_store.refresh_token
if refresh_token.blank?
raise DiscourseAi::Mcp::Client::Error,
I18n.t("discourse_ai.mcp_servers.errors.oauth_refresh_token_missing")
end
token_payload =
token_request(
server: server,
endpoint: discovery.token_endpoint,
auth_method: token_endpoint_client_auth_method(server, discovery),
params:
merged_oauth_params(
{
"grant_type" => "refresh_token",
"refresh_token" => refresh_token,
"client_id" => server.effective_oauth_client_id,
"resource" => discovery.resource.presence || server.url,
},
server.oauth_token_params,
),
)
store_token_response!(server, token_payload, preserve_refresh_token: true)
token_payload["access_token"]
rescue StandardError => e
server.mark_oauth_refresh_failed!(e.message)
raise
end
def disconnect!(server)
discovery = server.oauth_discovery_result
if discovery&.revocation_endpoint.present?
revoke_token(
server,
discovery.revocation_endpoint,
server.oauth_token_store.access_token,
)
revoke_token(
server,
discovery.revocation_endpoint,
server.oauth_token_store.refresh_token,
)
end
server.clear_oauth_credentials!
DiscourseAi::Mcp::ToolRegistry.invalidate!(server.id)
AiAgent.agent_cache.flush!
end
private
def build_authorization_url(server:, discovery:, state:, code_verifier:)
code_challenge =
Base64.urlsafe_encode64(Digest::SHA256.digest(code_verifier)).tr("+/", "-_").delete("=")
query =
merged_oauth_params(
{
"response_type" => "code",
"client_id" => server.effective_oauth_client_id,
"redirect_uri" => server.oauth_callback_url,
"state" => state,
"code_challenge" => code_challenge,
"code_challenge_method" => "S256",
"resource" => discovery.resource.presence || server.url,
"scope" => server.oauth_scopes.presence,
},
server.oauth_authorization_params,
)
uri = URI(discovery.authorization_endpoint)
uri.query = Rack::Utils.build_query(query)
uri.to_s
end
def token_request(server:, endpoint:, auth_method:, params:)
connection =
Faraday.new(request: { timeout: server.timeout_seconds }) do |builder|
builder.request :url_encoded
builder.adapter FinalDestination::FaradayAdapter
end
validate_endpoint!(endpoint)
headers = { "Accept" => "application/json" }
request_params = params.deep_dup
case auth_method
when "client_secret_basic"
headers["Authorization"] = basic_auth_header(
server.effective_oauth_client_id,
server.oauth_client_secret_value,
)
when "client_secret_post"
request_params["client_secret"] = server.oauth_client_secret_value
end
response = connection.post(endpoint, request_params, headers)
if response.status != 200
message =
begin
JSON.parse(response.body).dig("error_description")
rescue StandardError
nil
end
raise DiscourseAi::Mcp::Client::Error,
message.presence ||
I18n.t(
"discourse_ai.mcp_servers.errors.oauth_token_exchange_failed",
status: response.status,
)
end
JSON.parse(response.body)
rescue JSON::ParserError
raise DiscourseAi::Mcp::Client::Error,
I18n.t("discourse_ai.mcp_servers.errors.invalid_response")
end
def store_token_response!(server, token_payload, preserve_refresh_token:)
refresh_token =
if preserve_refresh_token
token_payload["refresh_token"].presence || server.oauth_token_store.refresh_token
else
token_payload["refresh_token"]
end
if server.oauth_require_refresh_token && refresh_token.blank?
raise DiscourseAi::Mcp::Client::Error,
I18n.t("discourse_ai.mcp_servers.errors.oauth_refresh_token_required")
end
server.update_oauth_tokens!(
access_token: token_payload["access_token"],
refresh_token: refresh_token,
token_type: token_payload["token_type"],
expires_in: token_payload["expires_in"],
granted_scopes: token_payload["scope"],
)
end
def revoke_token(server, endpoint, token)
return if token.blank?
validate_endpoint!(endpoint)
connection =
Faraday.new(request: { timeout: server.timeout_seconds }) do |builder|
builder.request :url_encoded
builder.adapter FinalDestination::FaradayAdapter
end
headers = { "Accept" => "application/json" }
if server.oauth_client_secret_value.present?
headers["Authorization"] = basic_auth_header(
server.effective_oauth_client_id,
server.oauth_client_secret_value,
)
end
connection.post(
endpoint,
{ token: token, client_id: server.effective_oauth_client_id },
headers,
)
rescue StandardError => e
Rails.logger.warn(
"Discourse AI MCP OAuth revoke failed for server #{server.id}: #{e.message}",
)
end
def basic_auth_header(client_id, client_secret)
"Basic #{Base64.strict_encode64("#{client_id}:#{client_secret}")}"
end
def merged_oauth_params(defaults, overrides)
defaults.merge(overrides.to_h.stringify_keys).compact
rescue NoMethodError, TypeError
defaults.compact
end
def generate_code_verifier
Base64.urlsafe_encode64(OpenSSL::Random.random_bytes(32)).delete("=")
end
def state_cache_key(state)
"discourse-ai:mcp-oauth-state:#{state}"
end
def validate_endpoint!(url)
uri = AiMcpServer.parse_public_uri(url)
if uri.nil?
raise DiscourseAi::Mcp::Client::Error,
I18n.t("discourse_ai.mcp_servers.invalid_url_not_https")
end
AiMcpServer.validate_hostname_public!(uri.hostname)
rescue FinalDestination::SSRFError, SocketError, URI::InvalidURIError
raise DiscourseAi::Mcp::Client::Error,
I18n.t("discourse_ai.mcp_servers.invalid_url_not_reachable")
end
def validate_local_oauth_urls!(server)
callback_uri = URI.parse(server.oauth_callback_url)
if callback_uri.scheme != "https"
raise DiscourseAi::Mcp::Client::Error,
I18n.t("discourse_ai.mcp_servers.errors.oauth_https_required")
end
return if server.oauth_client_registration == "manual"
metadata_url = server.oauth_client_metadata_url
metadata_uri = AiMcpServer.parse_public_uri(metadata_url)
if metadata_uri.nil?
raise DiscourseAi::Mcp::Client::Error,
I18n.t(
"discourse_ai.mcp_servers.errors.oauth_client_metadata_public_https_required",
url: metadata_url,
)
end
AiMcpServer.validate_hostname_public!(metadata_uri.hostname)
rescue FinalDestination::SSRFError, SocketError, URI::InvalidURIError
raise DiscourseAi::Mcp::Client::Error,
I18n.t(
"discourse_ai.mcp_servers.errors.oauth_client_metadata_public_https_required",
url: server.oauth_client_metadata_url,
)
end
def validate_client_registration!(server, discovery)
return if server.oauth_client_id.present?
return if server.oauth_client_registration == "manual"
return if discovery.registration_endpoint.present?
raise DiscourseAi::Mcp::Client::Error,
I18n.t(
"discourse_ai.mcp_servers.errors.oauth_manual_client_registration_required",
issuer: oauth_issuer_label(discovery, server),
)
end
def token_endpoint_client_auth_method(server, discovery)
methods = Array(discovery.token_endpoint_auth_methods_supported).map(&:to_s).presence
if methods.blank?
return server.oauth_client_secret_value.present? ? "client_secret_basic" : "none"
end
if server.oauth_client_secret_value.present?
return "client_secret_basic" if methods.include?("client_secret_basic")
return "client_secret_post" if methods.include?("client_secret_post")
return "none" if methods.include?("none")
raise unsupported_token_endpoint_client_auth_method_error(server, discovery, methods)
end
return "none" if methods.include?("none")
if methods.any? { |method| %w[client_secret_basic client_secret_post].include?(method) }
raise DiscourseAi::Mcp::Client::Error,
I18n.t(
"discourse_ai.mcp_servers.errors.oauth_client_secret_required",
issuer: oauth_issuer_label(discovery, server),
methods: methods.join(", "),
)
end
raise unsupported_token_endpoint_client_auth_method_error(server, discovery, methods)
end
def unsupported_token_endpoint_client_auth_method_error(server, discovery, methods)
DiscourseAi::Mcp::Client::Error.new(
I18n.t(
"discourse_ai.mcp_servers.errors.oauth_token_endpoint_auth_method_unsupported",
issuer: oauth_issuer_label(discovery, server),
methods: methods.join(", "),
),
)
end
def oauth_issuer_label(discovery, server)
discovery.issuer.presence || discovery.authorization_endpoint.presence || server.url
end
end
end
end
end