diff --git a/config/initializers/008-rack-cors.rb b/config/initializers/008-rack-cors.rb index 2ba7b01fb3f..689cdd3e2ea 100644 --- a/config/initializers/008-rack-cors.rb +++ b/config/initializers/008-rack-cors.rb @@ -1,41 +1,52 @@ -if GlobalSetting.enable_cors - class Discourse::Cors - def initialize(app, options = nil) - @app = app - if GlobalSetting.enable_cors && GlobalSetting.cors_origin.present? - @global_origins = GlobalSetting.cors_origin.split(',').map(&:strip) - end - end +# frozen_string_literal: true - def call(env) - if env['REQUEST_METHOD'] == ('OPTIONS') && env['HTTP_ACCESS_CONTROL_REQUEST_METHOD'] - return [200, apply_headers(env), []] - end +class Discourse::Cors + ORIGINS_ENV = "Discourse_Cors_Origins" - status, headers, body = @app.call(env) - [status, apply_headers(env, headers), body] - end - - def apply_headers(env, headers = nil) - headers ||= {} - - origin = nil - cors_origins = @global_origins || [] - cors_origins += SiteSetting.cors_origins.split('|') if SiteSetting.cors_origins - - if cors_origins - if origin = env['HTTP_ORIGIN'] - origin = nil unless cors_origins.include?(origin) - end - - headers['Access-Control-Allow-Origin'] = origin || cors_origins[0] - headers['Access-Control-Allow-Headers'] = 'X-Requested-With, X-CSRF-Token, Discourse-Visible' - headers['Access-Control-Allow-Credentials'] = 'true' - end - - headers + def initialize(app, options = nil) + @app = app + if GlobalSetting.enable_cors && GlobalSetting.cors_origin.present? + @global_origins = GlobalSetting.cors_origin.split(',').map(&:strip) end end + def call(env) + + cors_origins = @global_origins || [] + cors_origins += SiteSetting.cors_origins.split('|') if SiteSetting.cors_origins.present? + cors_origins = cors_origins.presence + + if env['REQUEST_METHOD'] == ('OPTIONS') && env['HTTP_ACCESS_CONTROL_REQUEST_METHOD'] + return [200, Discourse::Cors.apply_headers(cors_origins, env, {}), []] + end + + env[Discourse::Cors::ORIGINS_ENV] = cors_origins if cors_origins + + status, headers, body = @app.call(env) + headers ||= {} + + Discourse::Cors.apply_headers(cors_origins, env, headers) if cors_origins + + [status, headers, body] + end + + def self.apply_headers(cors_origins, env, headers) + origin = nil + + if cors_origins + if origin = env['HTTP_ORIGIN'] + origin = nil unless cors_origins.include?(origin) + end + + headers['Access-Control-Allow-Origin'] = origin || cors_origins[0] + headers['Access-Control-Allow-Headers'] = 'X-Requested-With, X-CSRF-Token, Discourse-Visible' + headers['Access-Control-Allow-Credentials'] = 'true' + end + + headers + end +end + +if GlobalSetting.enable_cors Rails.configuration.middleware.insert_before ActionDispatch::Flash, Discourse::Cors end diff --git a/lib/hijack.rb b/lib/hijack.rb index c2722e957a4..026077dbb61 100644 --- a/lib/hijack.rb +++ b/lib/hijack.rb @@ -55,6 +55,11 @@ module Hijack body = response.body headers = response.headers + # add cors if needed + if cors_origins = env_copy[Discourse::Cors::ORIGINS_ENV] + Discourse::Cors.apply_headers(cors_origins, env_copy, headers) + end + headers['Content-Length'] = body.bytesize headers['Content-Type'] = response.content_type || "text/plain" headers['Connection'] = "close" diff --git a/spec/components/hijack_spec.rb b/spec/components/hijack_spec.rb index 64a5ec5b054..3a9522ab9b6 100644 --- a/spec/components/hijack_spec.rb +++ b/spec/components/hijack_spec.rb @@ -79,6 +79,41 @@ describe Hijack do expect(copy_req.object_id).not_to eq(orig_req.object_id) end + it "handles cors" do + SiteSetting.cors_origins = "www.rainbows.com" + + app = lambda do |env| + tester = Hijack::Tester.new(env) + tester.hijack_test do + render body: "hello", status: 201 + end + + expect(tester.io.string).to include("Access-Control-Allow-Origin: www.rainbows.com") + end + + env = {} + middleware = Discourse::Cors.new(app) + middleware.call(env) + + # it can do pre-flight + env = { + 'REQUEST_METHOD' => 'OPTIONS', + 'HTTP_ACCESS_CONTROL_REQUEST_METHOD' => 'GET' + } + + status, headers, _body = middleware.call(env) + + expect(status).to eq(200) + + expected = { + "Access-Control-Allow-Origin" => "www.rainbows.com", + "Access-Control-Allow-Headers" => "X-Requested-With, X-CSRF-Token, Discourse-Visible", + "Access-Control-Allow-Credentials" => "true" + } + + expect(headers).to eq(expected) + end + it "handles expires_in" do tester.hijack_test do expires_in 1.year