-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'EN-7354-openai-matching-clean' into staging
- Loading branch information
Showing
10 changed files
with
335 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
module OpenaiServices | ||
class BasicPerformer | ||
attr_reader :configuration, :client, :callback, :assistant_id, :instance | ||
|
||
class BasicPerformerCallback < Callback | ||
end | ||
|
||
def initialize instance: | ||
@callback = BasicPerformerCallback.new | ||
|
||
@configuration = get_configuration | ||
|
||
@client = OpenAI::Client.new(access_token: @configuration.api_key) | ||
@assistant_id = @configuration.assistant_id | ||
|
||
@instance = instance | ||
end | ||
|
||
def perform | ||
yield callback if block_given? | ||
|
||
# create new thread | ||
thread = client.threads.create | ||
|
||
# create instance message | ||
message = client.messages.create(thread_id: thread['id'], parameters: user_message) | ||
|
||
# run the thread | ||
run = client.runs.create(thread_id: thread['id'], parameters: { | ||
assistant_id: assistant_id, | ||
max_prompt_tokens: configuration.max_prompt_tokens, | ||
max_completion_tokens: configuration.max_completion_tokens | ||
}) | ||
|
||
# wait for completion | ||
status = status_loop(thread['id'], run['id']) | ||
|
||
return callback.on_failure.try(:call, "Failure status #{status}") unless ['completed', 'requires_action'].include?(status) | ||
|
||
response = get_response_class.new(response: find_run_message(thread['id'], run['id'])) | ||
|
||
return callback.on_failure.try(:call, "Response not valid", response) unless response.valid? | ||
|
||
callback.on_success.try(:call, response) | ||
rescue => e | ||
callback.on_failure.try(:call, e.message, nil) | ||
end | ||
|
||
def status_loop thread_id, run_id | ||
status = nil | ||
|
||
while true do | ||
response = client.runs.retrieve(id: run_id, thread_id: thread_id) | ||
status = response['status'] | ||
|
||
break if ['completed'].include?(status) # success | ||
break if ['requires_action'].include?(status) # success | ||
break if ['cancelled', 'failed', 'expired'].include?(status) # error | ||
break if ['incomplete'].include?(status) # ??? | ||
|
||
sleep 1 if ['queued', 'in_progress', 'cancelling'].include?(status) | ||
end | ||
|
||
status | ||
end | ||
|
||
def find_run_message thread_id, run_id | ||
messages = client.messages.list(thread_id: thread_id) | ||
messages['data'].find { |message| message['run_id'] == run_id && message['role'] == 'assistant' } | ||
end | ||
|
||
private | ||
|
||
# OpenaiAssistant.find_by_version(?) | ||
def get_configuration | ||
raise NotImplementedError, "this method get_configuration has to be defined in your class" | ||
end | ||
|
||
# format: { role: string, content: { type: "text", text: string }} | ||
def user_message | ||
raise NotImplementedError, "this method user_message has to be defined in your class" | ||
end | ||
|
||
# example: MatchingResponse | ||
def get_response_class | ||
raise NotImplementedError, "this method get_response_class has to be defined in your class" | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
module OpenaiServices | ||
class MatchingPerformer < BasicPerformer | ||
attr_reader :user | ||
|
||
class MatcherCallback < Callback | ||
end | ||
|
||
def initialize instance: | ||
super(instance: instance) | ||
|
||
@user = instance.user | ||
end | ||
|
||
def get_configuration | ||
OpenaiAssistant.find_by_module_type(:matching) | ||
end | ||
|
||
def user_message | ||
{ | ||
role: "user", | ||
content: [ | ||
{ type: "text", text: get_formatted_prompt }, | ||
{ type: "text", text: get_recommandations.to_json } | ||
] | ||
} | ||
end | ||
|
||
def get_response_class | ||
MatchingResponse | ||
end | ||
|
||
private | ||
|
||
def get_formatted_prompt | ||
action_type = opposite_action_type = instance.class.name.camelize.downcase | ||
|
||
if instance.respond_to?(:action) && instance.action? | ||
action_type = instance.contribution? ? 'contribution' : 'solicitation' | ||
opposite_action_type = instance.contribution? ? 'solicitation' : 'contribution' | ||
end | ||
|
||
@configuration.prompt | ||
.gsub("{{action_type}}", action_type) | ||
.gsub("{{opposite_action_type}}", opposite_action_type) | ||
.gsub("{{name}}", instance.name) | ||
.gsub("{{description}}", instance.description) | ||
end | ||
|
||
def get_recommandations | ||
{ | ||
recommandations: | ||
get_contributions.map { |contribution| Openai::ContributionSerializer.new(contribution).as_json } + | ||
get_solicitations.map { |solicitation| Openai::SolicitationSerializer.new(solicitation).as_json } + | ||
get_outings.map { |outing| Openai::OutingSerializer.new(outing).as_json } + | ||
get_pois.map { |poi| Openai::PoiSerializer.new(poi).as_json } + | ||
get_resources.map { |resource| Openai::ResourceSerializer.new(resource).as_json } | ||
} | ||
end | ||
|
||
def get_contributions | ||
return [] if instance.is_a?(Entourage) && instance.contribution? | ||
|
||
ContributionServices::Finder.new(user, Hash.new) | ||
.find_all | ||
.where("created_at > ?", @configuration.days_for_actions.days.ago) | ||
.limit(100) | ||
end | ||
|
||
def get_solicitations | ||
return [] if instance.is_a?(Entourage) && instance.solicitation? | ||
|
||
SolicitationServices::Finder.new(user, Hash.new) | ||
.find_all | ||
.where("created_at > ?", @configuration.days_for_actions.days.ago) | ||
.limit(100) | ||
end | ||
|
||
def get_outings | ||
OutingsServices::Finder.new(user, Hash.new) | ||
.find_all | ||
.between(Time.zone.now, @configuration.days_for_outings.days.from_now) | ||
.limit(100) | ||
end | ||
|
||
def get_pois | ||
return if @configuration.poi_from_file | ||
|
||
Poi.validated.around(instance.latitude, instance.longitude, user.travel_distance).limit(300) | ||
end | ||
|
||
def get_resources | ||
return if @configuration.resource_from_file | ||
|
||
Resource.where(status: :active) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
module OpenaiServices | ||
# response example | ||
# {"recommandations"=> | ||
# [{ | ||
# "type"=>"resource", | ||
# "id"=>"e8bWJqPHAcxY", | ||
# "name"=>"Sophie : les portraits des bénévoles", | ||
# "score"=>"0.96", | ||
# "explanation"=>"Ce ressource présente des histoires de bénévoles et peut vous inspirer pour obtenir de l'aide." | ||
# }] | ||
# } | ||
|
||
MatchingResponse = Struct.new(:response) do | ||
TYPES = %w{contribution solicitation outing resource poi} | ||
|
||
def initialize(response: nil) | ||
@response = response | ||
@parsed_response = parsed_response | ||
end | ||
|
||
def valid? | ||
recommandations.any? | ||
end | ||
|
||
def parsed_response | ||
return unless @response | ||
return unless content = @response["content"] | ||
return unless content.any? && first_content = content[0] | ||
return unless first_content["type"] == "text" | ||
return unless value = first_content["text"]["value"]&.gsub("\n", "") | ||
return unless json = value[/\{.*\}/m] | ||
|
||
JSON.parse(json) | ||
end | ||
|
||
def to_json | ||
@response.to_json | ||
end | ||
|
||
def recommandations | ||
return [] unless @parsed_response | ||
|
||
@parsed_response["recommandations"] | ||
end | ||
|
||
def metadata | ||
{ | ||
message_id: @response["id"], | ||
assistant_id: @response["assistant_id"], | ||
thread_id: @response["thread_id"], | ||
run_id: @response["run_id"] | ||
} | ||
end | ||
|
||
def best_recommandation | ||
each_recommandation do |instance, score, explanation, index| | ||
return { | ||
instance: instance, | ||
score: score, | ||
explanation: explanation, | ||
index: index, | ||
} | ||
end | ||
end | ||
|
||
def each_recommandation &block | ||
recommandations.each_with_index do |recommandation, index| | ||
next unless recommandation["id"] | ||
next unless TYPES.include?(recommandation["type"]) | ||
next unless instance = recommandation["type"].classify.constantize.find_by_id(recommandation["id"]) | ||
|
||
yield(instance, recommandation["score"], recommandation["explanation"], index) | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 6 additions & 0 deletions
6
db/migrate/20241216160300_add_token_lengths_to_openai_assistants.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
class AddTokenLengthsToOpenaiAssistants < ActiveRecord::Migration[6.1] | ||
def change | ||
add_column :openai_assistants, :max_prompt_tokens, :integer, default: 1024*1024 | ||
add_column :openai_assistants, :max_completion_tokens, :integer, default: 1024 | ||
end | ||
end |
5 changes: 5 additions & 0 deletions
5
db/migrate/20241216160301_add_module_type_to_openai_assistants.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
class AddModuleTypeToOpenaiAssistants < ActiveRecord::Migration[6.1] | ||
def change | ||
add_column :openai_assistants, :module_type, :string, default: :matching | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
require 'rails_helper' | ||
|
||
describe OpenaiServices::MatchingResponse do | ||
describe '#parsed_response' do | ||
subject { described_class.new(response: response).parsed_response } | ||
|
||
context 'when response is nil' do | ||
let(:response) { nil } | ||
|
||
it { expect(subject).to be_nil } | ||
end | ||
|
||
context 'when response has no content' do | ||
let(:response) { { "content" => [] } } | ||
|
||
it { expect(subject).to be_nil } | ||
end | ||
|
||
context 'when content type is not "text"' do | ||
let(:response) { { "content" => [{ "type" => "image", "text" => { "value" => '{"key":"value"}' } }] } } | ||
|
||
it { expect(subject).to be_nil } | ||
end | ||
|
||
context 'when content type is "text" but text is malformed' do | ||
let(:response) { { "content" => [{ "type" => "text", "text" => { "value" => 'invalid json' } }] } } | ||
|
||
it { expect(subject).to be_nil } | ||
end | ||
|
||
context 'when content type is "text" and text contains valid JSON' do | ||
let(:response) { { "content" => [{ "type" => "text", "text" => { "value" => '{"key":"value"}' } }] } } | ||
|
||
it { expect(subject).to eq({ "key" => "value" }) } | ||
end | ||
|
||
context 'when text contains extraneous data but valid JSON inside' do | ||
let(:response) { { "content" => [{ "type" => "text", "text" => { "value" => 'Random text before ```json{"key":"value"}``` and after' } }] } } | ||
|
||
it { expect(subject).to eq({ "key" => "value" }) } | ||
end | ||
end | ||
end |