Skip to main content

Sample Program Walkthrough

The DatabaseChatProgram is a sample program that illustrates how to generate and run queries.

Class and Variables

Class Declaration

public class DatabaseChatProgram {

The main class of the program.

Variables

private static Waii waii;
private static String dbConnectionKey;
private static List<Tweak> tweaks = new ArrayList<Tweak>();
private static String parentUuid = null;
  • waii: An instance of the Waii SDK, used to interact with the Waii API.
  • dbConnectionKey: The key of the currently active database connection.
  • tweaks: A list of Tweak objects representing modifications to queries.
  • parentUuid: The UUID of the parent query, used for keeping track of query history.

Main Method

Main Method Declaration

public static void main(String[] args) throws InterruptedException {

The entry point of the program.

Initializing Waii SDK

waii = new Waii("http://sql.waii.ai/api/", "<your-api-key>");

Initializes the Waii SDK with the API URL and key.

User Input for Database Connection

Scanner scanner = new Scanner(System.in);

System.out.print("Create new connection? (y/n): ");
String yesno = scanner.nextLine();

Prompts the user to decide whether to create a new database connection or use an existing one.

Creating or Activating a Connection

if (yesno.equals("y")) {
// Code to create a new connection
createAndActivateDbConnection(accountName, username, password, database, warehouse, role);
} else {
// Code to activate an existing connection
activateExistingConnection();
}

Based on user input, either creates a new database connection or activates an existing one.

Main Loop for User Questions

while (true) {
System.out.print("Enter your question: ");
String question = scanner.nextLine();

if (question.equalsIgnoreCase("exit")) {
break;
}

GeneratedQuery generatedQuery = generateQuery(question);
GetQueryResultResponse queryResponse = runQuery(generatedQuery);

// Print the result set
printResultSet(queryResponse);
}

Continuously reads questions from the user, generates and runs queries, and prints the results until the user types "exit".

Helper Methods

printConnections

private static void printConnections(DBConnection[] conns) {
int i = 0;
for (DBConnection conn: conns) {
System.out.println("" + i + " " + conn.getDatabase());
i++;
}
}

Prints the list of available database connections.

activateExistingConnection

private static void activateExistingConnection() throws IOException {
GetDBConnectionResponse response = waii.getDatabase().getConnections(new GetDBConnectionRequest());
DBConnection[] connections = response.getConnectors();

printConnections(connections);

Scanner scanner = new Scanner(System.in);
System.out.print("Enter connection number: ");
String number = scanner.nextLine();

DBConnection active = connections[Integer.parseInt(number)];
dbConnectionKey = active.getKey();
waii.getDatabase().activateConnection(dbConnectionKey);
}

Fetches and prints all database connections, and activates a connection based on user input.

findConnection

private static DBConnection findConnection(DBConnection db) throws IOException {
GetDBConnectionResponse response = waii.getDatabase().getConnections(new GetDBConnectionRequest());

for (DBConnection db2 : response.getConnectors()) {
if (db.equals(db2)) {
return db2;
}
}

return null;
}

Finds and returns a matching database connection, or null if not found.

createAndActivateDbConnection

private static void createAndActivateDbConnection(String accountName, String username, String password, String database, String warehouse, String role) throws IOException, InterruptedException {
DBConnection dbConnection = new DBConnection(null, "snowflake", "Snowflake connection", accountName, username, password, database, warehouse, role, null, null, null, null, true, null);

DBConnection dbConnection2 = findConnection(dbConnection);
if (dbConnection2 == null) {
ModifyDBConnectionRequest modifyRequest = new ModifyDBConnectionRequest();
modifyRequest.setUpdated(new DBConnection[]{dbConnection});

ModifyDBConnectionResponse modifyResponse = waii.getDatabase().modifyConnections(modifyRequest);
modifyResponse.getConnectors();
}

dbConnection = findConnection(dbConnection);
if (dbConnection == null) {
throw new IOException("failed to create connection.");
}

dbConnectionKey = dbConnection.getKey();
waii.getDatabase().activateConnection(dbConnectionKey);

while (true) {
DBConnectionIndexingStatus indexingStatus = getIndexingStatus(waii, dbConnectionKey);
if (indexingStatus != null && isIndexingComplete(indexingStatus)) {
break;
}
System.out.println("Indexing in progress... Please wait.");
Thread.sleep(5000);
}
}

Creates a new database connection, activates it, and waits for indexing to complete.

generateQuery

private static GeneratedQuery generateQuery(String question) throws IOException {
QueryGenerationRequest request = new QueryGenerationRequest().setAsk(question).setTweakHistory((Tweak[])tweaks.toArray(new Tweak[0])).setParentUuid(parentUuid);

GeneratedQuery generatedQuery = waii.getQuery().generate(request);

if (generatedQuery.getIsNew()) {
tweaks = new ArrayList<Tweak>();
}
Tweak tweak = new Tweak().setAsk(question).setSql(generatedQuery.getQuery());
tweaks.add(tweak);

parentUuid = generatedQuery.getUuid();

return generatedQuery;
}

Generates a SQL query based on the user's question and the current query history. Sending a tweak history gives the query generation context, sending a parentUuid will make requests part of the same session.

runQuery

private static GetQueryResultResponse runQuery(GeneratedQuery generatedQuery) throws IOException {
RunQueryRequest runRequest = new RunQueryRequest();
runRequest.setQuery(generatedQuery.getQuery());

RunQueryResponse runResponse = waii.getQuery().submit(runRequest);
GetQueryResultResponse resultResponse = waii.getQuery().getResults(new GetQueryResultRequest().setQueryId(runResponse.getQueryId()));

return resultResponse;
}

Runs the generated query and retrieves the query results.

printResultSet

private static void printResultSet(GetQueryResultResponse response) {
if (response.getRows().isEmpty()) {
System.out.println("No results found.");
return;
}

Map<String, Object> firstRow = response.getRows().get(0);
String[] headers = firstRow.keySet().toArray(new String[0]);

int[] columnWidths = new int[headers.length];
for (int i = 0; i < headers.length; i++) {
columnWidths[i] = headers[i].length();
}

for (Map<String, Object> row : response.getRows()) {
int i = 0;
for (Map.Entry<String, Object> entry : row.entrySet()) {
int valueLength = entry.getValue().toString().length();
if (valueLength > columnWidths[i]) {
columnWidths[i] = valueLength;
}
i++;
}
}

StringBuilder headerRow = new StringBuilder();
for (int i = 0; i < headers.length; i++) {
headerRow.append(String.format("%-" + columnWidths[i] + "s", headers[i]));
if (i < headers.length - 1) {
headerRow.append(" | ");
}
}
System.out.println(headerRow.toString());

StringBuilder separatorRow = new StringBuilder();
for (int columnWidth : columnWidths) {
separatorRow.append("-".repeat(columnWidth)).append("-+-");
}
System.out.println(separatorRow.substring(0, separatorRow.length() - 3));

for (Map<String, Object> row : response.getRows()) {
StringBuilder rowOutput = new StringBuilder();
int i = 0;
for (Map.Entry<String, Object> entry : row.entrySet()) {
rowOutput.append(String.format("%-" + columnWidths[i] + "s", entry.getValue().toString()));
if (i < headers.length - 1) {
rowOutput.append(" | ");
}
i++;
}
System.out.println(rowOutput.toString());
}
}

Formats and prints the query results in a tabular format.

getIndexingStatus

private static DBConnectionIndexingStatus getIndexingStatus(Waii waii, String connectionKey) throws IOException {
GetDBConnectionRequest getRequest = new GetDBConnectionRequest();
GetDBConnectionResponse response = waii.getDatabase().getConnections(getRequest);

for (DBConnection connection : response.getConnectors()) {
if (connection.getKey().equals(connectionKey)) {
return response.getConnectorStatus().get(connectionKey);
}
}
return null;
}

Fetches the indexing status of a database connection. When you add a new database a knowledge graph is generated for this database. The "indexing" status tells you if this process is complete or not.

isIndexingComplete

private static boolean isIndexingComplete(DBConnectionIndexingStatus status) {
for (SchemaIndexingStatus schemaStatus : status.getSchemaStatus().values()) {


if (schemaStatus.getPendingIndexingTables() > 0) {
return false;
}
}
return true;
}

Checks if the indexing of the database connection is complete.


Here's the full program for reference:

package ai.waii.chat;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Scanner;

import com.google.gson.Gson;

import ai.waii.Waii;
import ai.waii.clients.database.DBConnection;
import ai.waii.clients.database.DBConnectionIndexingStatus;
import ai.waii.clients.database.GetDBConnectionRequest;
import ai.waii.clients.database.GetDBConnectionResponse;
import ai.waii.clients.database.ModifyDBConnectionRequest;
import ai.waii.clients.database.ModifyDBConnectionResponse;
import ai.waii.clients.database.SchemaIndexingStatus;
import ai.waii.clients.query.GeneratedQuery;
import ai.waii.clients.query.GetQueryResultRequest;
import ai.waii.clients.query.GetQueryResultResponse;
import ai.waii.clients.query.QueryGenerationRequest;
import ai.waii.clients.query.RunQueryRequest;
import ai.waii.clients.query.RunQueryResponse;
import ai.waii.clients.query.Tweak;

public class DatabaseChatProgram {

private static Waii waii;
private static String dbConnectionKey;
private static List<Tweak> tweaks = new ArrayList<Tweak>();
private static String parentUuid = null;

public static void main(String[] args) throws InterruptedException {
// Initialize Waii SDK with URL and API key
waii = new Waii("http://localhost:9859/api/", "your-api-key");

try {
// Input database connection info from user
Scanner scanner = new Scanner(System.in);

System.out.print("Create new connection? (y/n): ");
String yesno = scanner.nextLine();

if (yesno.equals("y")) {
System.out.println("Enter Snowflake connection details:");
System.out.print("Account Name: ");
String accountName = scanner.nextLine();
System.out.print("Username: ");
String username = scanner.nextLine();
System.out.print("Password: ");
String password = scanner.nextLine();
System.out.print("Database: ");
String database = scanner.nextLine();
System.out.print("Warehouse: ");
String warehouse = scanner.nextLine();
System.out.print("Role: ");
String role = scanner.nextLine();

// Create and activate database connection
createAndActivateDbConnection(accountName, username, password, database, warehouse, role);
} else {
activateExistingConnection();
}

// Main loop to read questions from user, generate and run queries, and print results
while (true) {
System.out.print("Enter your question: ");
String question = scanner.nextLine();

if (question.equalsIgnoreCase("exit")) {
break;
}

GeneratedQuery generatedQuery = generateQuery(question);
GetQueryResultResponse queryResponse = runQuery(generatedQuery);

// Print the result set
printResultSet(queryResponse);
}
} catch (IOException e) {
e.printStackTrace();
}
}

private static void printConnections(DBConnection[] conns) {
int i = 0;
for (DBConnection conn: conns) {
System.out.println(""+i+" "+conn.getDatabase());
i++;
}
}

private static void activateExistingConnection() throws IOException {
// fetch all connections
GetDBConnectionResponse response = waii.getDatabase().getConnections(new GetDBConnectionRequest());
DBConnection[] connections = response.getConnectors();

printConnections(connections);

// Input database connection info from user
Scanner scanner = new Scanner(System.in);

System.out.print("Enter connection number: ");
String number = scanner.nextLine();

DBConnection active = connections[Integer.parseInt(number)];
dbConnectionKey = active.getKey();
waii.getDatabase().activateConnection(dbConnectionKey);
}

private static DBConnection findConnection(DBConnection db) throws IOException {
// fetch all connections
GetDBConnectionResponse response = waii.getDatabase().getConnections(new GetDBConnectionRequest());

// find right connection
for (DBConnection db2 : response.getConnectors()) {
System.out.println("checking: " + new Gson().toJson(db2));
System.out.println("");
if (db.equals(db2)) {
System.out.println("found");
return db2;
}
}

// nothing found
return null;
}

private static void createAndActivateDbConnection(String accountName, String username, String password,
String database, String warehouse, String role) throws IOException, InterruptedException {

// Create DBConnection object
DBConnection dbConnection = new DBConnection(null, "snowflake", "Snowflake connection",
accountName, username, password, database, warehouse, role, null, null, null, null, true, null);

DBConnection dbConnection2 = findConnection(dbConnection);
if (dbConnection2 == null) {

// Create ModifyDBConnectionRequest object
ModifyDBConnectionRequest modifyRequest = new ModifyDBConnectionRequest();
modifyRequest.setUpdated(new DBConnection[]{dbConnection});

// Modify and activate database connection
ModifyDBConnectionResponse modifyResponse = waii.getDatabase().modifyConnections(modifyRequest);

modifyResponse.getConnectors();
}

dbConnection = findConnection(dbConnection);
if (dbConnection == null) {
// couldn't create the connection.
throw new IOException("failed to create connection.");
}

dbConnectionKey = dbConnection.getKey();
waii.getDatabase().activateConnection(dbConnectionKey);

while (true) {
DBConnectionIndexingStatus indexingStatus = getIndexingStatus(waii, dbConnectionKey);
if (indexingStatus != null && isIndexingComplete(indexingStatus)) {
break;
}
System.out.println("Indexing in progress... Please wait.");
Thread.sleep(5000);
}
}

private static GeneratedQuery generateQuery(String question) throws IOException {
// Create QueryGenerationRequest object
QueryGenerationRequest request
= new QueryGenerationRequest()
.setAsk(question)
.setTweakHistory((Tweak[])tweaks.toArray(new Tweak[0]))
.setParentUuid(parentUuid);


// Generate query
GeneratedQuery generatedQuery = waii.getQuery().generate(request);

// remember state for future requests
if (generatedQuery.getIsNew()) {
tweaks = new ArrayList<Tweak>();
}
Tweak tweak = new Tweak().setAsk(question).setSql(generatedQuery.getQuery());
tweaks.add(tweak);

parentUuid = generatedQuery.getUuid();

return generatedQuery;
}

private static GetQueryResultResponse runQuery(GeneratedQuery generatedQuery) throws IOException {
// Create RunQueryRequest object
RunQueryRequest runRequest = new RunQueryRequest();
runRequest.setQuery(generatedQuery.getQuery());

// Submit and get query result
RunQueryResponse runResponse = waii.getQuery().submit(runRequest);
GetQueryResultResponse resultResponse = waii.getQuery().getResults(new GetQueryResultRequest().setQueryId(runResponse.getQueryId()));

return resultResponse;
}

private static void printResultSet(GetQueryResultResponse response) {
if (response.getRows().isEmpty()) {
System.out.println("No results found.");
return;
}

// Extract headers from the first row
Map<String, Object> firstRow = response.getRows().get(0);
String[] headers = firstRow.keySet().toArray(new String[0]);

// Calculate column widths
int[] columnWidths = new int[headers.length];
for (int i = 0; i < headers.length; i++) {
columnWidths[i] = headers[i].length();
}

for (Map<String, Object> row : response.getRows()) {
int i = 0;
for (Map.Entry<String, Object> entry : row.entrySet()) {
int valueLength = entry.getValue().toString().length();
if (valueLength > columnWidths[i]) {
columnWidths[i] = valueLength;
}
i++;
}
}

// Print header
StringBuilder headerRow = new StringBuilder();
for (int i = 0; i < headers.length; i++) {
headerRow.append(String.format("%-" + columnWidths[i] + "s", headers[i]));
if (i < headers.length - 1) {
headerRow.append(" | ");
}
}
System.out.println(headerRow.toString());

// Print separator
StringBuilder separatorRow = new StringBuilder();
for (int columnWidth : columnWidths) {
separatorRow.append("-".repeat(columnWidth)).append("-+-");
}
System.out.println(separatorRow.substring(0, separatorRow.length() - 3));

// Print rows
for (Map<String, Object> row : response.getRows()) {
StringBuilder rowOutput = new StringBuilder();
int i = 0;
for (Map.Entry<String, Object> entry : row.entrySet()) {
rowOutput.append(String.format("%-" + columnWidths[i] + "s", entry.getValue().toString()));
if (i < headers.length - 1) {
rowOutput.append(" | ");
}
i++;
}
System.out.println(rowOutput.toString());
}
}

private static DBConnectionIndexingStatus getIndexingStatus(Waii waii, String connectionKey) throws IOException {
GetDBConnectionRequest getRequest = new GetDBConnectionRequest();
GetDBConnectionResponse response = waii.getDatabase().getConnections(getRequest);

for (DBConnection connection : response.getConnectors()) {
System.out.println(connectionKey);
System.out.println(connection.getKey());
if (connection.getKey().equals(connectionKey)) {
System.out.println("Indexing:" + new Gson().toJson(response.getConnectorStatus().get(connectionKey)));
return response.getConnectorStatus().get(connectionKey);
}
}
return null;
}

private static boolean isIndexingComplete(DBConnectionIndexingStatus status) {
for (SchemaIndexingStatus schemaStatus : status.getSchemaStatus().values()) {
if (schemaStatus.getPendingIndexingTables() > 0) {
return false;
}
}
return true;
}
}