Sample Program Walkthrough
The DatabaseChatProgram
is a sample program that illustrates how to generate and run queries.
Class and Variables
Class Declaration
public class DatabaseChatProgram {
The main class of the program.
Variables
private static Waii waii;
private static String dbConnectionKey;
private static List<Tweak> tweaks = new ArrayList<Tweak>();
private static String parentUuid = null;
waii
: An instance of the Waii SDK, used to interact with the Waii API.dbConnectionKey
: The key of the currently active database connection.tweaks
: A list ofTweak
objects representing modifications to queries.parentUuid
: The UUID of the parent query, used for keeping track of query history.
Main Method
Main Method Declaration
public static void main(String[] args) throws InterruptedException {
The entry point of the program.
Initializing Waii SDK
waii = new Waii("http://sql.waii.ai/api/", "<your-api-key>");
Initializes the Waii SDK with the API URL and key.
User Input for Database Connection
Scanner scanner = new Scanner(System.in);
System.out.print("Create new connection? (y/n): ");
String yesno = scanner.nextLine();
Prompts the user to decide whether to create a new database connection or use an existing one.
Creating or Activating a Connection
if (yesno.equals("y")) {
// Code to create a new connection
createAndActivateDbConnection(accountName, username, password, database, warehouse, role);
} else {
// Code to activate an existing connection
activateExistingConnection();
}
Based on user input, either creates a new database connection or activates an existing one.
Main Loop for User Questions
while (true) {
System.out.print("Enter your question: ");
String question = scanner.nextLine();
if (question.equalsIgnoreCase("exit")) {
break;
}
GeneratedQuery generatedQuery = generateQuery(question);
GetQueryResultResponse queryResponse = runQuery(generatedQuery);
// Print the result set
printResultSet(queryResponse);
}
Continuously reads questions from the user, generates and runs queries, and prints the results until the user types "exit".
Helper Methods
printConnections
private static void printConnections(DBConnection[] conns) {
int i = 0;
for (DBConnection conn: conns) {
System.out.println("" + i + " " + conn.getDatabase());
i++;
}
}
Prints the list of available database connections.
activateExistingConnection
private static void activateExistingConnection() throws IOException {
GetDBConnectionResponse response = waii.getDatabase().getConnections(new GetDBConnectionRequest());
DBConnection[] connections = response.getConnectors();
printConnections(connections);
Scanner scanner = new Scanner(System.in);
System.out.print("Enter connection number: ");
String number = scanner.nextLine();
DBConnection active = connections[Integer.parseInt(number)];
dbConnectionKey = active.getKey();
waii.getDatabase().activateConnection(dbConnectionKey);
}
Fetches and prints all database connections, and activates a connection based on user input.
findConnection
private static DBConnection findConnection(DBConnection db) throws IOException {
GetDBConnectionResponse response = waii.getDatabase().getConnections(new GetDBConnectionRequest());
for (DBConnection db2 : response.getConnectors()) {
if (db.equals(db2)) {
return db2;
}
}
return null;
}
Finds and returns a matching database connection, or null if not found.
createAndActivateDbConnection
private static void createAndActivateDbConnection(String accountName, String username, String password, String database, String warehouse, String role) throws IOException, InterruptedException {
DBConnection dbConnection = new DBConnection(null, "snowflake", "Snowflake connection", accountName, username, password, database, warehouse, role, null, null, null, null, true, null);
DBConnection dbConnection2 = findConnection(dbConnection);
if (dbConnection2 == null) {
ModifyDBConnectionRequest modifyRequest = new ModifyDBConnectionRequest();
modifyRequest.setUpdated(new DBConnection[]{dbConnection});
ModifyDBConnectionResponse modifyResponse = waii.getDatabase().modifyConnections(modifyRequest);
modifyResponse.getConnectors();
}
dbConnection = findConnection(dbConnection);
if (dbConnection == null) {
throw new IOException("failed to create connection.");
}
dbConnectionKey = dbConnection.getKey();
waii.getDatabase().activateConnection(dbConnectionKey);
while (true) {
DBConnectionIndexingStatus indexingStatus = getIndexingStatus(waii, dbConnectionKey);
if (indexingStatus != null && isIndexingComplete(indexingStatus)) {
break;
}
System.out.println("Indexing in progress... Please wait.");
Thread.sleep(5000);
}
}
Creates a new database connection, activates it, and waits for indexing to complete.
generateQuery
private static GeneratedQuery generateQuery(String question) throws IOException {
QueryGenerationRequest request = new QueryGenerationRequest().setAsk(question).setTweakHistory((Tweak[])tweaks.toArray(new Tweak[0])).setParentUuid(parentUuid);
GeneratedQuery generatedQuery = waii.getQuery().generate(request);
if (generatedQuery.getIsNew()) {
tweaks = new ArrayList<Tweak>();
}
Tweak tweak = new Tweak().setAsk(question).setSql(generatedQuery.getQuery());
tweaks.add(tweak);
parentUuid = generatedQuery.getUuid();
return generatedQuery;
}
Generates a SQL query based on the user's question and the current query history. Sending a tweak history gives the query generation context, sending a parentUuid will make requests part of the same session.
runQuery
private static GetQueryResultResponse runQuery(GeneratedQuery generatedQuery) throws IOException {
RunQueryRequest runRequest = new RunQueryRequest();
runRequest.setQuery(generatedQuery.getQuery());
RunQueryResponse runResponse = waii.getQuery().submit(runRequest);
GetQueryResultResponse resultResponse = waii.getQuery().getResults(new GetQueryResultRequest().setQueryId(runResponse.getQueryId()));
return resultResponse;
}
Runs the generated query and retrieves the query results.
printResultSet
private static void printResultSet(GetQueryResultResponse response) {
if (response.getRows().isEmpty()) {
System.out.println("No results found.");
return;
}
Map<String, Object> firstRow = response.getRows().get(0);
String[] headers = firstRow.keySet().toArray(new String[0]);
int[] columnWidths = new int[headers.length];
for (int i = 0; i < headers.length; i++) {
columnWidths[i] = headers[i].length();
}
for (Map<String, Object> row : response.getRows()) {
int i = 0;
for (Map.Entry<String, Object> entry : row.entrySet()) {
int valueLength = entry.getValue().toString().length();
if (valueLength > columnWidths[i]) {
columnWidths[i] = valueLength;
}
i++;
}
}
StringBuilder headerRow = new StringBuilder();
for (int i = 0; i < headers.length; i++) {
headerRow.append(String.format("%-" + columnWidths[i] + "s", headers[i]));
if (i < headers.length - 1) {
headerRow.append(" | ");
}
}
System.out.println(headerRow.toString());
StringBuilder separatorRow = new StringBuilder();
for (int columnWidth : columnWidths) {
separatorRow.append("-".repeat(columnWidth)).append("-+-");
}
System.out.println(separatorRow.substring(0, separatorRow.length() - 3));
for (Map<String, Object> row : response.getRows()) {
StringBuilder rowOutput = new StringBuilder();
int i = 0;
for (Map.Entry<String, Object> entry : row.entrySet()) {
rowOutput.append(String.format("%-" + columnWidths[i] + "s", entry.getValue().toString()));
if (i < headers.length - 1) {
rowOutput.append(" | ");
}
i++;
}
System.out.println(rowOutput.toString());
}
}
Formats and prints the query results in a tabular format.
getIndexingStatus
private static DBConnectionIndexingStatus getIndexingStatus(Waii waii, String connectionKey) throws IOException {
GetDBConnectionRequest getRequest = new GetDBConnectionRequest();
GetDBConnectionResponse response = waii.getDatabase().getConnections(getRequest);
for (DBConnection connection : response.getConnectors()) {
if (connection.getKey().equals(connectionKey)) {
return response.getConnectorStatus().get(connectionKey);
}
}
return null;
}
Fetches the indexing status of a database connection. When you add a new database a knowledge graph is generated for this database. The "indexing" status tells you if this process is complete or not.
isIndexingComplete
private static boolean isIndexingComplete(DBConnectionIndexingStatus status) {
for (SchemaIndexingStatus schemaStatus : status.getSchemaStatus().values()) {
if (schemaStatus.getPendingIndexingTables() > 0) {
return false;
}
}
return true;
}
Checks if the indexing of the database connection is complete.
Here's the full program for reference:
package ai.waii.chat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import com.google.gson.Gson;
import ai.waii.Waii;
import ai.waii.clients.database.DBConnection;
import ai.waii.clients.database.DBConnectionIndexingStatus;
import ai.waii.clients.database.GetDBConnectionRequest;
import ai.waii.clients.database.GetDBConnectionResponse;
import ai.waii.clients.database.ModifyDBConnectionRequest;
import ai.waii.clients.database.ModifyDBConnectionResponse;
import ai.waii.clients.database.SchemaIndexingStatus;
import ai.waii.clients.query.GeneratedQuery;
import ai.waii.clients.query.GetQueryResultRequest;
import ai.waii.clients.query.GetQueryResultResponse;
import ai.waii.clients.query.QueryGenerationRequest;
import ai.waii.clients.query.RunQueryRequest;
import ai.waii.clients.query.RunQueryResponse;
import ai.waii.clients.query.Tweak;
public class DatabaseChatProgram {
private static Waii waii;
private static String dbConnectionKey;
private static List<Tweak> tweaks = new ArrayList<Tweak>();
private static String parentUuid = null;
public static void main(String[] args) throws InterruptedException {
// Initialize Waii SDK with URL and API key
waii = new Waii("http://localhost:9859/api/", "your-api-key");
try {
// Input database connection info from user
Scanner scanner = new Scanner(System.in);
System.out.print("Create new connection? (y/n): ");
String yesno = scanner.nextLine();
if (yesno.equals("y")) {
System.out.println("Enter Snowflake connection details:");
System.out.print("Account Name: ");
String accountName = scanner.nextLine();
System.out.print("Username: ");
String username = scanner.nextLine();
System.out.print("Password: ");
String password = scanner.nextLine();
System.out.print("Database: ");
String database = scanner.nextLine();
System.out.print("Warehouse: ");
String warehouse = scanner.nextLine();
System.out.print("Role: ");
String role = scanner.nextLine();
// Create and activate database connection
createAndActivateDbConnection(accountName, username, password, database, warehouse, role);
} else {
activateExistingConnection();
}
// Main loop to read questions from user, generate and run queries, and print results
while (true) {
System.out.print("Enter your question: ");
String question = scanner.nextLine();
if (question.equalsIgnoreCase("exit")) {
break;
}
GeneratedQuery generatedQuery = generateQuery(question);
GetQueryResultResponse queryResponse = runQuery(generatedQuery);
// Print the result set
printResultSet(queryResponse);
}
} catch (IOException e) {
e.printStackTrace();
}
}
private static void printConnections(DBConnection[] conns) {
int i = 0;
for (DBConnection conn: conns) {
System.out.println(""+i+" "+conn.getDatabase());
i++;
}
}
private static void activateExistingConnection() throws IOException {
// fetch all connections
GetDBConnectionResponse response = waii.getDatabase().getConnections(new GetDBConnectionRequest());
DBConnection[] connections = response.getConnectors();
printConnections(connections);
// Input database connection info from user
Scanner scanner = new Scanner(System.in);
System.out.print("Enter connection number: ");
String number = scanner.nextLine();
DBConnection active = connections[Integer.parseInt(number)];
dbConnectionKey = active.getKey();
waii.getDatabase().activateConnection(dbConnectionKey);
}
private static DBConnection findConnection(DBConnection db) throws IOException {
// fetch all connections
GetDBConnectionResponse response = waii.getDatabase().getConnections(new GetDBConnectionRequest());
// find right connection
for (DBConnection db2 : response.getConnectors()) {
System.out.println("checking: " + new Gson().toJson(db2));
System.out.println("");
if (db.equals(db2)) {
System.out.println("found");
return db2;
}
}
// nothing found
return null;
}
private static void createAndActivateDbConnection(String accountName, String username, String password,
String database, String warehouse, String role) throws IOException, InterruptedException {
// Create DBConnection object
DBConnection dbConnection = new DBConnection(null, "snowflake", "Snowflake connection",
accountName, username, password, database, warehouse, role, null, null, null, null, true, null);
DBConnection dbConnection2 = findConnection(dbConnection);
if (dbConnection2 == null) {
// Create ModifyDBConnectionRequest object
ModifyDBConnectionRequest modifyRequest = new ModifyDBConnectionRequest();
modifyRequest.setUpdated(new DBConnection[]{dbConnection});
// Modify and activate database connection
ModifyDBConnectionResponse modifyResponse = waii.getDatabase().modifyConnections(modifyRequest);
modifyResponse.getConnectors();
}
dbConnection = findConnection(dbConnection);
if (dbConnection == null) {
// couldn't create the connection.
throw new IOException("failed to create connection.");
}
dbConnectionKey = dbConnection.getKey();
waii.getDatabase().activateConnection(dbConnectionKey);
while (true) {
DBConnectionIndexingStatus indexingStatus = getIndexingStatus(waii, dbConnectionKey);
if (indexingStatus != null && isIndexingComplete(indexingStatus)) {
break;
}
System.out.println("Indexing in progress... Please wait.");
Thread.sleep(5000);
}
}
private static GeneratedQuery generateQuery(String question) throws IOException {
// Create QueryGenerationRequest object
QueryGenerationRequest request
= new QueryGenerationRequest()
.setAsk(question)
.setTweakHistory((Tweak[])tweaks.toArray(new Tweak[0]))
.setParentUuid(parentUuid);
// Generate query
GeneratedQuery generatedQuery = waii.getQuery().generate(request);
// remember state for future requests
if (generatedQuery.getIsNew()) {
tweaks = new ArrayList<Tweak>();
}
Tweak tweak = new Tweak().setAsk(question).setSql(generatedQuery.getQuery());
tweaks.add(tweak);
parentUuid = generatedQuery.getUuid();
return generatedQuery;
}
private static GetQueryResultResponse runQuery(GeneratedQuery generatedQuery) throws IOException {
// Create RunQueryRequest object
RunQueryRequest runRequest = new RunQueryRequest();
runRequest.setQuery(generatedQuery.getQuery());
// Submit and get query result
RunQueryResponse runResponse = waii.getQuery().submit(runRequest);
GetQueryResultResponse resultResponse = waii.getQuery().getResults(new GetQueryResultRequest().setQueryId(runResponse.getQueryId()));
return resultResponse;
}
private static void printResultSet(GetQueryResultResponse response) {
if (response.getRows().isEmpty()) {
System.out.println("No results found.");
return;
}
// Extract headers from the first row
Map<String, Object> firstRow = response.getRows().get(0);
String[] headers = firstRow.keySet().toArray(new String[0]);
// Calculate column widths
int[] columnWidths = new int[headers.length];
for (int i = 0; i < headers.length; i++) {
columnWidths[i] = headers[i].length();
}
for (Map<String, Object> row : response.getRows()) {
int i = 0;
for (Map.Entry<String, Object> entry : row.entrySet()) {
int valueLength = entry.getValue().toString().length();
if (valueLength > columnWidths[i]) {
columnWidths[i] = valueLength;
}
i++;
}
}
// Print header
StringBuilder headerRow = new StringBuilder();
for (int i = 0; i < headers.length; i++) {
headerRow.append(String.format("%-" + columnWidths[i] + "s", headers[i]));
if (i < headers.length - 1) {
headerRow.append(" | ");
}
}
System.out.println(headerRow.toString());
// Print separator
StringBuilder separatorRow = new StringBuilder();
for (int columnWidth : columnWidths) {
separatorRow.append("-".repeat(columnWidth)).append("-+-");
}
System.out.println(separatorRow.substring(0, separatorRow.length() - 3));
// Print rows
for (Map<String, Object> row : response.getRows()) {
StringBuilder rowOutput = new StringBuilder();
int i = 0;
for (Map.Entry<String, Object> entry : row.entrySet()) {
rowOutput.append(String.format("%-" + columnWidths[i] + "s", entry.getValue().toString()));
if (i < headers.length - 1) {
rowOutput.append(" | ");
}
i++;
}
System.out.println(rowOutput.toString());
}
}
private static DBConnectionIndexingStatus getIndexingStatus(Waii waii, String connectionKey) throws IOException {
GetDBConnectionRequest getRequest = new GetDBConnectionRequest();
GetDBConnectionResponse response = waii.getDatabase().getConnections(getRequest);
for (DBConnection connection : response.getConnectors()) {
System.out.println(connectionKey);
System.out.println(connection.getKey());
if (connection.getKey().equals(connectionKey)) {
System.out.println("Indexing:" + new Gson().toJson(response.getConnectorStatus().get(connectionKey)));
return response.getConnectorStatus().get(connectionKey);
}
}
return null;
}
private static boolean isIndexingComplete(DBConnectionIndexingStatus status) {
for (SchemaIndexingStatus schemaStatus : status.getSchemaStatus().values()) {
if (schemaStatus.getPendingIndexingTables() > 0) {
return false;
}
}
return true;
}
}