Skip to content

Commit

Permalink
A way to output simulated qa in a format compatible with eval (#34479)
Browse files Browse the repository at this point in the history
* Changed the parameter from max_count to limit

* Add a method to have output in qa format from simulator

* simulation_result_limit was missing from main
  • Loading branch information
nagkumar91 authored Feb 28, 2024
1 parent 6452024 commit a7868e9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,19 @@ def to_json_lines(self):
for item in self:
json_lines += json.dumps(item) + "\n"
return json_lines

def to_eval_qa_json_lines(self):
json_lines = ""
for item in self:
user_message = None
assistant_message = None
for message in item['messages']:
if message['role'] == 'user':
user_message = message['content']
elif message['role'] == 'assistant':
assistant_message = message['content']
if user_message and assistant_message:
json_lines += json.dumps({'question': user_message, 'answer': assistant_message}) + "\n"
return json_lines


Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ async def simulate_async(
api_call_retry_limit: int = 3,
api_call_retry_sleep_sec: int = 1,
api_call_delay_sec: float = 0,
concurrent_async_task: int = 3
concurrent_async_task: int = 3,
simulation_result_limit: int = 3,
):
"""Asynchronously simulate conversations using the provided template and parameters
Expand All @@ -241,6 +242,8 @@ async def simulate_async(
:paramtype api_call_delay_sec: float, optional
:keyword concurrent_async_task: The maximum number of asynchronous tasks to run concurrently. Defaults to 3.
:paramtype concurrent_async_task: int, optional
:keyword simulation_result_limit: The maximum number of simulation results to return. Defaults to 3.
:paramtype simulation_result_limit: int, optional
:return: A list of dictionaries containing the simulation results.
:rtype: List[Dict]
Expand Down Expand Up @@ -271,7 +274,6 @@ async def simulate_async(
semaphore = asyncio.Semaphore(concurrent_async_task)
sim_results = []
tasks = []

for t in templates:
for p in t.template_parameters:
if jailbreak:
Expand All @@ -294,6 +296,12 @@ async def simulate_async(
)
)

if len(tasks) >= simulation_result_limit:
break

if len(tasks) >= simulation_result_limit:
break

sim_results = await asyncio.gather(*tasks)

return JsonLineList(sim_results)
Expand Down Expand Up @@ -324,6 +332,8 @@ async def _simulate_async(
api_call_delay_sec (float, optional): The time in seconds to wait between API calls. Defaults to 0.
concurrent_async_task (int, optional): The maximum number of asynchronous tasks to run concurrently.
Defaults to 3.
simulation_result_limit (int, optional): The maximum number of simulation results to return. Defaults to 3.
Returns:
List[Dict]: A list of dictionaries containing the simulation results.
Expand Down

0 comments on commit a7868e9

Please sign in to comment.