diff --git a/lib/graphql/dataloader.rb b/lib/graphql/dataloader.rb index b419f534c4..b2e72d7a91 100644 --- a/lib/graphql/dataloader.rb +++ b/lib/graphql/dataloader.rb @@ -24,18 +24,23 @@ module GraphQL # class Dataloader class << self - attr_accessor :default_nonblocking + attr_accessor :default_nonblocking, :default_fiber_limit end - NonblockingDataloader = Class.new(self) { self.default_nonblocking = true } - - def self.use(schema, nonblocking: nil) - schema.dataloader_class = if nonblocking + def self.use(schema, nonblocking: nil, fiber_limit: nil) + dataloader_class = if nonblocking warn("`nonblocking: true` is deprecated from `GraphQL::Dataloader`, please use `GraphQL::Dataloader::AsyncDataloader` instead. Docs: https://graphql-ruby.org/dataloader/async_dataloader.") - NonblockingDataloader + Class.new(self) { self.default_nonblocking = true } else self end + + if fiber_limit + dataloader_class = Class.new(dataloader_class) + dataloader_class.default_fiber_limit = fiber_limit + end + + schema.dataloader_class = dataloader_class end # Call the block with a Dataloader instance, @@ -50,14 +55,18 @@ def self.with_dataloading(&block) result end - def initialize(nonblocking: self.class.default_nonblocking) + def initialize(nonblocking: self.class.default_nonblocking, fiber_limit: self.class.default_fiber_limit) @source_cache = Hash.new { |h, k| h[k] = {} } @pending_jobs = [] if !nonblocking.nil? @nonblocking = nonblocking end + @fiber_limit = fiber_limit end + # @return [Integer, nil] + attr_reader :fiber_limit + def nonblocking? @nonblocking end @@ -178,6 +187,7 @@ def run_isolated end def run + jobs_fiber_limit, total_fiber_limit = calculate_fiber_limit job_fibers = [] next_job_fibers = [] source_fibers = [] @@ -187,7 +197,7 @@ def run while first_pass || job_fibers.any? first_pass = false - while (f = (job_fibers.shift || spawn_job_fiber)) + while (f = (job_fibers.shift || (((next_job_fibers.size + job_fibers.size) < jobs_fiber_limit) && spawn_job_fiber))) if f.alive? finished = run_fiber(f) if !finished @@ -197,8 +207,8 @@ def run end join_queues(job_fibers, next_job_fibers) - while source_fibers.any? || @source_cache.each_value.any? { |group_sources| group_sources.each_value.any?(&:pending?) } - while (f = source_fibers.shift || spawn_source_fiber) + while (source_fibers.any? || @source_cache.each_value.any? { |group_sources| group_sources.each_value.any?(&:pending?) }) + while (f = source_fibers.shift || (((job_fibers.size + source_fibers.size + next_source_fibers.size + next_job_fibers.size) < total_fiber_limit) && spawn_source_fiber)) if f.alive? finished = run_fiber(f) if !finished @@ -242,6 +252,17 @@ def spawn_fiber private + def calculate_fiber_limit + total_fiber_limit = @fiber_limit || Float::INFINITY + if total_fiber_limit < 4 + raise ArgumentError, "Dataloader fiber limit is too low (#{total_fiber_limit}), it must be at least 4" + end + total_fiber_limit -= 1 # deduct one fiber for `manager` + # Deduct at least one fiber for sources + jobs_fiber_limit = total_fiber_limit - 2 + return jobs_fiber_limit, total_fiber_limit + end + def join_queues(prev_queue, new_queue) @nonblocking && Fiber.scheduler.run prev_queue.concat(new_queue) diff --git a/lib/graphql/dataloader/async_dataloader.rb b/lib/graphql/dataloader/async_dataloader.rb index 07b4f3184a..e8d730eb74 100644 --- a/lib/graphql/dataloader/async_dataloader.rb +++ b/lib/graphql/dataloader/async_dataloader.rb @@ -12,6 +12,7 @@ def yield end def run + jobs_fiber_limit, total_fiber_limit = calculate_fiber_limit job_fibers = [] next_job_fibers = [] source_tasks = [] @@ -23,7 +24,7 @@ def run first_pass = false fiber_vars = get_fiber_variables - while (f = (job_fibers.shift || spawn_job_fiber)) + while (f = (job_fibers.shift || (((job_fibers.size + next_job_fibers.size + source_tasks.size) < jobs_fiber_limit) && spawn_job_fiber))) if f.alive? finished = run_fiber(f) if !finished @@ -37,7 +38,7 @@ def run Sync do |root_task| set_fiber_variables(fiber_vars) while source_tasks.any? || @source_cache.each_value.any? { |group_sources| group_sources.each_value.any?(&:pending?) } - while (task = source_tasks.shift || spawn_source_task(root_task, sources_condition)) + while (task = (source_tasks.shift || (((job_fibers.size + next_job_fibers.size + source_tasks.size + next_source_tasks.size) < total_fiber_limit) && spawn_source_task(root_task, sources_condition)))) if task.alive? root_task.yield # give the source task a chance to run next_source_tasks << task diff --git a/lib/graphql/dataloader/source.rb b/lib/graphql/dataloader/source.rb index 12ec8785e9..d70db0a2a4 100644 --- a/lib/graphql/dataloader/source.rb +++ b/lib/graphql/dataloader/source.rb @@ -98,7 +98,7 @@ def sync(pending_result_keys) while pending_result_keys.any? { |key| !@results.key?(key) } iterations += 1 if iterations > MAX_ITERATIONS - raise "#{self.class}#sync tried #{MAX_ITERATIONS} times to load pending keys (#{pending_result_keys}), but they still weren't loaded. There is likely a circular dependency." + raise "#{self.class}#sync tried #{MAX_ITERATIONS} times to load pending keys (#{pending_result_keys}), but they still weren't loaded. There is likely a circular dependency#{@dataloader.fiber_limit ? " or `fiber_limit: #{@dataloader.fiber_limit}` is set too low" : ""}." end @dataloader.yield end diff --git a/spec/graphql/dataloader/nonblocking_dataloader_spec.rb b/spec/graphql/dataloader/nonblocking_dataloader_spec.rb index d2eceb1e6c..3e3aec77dc 100644 --- a/spec/graphql/dataloader/nonblocking_dataloader_spec.rb +++ b/spec/graphql/dataloader/nonblocking_dataloader_spec.rb @@ -2,7 +2,7 @@ require "spec_helper" if Fiber.respond_to?(:scheduler) # Ruby 3+ - describe GraphQL::Dataloader::NonblockingDataloader do + describe "GraphQL::Dataloader::NonblockingDataloader" do class NonblockingSchema < GraphQL::Schema class SleepSource < GraphQL::Dataloader::Source def fetch(keys) @@ -84,7 +84,7 @@ def wait_for(tag:, wait:) end query(Query) - use GraphQL::Dataloader::NonblockingDataloader + use GraphQL::Dataloader, nonblocking: true end def with_scheduler @@ -99,7 +99,7 @@ def self.included(child_class) child_class.class_eval do it "runs IO in parallel by default" do - dataloader = GraphQL::Dataloader::NonblockingDataloader.new + dataloader = GraphQL::Dataloader.new(nonblocking: true) results = {} dataloader.append_job { sleep(0.1); results[:a] = 1 } dataloader.append_job { sleep(0.2); results[:b] = 2 } @@ -115,7 +115,7 @@ def self.included(child_class) end it "works with sources" do - dataloader = GraphQL::Dataloader::NonblockingDataloader.new + dataloader = GraphQL::Dataloader.new(nonblocking: true) r1 = dataloader.with(NonblockingSchema::SleepSource).request(0.1) r2 = dataloader.with(NonblockingSchema::SleepSource).request(0.2) r3 = dataloader.with(NonblockingSchema::SleepSource).request(0.3) diff --git a/spec/graphql/dataloader/source_spec.rb b/spec/graphql/dataloader/source_spec.rb index f9825d2b41..1006e18fff 100644 --- a/spec/graphql/dataloader/source_spec.rb +++ b/spec/graphql/dataloader/source_spec.rb @@ -16,6 +16,14 @@ def fetch(keys) end expected_message = "FailsToLoadSource#sync tried 1000 times to load pending keys ([1]), but they still weren't loaded. There is likely a circular dependency." assert_equal expected_message, err.message + + dl = GraphQL::Dataloader.new(fiber_limit: 10000) + dl.append_job { dl.with(FailsToLoadSource).load(1) } + err = assert_raises RuntimeError do + dl.run + end + expected_message = "FailsToLoadSource#sync tried 1000 times to load pending keys ([1]), but they still weren't loaded. There is likely a circular dependency or `fiber_limit: 10000` is set too low." + assert_equal expected_message, err.message end it "is pending when waiting for false and nil" do diff --git a/spec/graphql/dataloader_spec.rb b/spec/graphql/dataloader_spec.rb index 25f0fbe952..8cf6fc2763 100644 --- a/spec/graphql/dataloader_spec.rb +++ b/spec/graphql/dataloader_spec.rb @@ -515,6 +515,52 @@ class Query < GraphQL::Schema::Object end module DataloaderAssertions + module FiberCounting + class << self + attr_accessor :starting_fiber_count, :last_spawn_fiber_count, :last_max_fiber_count + + def current_fiber_count + count_active_fibers - starting_fiber_count + end + + def count_active_fibers + GC.start + ObjectSpace.each_object(Fiber).count + end + end + + def initialize(*args, **kwargs, &block) + super + FiberCounting.starting_fiber_count = FiberCounting.count_active_fibers + FiberCounting.last_max_fiber_count = 0 + FiberCounting.last_spawn_fiber_count = 0 + end + + def spawn_fiber + result = super + update_fiber_counts + result + end + + def spawn_source_task(parent_task, condition) + result = super + if result + update_fiber_counts + end + result + end + + private + + def update_fiber_counts + FiberCounting.last_spawn_fiber_count += 1 + current_count = FiberCounting.current_fiber_count + if current_count > FiberCounting.last_max_fiber_count + FiberCounting.last_max_fiber_count = current_count + end + end + end + def self.included(child_class) child_class.class_eval do let(:schema) { make_schema_from(FiberSchema) } @@ -1038,6 +1084,92 @@ def self.included(child_class) response = parts_schema.execute(query).to_h assert_equal [4, 4, 4, 4], response["data"]["manufacturers"].map { |parts_obj| parts_obj["parts"].size } end + + describe "fiber_limit" do + def assert_last_max_fiber_count(expected_last_max_fiber_count) + if schema.dataloader_class == GraphQL::Dataloader::AsyncDataloader && FiberCounting.last_max_fiber_count == (expected_last_max_fiber_count + 1) + # TODO why does this happen sometimes? + warn "AsyncDataloader had +1 last_max_fiber_count" + assert_equal (expected_last_max_fiber_count + 1), FiberCounting.last_max_fiber_count + else + assert_equal expected_last_max_fiber_count, FiberCounting.last_max_fiber_count + end + end + + it "respects a configured fiber_limit" do + query_str = <<-GRAPHQL + { + recipes { + ingredients { + name + } + } + nestedIngredient(id: 2) { + name + } + keyIngredient(id: 4) { + name + } + commonIngredientsWithLoad(recipe1Id: 5, recipe2Id: 6) { + name + } + } + GRAPHQL + + fiber_counting_dataloader_class = Class.new(schema.dataloader_class) + fiber_counting_dataloader_class.include(FiberCounting) + + res = schema.execute(query_str, context: { dataloader: fiber_counting_dataloader_class.new }) + assert_nil res.context.dataloader.fiber_limit + assert_equal 12, FiberCounting.last_spawn_fiber_count + assert_last_max_fiber_count(9) + + res = schema.execute(query_str, context: { dataloader: fiber_counting_dataloader_class.new(fiber_limit: 4) }) + assert_equal 4, res.context.dataloader.fiber_limit + assert_equal 14, FiberCounting.last_spawn_fiber_count + assert_last_max_fiber_count(4) + + res = schema.execute(query_str, context: { dataloader: fiber_counting_dataloader_class.new(fiber_limit: 6) }) + assert_equal 6, res.context.dataloader.fiber_limit + assert_equal 10, FiberCounting.last_spawn_fiber_count + assert_last_max_fiber_count(6) + end + + it "accepts a default fiber_limit config" do + schema = Class.new(FiberSchema) do + use GraphQL::Dataloader, fiber_limit: 4 + end + query_str = <<-GRAPHQL + { + recipes { + ingredients { + name + } + } + nestedIngredient(id: 2) { + name + } + keyIngredient(id: 4) { + name + } + commonIngredientsWithLoad(recipe1Id: 5, recipe2Id: 6) { + name + } + } + GRAPHQL + res = schema.execute(query_str) + assert_equal 4, res.context.dataloader.fiber_limit + assert_nil res["errors"] + end + + it "requires at least three fibers" do + dl = GraphQL::Dataloader.new(fiber_limit: 2) + err = assert_raises ArgumentError do + dl.run + end + assert_equal "Dataloader fiber limit is too low (2), it must be at least 4", err.message + end + end end end end